Wednesday, April 1, 2009

Shunting Yard Algorithm

So, the next JUG meeting is introducing the concept of a code kata. One suggestion for an example topic was the Shunting Yard Algorithm. I remember using an HP 42S RPN calculator all through school, and well into my professional life. So, I thought I'd give it a go.

Looking at the Wikipedia entry, they outline an algorithm in pseudo-english-code that can be implemented fairly easily in scala like this:

package org.okcjug.kata.shuntingyard

import scala.util.parsing.combinator._
import scala.util.matching.Regex
import scala.util.parsing.input.CharSequenceReader
import scala.collection.mutable.Stack
import scala.collection.mutable.Queue

object ShuntingYardWikipediaParser extends RegexParsers {

// the actual algorithm
def rpn(s: String) = {
val stack = new Stack[String]
val queue = new Queue[String]

for (t: String <- parseAll(tokens, new CharSequenceReader(s)).get) {

if (t isNumber) { queue += t }
if (t isOperator) {
while ((!(stack isEmpty) && (stack.top isOperator)) && (
((t isLeftAssociative) && (t isLowerOrEqualPrecedenceThan stack.top)) ||
((t isRightAssociative) && (t isLowerPrecedenceThan stack.top)))) {
queue += stack.pop
}
stack push t
}
if (t isLeftParenthesis) ( stack push t )
if (t isRightParenthesis) {
while (!(stack.top isLeftParenthesis)) { queue += stack.pop }
stack.pop
}
}
while (!(stack isEmpty)) {
if (stack.top isLeftParenthesis) { error("mismatched parens") }
queue += stack.pop
}
queue.mkString(" ")
}

// definitions
val number = """(\d+(\.\d+)?)""".r
val product = """[\*\/]""".r
val term = """[\+\-]""".r
val exponent = """\^""".r
val negation = """\!""".r
val leftParens = """\(""".r
val rightParens = """\)""".r
val parens = leftParens | rightParens
val token = number | product | term | exponent | negation | parens ^^ { case s: String => s }
val tokens = token*

// tricky implicit conversions to make the syntax prettier
implicit def stringToSmartToken(s: String): SmartToken = { SmartToken(s) }
case class SmartToken(s: String) {
implicit def regexMatches(r: Regex): Boolean = { s.matches(r.toString) }
def isNumber: Boolean = number
def isOperator = product || term || exponent || negation
def isLeftAssociative = isOperator && ( product || term || negation )
def isRightAssociative = isOperator && exponent
def isLeftParenthesis: Boolean = leftParens
def isRightParenthesis: Boolean = rightParens
def precedence = {
if (term) 1
else if (product) 2
else if (exponent) 3
else if (negation) 4
else error("Not an operator")
}
def isLowerPrecedenceThan(o: String) = {
(s precedence) < (o precedence)
}
def isLowerOrEqualPrecedenceThan(o: String) = {
(s precedence) <= (o precedence)
}
}
}



Note that I ignored the possibility of a function as described in the Wikipedia entry, but added the concept of a unary operator (negation, which I represent as a prefixed "!").

The interesting part of this starts at line 41, which has (through line 47) a bunch of regex definitions of the tokens to be read. Scala can convert a string to a regex using the ".r" method.

Lines 48 - 50 take the regular expressions and combine them into parsers combinators, which produce string tokens.

To make the string tokens syntactically pleasing, we implicitly convert them into a "SmartToken" (lines 53 - 75) which has methods relevant to this particular algorithm. Another interesting thing is line 55, which auto-magically converts a regex (defined in an outer scope) to be treated like a boolean value of that regex applied to the SmartToken's wrapped string. Yee-haw! A few more refactoring passes and it will look like Ruby while still allowing you to argue with the compiler over types.

But, I'll be honest. If I didn't find the Wikipedia description of the algorithm, I probably wouldn't have implemented it in the same way. Also, if I'm going to take the trouble to parse a language, I might as well build an abstract syntax tree.

Here is my elaborate AST model:

package org.okcjug.kata.shuntingyard

import scala.util.parsing.combinator._

abstract class Expression {
def rpn: List[String]
}

case class NumberExpr(n: String) extends Expression {
def rpn = List(n)
}

abstract class Operator extends Expression
case class UnaryOp(op: String, e: Expression) extends Operator {
def rpn = e.rpn ::: List(op)
}
case class BinaryOp(op: String, l: Expression, r: Expression) extends Operator {
def rpn = l.rpn ::: r.rpn ::: List(op)
}



And, I can parse it with this object:

package org.okcjug.kata.shuntingyard

import scala.util.parsing.combinator._
import scala.util.parsing.input.CharSequenceReader

