Monday, April 1, 2013

Scala DSL for parsing and evaluating of arithmetic expressions

In this post I want to show you a simple way to parse and evaluate arithmetic expressions by using Scala Parser Combinators.

So, I will try to do the followings:
  1. create a parser able to recognize complex arithmetic expressions
  2. parse arithmetic expression and generate the parse result as a list of strings that corresponds to the postfix notation of the initial expression
  3. evaluate the above postfix notation list to generate the value of the arithmetic expression
To better understand the code I would also recommend you to read the following links if you are not that familiar with Scala parser combinators and the postfix notation:

Chapter 31 of Programming in Scala, First Edition Combinator Parsing by Martin Odersky, Lex Spoon, and Bill Venners December 10, 2008

Reverse Polish notation

And here is the code:


import scala.util.parsing.combinator._
/**
* @author Nicolae Caralicea
* @version 1.0, 04/01/2013
*/
class Arithm extends JavaTokenParsers {
  def expr: Parser[List[String]] = term ~ rep(addTerm | minusTerm) ^^
    { case termValue ~ repValue => termValue ::: repValue.flatten }

  def addTerm: Parser[List[String]] = "+" ~ term ^^
    { case "+" ~ termValue => termValue ::: List("+") }

  def minusTerm: Parser[List[String]] = "-" ~ term ^^
    { case "-" ~ termValue => termValue ::: List("-") }

  def term: Parser[List[String]] = factor ~ rep(multiplyFactor | divideFactor) ^^
    {
      case factorValue1 ~ repfactor => factorValue1 ::: repfactor.flatten
    }

  def multiplyFactor: Parser[List[String]] = "*" ~ factor ^^
    { case "*" ~ factorValue => factorValue ::: List("*") }

  def divideFactor: Parser[List[String]] = "/" ~ factor ^^
    { case "/" ~ factorValue => factorValue ::: List("/") }

  def factor: Parser[List[String]] = floatingPointConstant | parantExpr

  def floatingPointConstant: Parser[List[String]] = floatingPointNumber ^^
    {
      case value => List[String](value)
    }

  def parantExpr: Parser[List[String]] = "(" ~ expr ~ ")" ^^
    {
      case "(" ~ exprValue ~ ")" => exprValue
    }

  def evaluateExpr(expression: String): Double = {
    val parseRes = parseAll(expr, expression)
    if (parseRes.successful) evaluatePostfix(parseRes.get)
    else throw new RuntimeException(parseRes.toString())
  }
  private def evaluatePostfix(postfixExpressionList: List[String]): Double = {
    import scala.collection.immutable.Stack

    def multiply(a: Double, b: Double) = a * b
    def divide(a: Double, b: Double) = a / b
    def add(a: Double, b: Double) = a + b
    def subtract(a: Double, b: Double) = a - b

    def executeOpOnStack(stack: Stack[Any], operation: (Double, Double) => Double): (Stack[Any], Double) = {
      val el1 = stack.top
      val updatedStack1 = stack.pop
      val el2 = updatedStack1.top
      val updatedStack2 = updatedStack1.pop
      val value = operation(el2.toString.toDouble, el1.toString.toDouble)
      (updatedStack2.push(operation(el2.toString.toDouble, el1.toString.toDouble)), value)
    }
    if (postfixExpressionList.length == 1) toDouble(postfixExpressionList.head)
    else {
      val initial: (Stack[Any], Double) = (Stack(), null.asInstanceOf[Double])
      val res = postfixExpressionList.foldLeft(initial)((computed, item) =>
        item match {
          case "*" => executeOpOnStack(computed._1, multiply)
          case "/" => executeOpOnStack(computed._1, divide)
          case "+" => executeOpOnStack(computed._1, add)
          case "-" => executeOpOnStack(computed._1, subtract)
          case other => (computed._1.push(other), computed._2)
        })
      res._2
    }
}

object TestArithmDSL {
  def main(args: Array[String]): Unit = {
    val arithm = new Arithm
    val actual = arithm.evaluateExpr("(12 + 4 * 6) * ((2 + 3 * ( 4 + 2 ) ) * ( 5 + 12 ))")
    val expected: Double = (12 + 4 * 6) * ((2 + 3 * ( 4 + 2 ) ) * ( 5 + 12 ))
    assert(actual == expected)
  }
}