Saturday, January 26, 2013

Graphs - Finding the Minimum Spanning Tree - A Generic Scala implementation of Prim's algorithm



In this post I am trying to provide you a basic implementation of Prim's algorithm in Scala.

I also tried to make it more generic.

So, here is how I defined the Node, Edge, and UndirectedGraph (just a part of it) classes.

case class Node(name: String)
case class Edge[T <% Ordered[T]](xNodeName: String, yNodeName: String, cost: T)
case class UndirectedGraph[T <% Ordered[T]](nodes: List[Node], edgetList: List[Edge[T]])

Here is the declaration of the  function that provides the sought result:


def computeMinimumSpanningTreeUsingPrim(): UndirectedGraph[T]

 I am also using this opportunity to introduce you to "Views and view bounds".
This is another great Scala feature. I'm not going to dive into more details about it.

So, as long as there is a "view" from T to Ordered[T], this algorithm is supposed to work properly.

Note that, "T <% Ordered[T]" is one of the most used views in Scala.


On Wikipedia you can also find a good explanation if the Prim's algorithm, so I wont go into more details.

Here is the link:

http://en.wikipedia.org/wiki/Prim%27s_algorithm



And here is the code:


package org.madeforall.graph.mst.prim
/**
 * @author Nicolae Caralicea
 * @version 1.0, 26/01/2013
 */

case class Node(name: String)
case class Edge[T <% Ordered[T]](xNodeName: String, yNodeName: String, cost: T)

case class UndirectedGraph[T <% Ordered[T]](nodes: List[Node], edgetList: List[Edge[T]]) {
  case class NodesEdge[T <% Ordered[T]](xNode: Node, yNode: Node, cost: T)

  def computeMinimumSpanningTreeUsingPrim(): UndirectedGraph[T] =
    computeMinimumSpanningTree(nodes.tail, UndirectedGraph(List(nodes.head), Nil))

    private def computeMinimumSpanningTree(leftNodes: List[Node], mst: UndirectedGraph[T]): UndirectedGraph[T] = {
    leftNodes match {
      case Nil => mst
      case _ =>
        val minCostEdge = minCost(cost(leftNodes, mst)) // xNode belongs to leftNodes, and yNode belongs to mst
        if (minCostEdge == None) throw new Error("disconnected graph")
        val updatedLeftNodes = leftNodes diff List(minCostEdge.get.xNode)
        val updatedMst = UndirectedGraph(
              minCostEdge.get.xNode :: mst.nodes,
              toEdge(minCostEdge.get) :: mst.edgetList)
        computeMinimumSpanningTree(updatedLeftNodes, updatedMst)
    }
  }

  private def toEdge[T <% Ordered[T]](nodesEdge: NodesEdge[T]): Edge[T] =
    Edge(nodesEdge.xNode.name, nodesEdge.yNode.name, nodesEdge.cost)

  private def cost(outsideNodes: List[Node], graph: UndirectedGraph[T]): List[NodesEdge[T]] =
    for {
      node <- outsideNodes
      minCostEdgeNodeToGraph <- minCostEdgeOfNodeToGraph(node, graph.nodes)
    } yield minCostEdgeNodeToGraph

  private def minCost[T <% Ordered[T]](nodesEdgeList: List[NodesEdge[T]]): Option[NodesEdge[T]] = {
    val min: Option[NodesEdge[T]] = None
    nodesEdgeList.foldLeft(min)((comp, itm) =>
      if (comp != None) {
        if (comp.get.cost > itm.cost) Some(itm) else comp
      } else {
        Some(itm)
      })
  }

  private def cost(nodeA: Node, nodeB: Node): Option[T] = {
    val costItem = edgetList.filter(item => {
      (item.xNodeName == nodeA.name && item.yNodeName == nodeB.name) ||
        (item.yNodeName == nodeA.name && item.xNodeName == nodeB.name)
    })
    if (costItem != Nil) Some(costItem.head.cost) else None
  }

  private def minCostEdgeOfNodeToGraph(node: Node, graph: List[Node]): Option[NodesEdge[T]] = {
    val edges = for {
      n <- graph
      cost <- cost(n, node)
    } yield NodesEdge(node, n, cost)
    minCost(edges)
  }
}

To test it you can use something like this:

package org.madeforall.extractor.test
import org.scalatest._
import org.scalatest.matchers._
import org.madeforall.graph.mst.prim._
/**
 * @author Nicolae Caralicea
 * @version 1.0, 26/01/2013
 */
class TestMst extends FlatSpec with ShouldMatchers {
    
  "The minimal spanning resulted whn using the Prim's alghorith" should "be like" in {
    val graph: UndirectedGraph[Double] = UndirectedGraph(
      List(
        Node("A"), Node("B"), Node("C"), Node("D"),
        Node("E"), Node("F"), Node("G"), Node("H"),
        Node("M"), Node("N")),
      List(
        Edge("A", "B", 9.0),
        Edge("A", "F", 6),
        Edge("A", "G", 3),
        Edge("B", "G", 9),
        Edge("B", "M", 8),
        Edge("B", "C", 18),
        Edge("C", "M", 10),
        Edge("C", "N", 3),
        Edge("C", "D", 4),
        Edge("D", "N", 1),
        Edge("D", "E", 4),
        Edge("E", "F", 9),
        Edge("E", "H", 9),
        Edge("E", "M", 7),
        Edge("F", "G", 4),
        Edge("F", "H", 2),
        Edge("H", "G", 2),
        Edge("H", "M", 8),
        Edge("M", "N", 9),
        Edge("M", "G", 9),
        Edge("N", "E", 5)))
    assert(graph.computeMinimumSpanningTreeUsingPrim === UndirectedGraph(
        List(Node("B"), Node("C"), Node("N"), Node("D"), Node("E"), Node("M"), Node("F"), Node("H"), Node("G"), Node("A")),
        List(Edge("B","M",8), Edge("C","N",3), Edge("N","D",1), Edge("D","E",4), Edge("E","M",7), Edge("M","H",8), Edge("F","H",2),
            Edge("H","G",2), Edge("G","A",3))))
  }
    
}

If you want to use and contribute in any way to this project/code here is the link to its Github repository:

https://github.com/ncaralicea/madeforall



No comments:

Post a Comment