object ShuntingYardParser extends RegexParsers {

def expression(s: String) = parseAll(expr, new CharSequenceReader(s))
def rpn(e: Expression) = e.rpn.mkString(" ")
def rpn(s: String) = {
val e = expression(s)
if (e.successful) {
e.get.rpn.mkString(" ")
}
else {
error("Could not parse: " + s)
}
}

// parser definitions
def expr: Parser[Expression] = term
def num = """\d+(\.\d+)?""".r | failure("Some number expected")
def number = num ^^ { case s => NumberExpr(s) }
def group = "(" ~> expr <~ ")" ^^ { case e => e }
def singular: Parser[Expression] = number | group | unaryOp | failure("Some expression expected")
def unaryOp = "!" ~ singular ^^ { case s ~ e => UnaryOp(s, e) }
def exponent = rightChain(singular, "^")
def product = leftChain(exponent, "*"|"/")
def term = leftChain(product, "+"|"-")

// I think this is redundant, but couldn't figure out the chainx1(...) methods in scala's Parser class
def rightChain(p: Parser[Expression], x: Parser[String]) = {
((p ~ x)*) ~ p ^^ {
case Nil ~ tail => tail
case rest ~ tail => rest.foldRight(tail) {
case (p ~ s, z) => BinaryOp(s, p, z)
}
}
}
def leftChain(p: Parser[Expression], x: Parser[String]) = {
p ~ ((x ~ p)*) ^^ {
case head ~ Nil => head
case head ~ rest => rest.foldLeft(head) {
case (z, s ~ p) => BinaryOp(s, z, p)
}
}
}

}



This code really does all the parsing definitions in lines 21 - 29. We produce an AST that can be output in RPN form.

The really tricky part (for me) about this code was defining it is such a way that we didn't get any recursive parsers as the first (leftmost) term of a sequential combination. This will lead to immediate stack-overflow purgatory as the parser infinitely recurses. Surprisingly, the compiler/type-checker allows it, and I've heard that it might be supported in a future implementation.

Well, do they work?

I guess we should actually test these implementations to see if they do what they are supposed to. Probably better even still to have defined some tests up front to allow the algorithm to be developed under that guidance. Scala has a wonderful testing facility called specs that allows you to specify test in a very intuitive way. Drop a library in my path and make some specifications>:


package org.okcjug.kata.shuntingyard

import org.specs._
import ShuntingYardParser._

object ShuntingYardTest extends Specification {

def passStdTests(parser: { def rpn(s: String): String }) = {
"be able to add and subtract" in {
parser.rpn("3+4-2") must beEqual("3 4 + 2 -")
}

"do unary negation correctly" in {
parser.rpn("! 3 + 2") must beEqual("3 ! 2 +")
parser.rpn("!(3 + 2)^2") must beEqual("3 2 + ! 2 ^")
}

"parse Wikipedia example correctly" in {
val input = "3 + 4 * 2 / (1-5) ^2 ^3"
val result = parser.rpn(input)
result must beEqual("3 4 2 * 1 5 - 2 3 ^ ^ / +")
}
}

"ShuntingYardParser" should {
passStdTests(ShuntingYardParser)
}

"ShuntingYardParser" can {
"blow the stack" in {
val sb = new StringBuffer
val max = 1000
for(i <- 1 to max) { sb append "(" }
sb append "(4-3)"
for(i <- 1 to max) { sb append "*(4-3))" }
ShuntingYardParser.rpn(sb.toString) must throwA(new StackOverflowError)
}
}

"ShuntingYardWikipediaParser" should {
passStdTests(ShuntingYardWikipediaParser)

"not blow the stack" in {
val sb = new StringBuffer
val max = 1000
for(i <- 1 to max) { sb append "(" }
sb append "(4-3)"
for(i <- 1 to max) { sb append "*(4-3))" }
ShuntingYardWikipediaParser.rpn(sb.toString) mustNot throwAn(new Exception)
}
}
}



This spec produces an executable class that produces this output:

Specification "ShuntingYardTest"
ShuntingYardParser should
+ be able to add and subtract
+ do unary negation correctly
+ parse Wikipedia example correctly

Total for SUT "ShuntingYardParser":
Finished in 0 second, 0 ms
3 examples, 4 expectations, 0 failure, 0 error

ShuntingYardParser can
+ blow the stack

Total for SUT "ShuntingYardParser":
Finished in 0 second, 0 ms
1 example, 1 expectation, 0 failure, 0 error

ShuntingYardWikipediaParser should
+ be able to add and subtract
+ do unary negation correctly
+ parse Wikipedia example correctly
+ not blow the stack

Total for SUT "ShuntingYardWikipediaParser":
Finished in 0 second, 0 ms
4 examples, 5 expectations, 0 failure, 0 error

Total for specification "ShuntingYardTest":
Finished in 0 second, 563 ms
8 examples, 10 expectations, 0 failure, 0 error



Notice that I'm testing both implementations of the parser with the same code, except for one test (blow the stack). This is because the AST version will blow up when it encounters some input that nests thousands of levels deep. I verify this and also verify that is doesn't affect the Wikipedia version. I'm am deeply embarrassed that a Wikipedia algorithm is more feature-full than my home-grown version.

What next?

Well, we could actually calculate the results of the input from either the AST or the rpn input. Both of these would be almost trivial to implement. We could randomly generate input strings, feed them into both parsers and compare the results. It is nice to have Scala's parser combinators as a tool in your belt. There is a dead zone between what can/should be done with one-off regular expressions, and generating a full-fledged language with an external tool (JavaCC, ANTLR, etc). Another option for these situations is to use an internal DSL (like the specs library I used). But, that's a topic for another post...