Prim MST Algorithm

The Prim's Minimum Spanning Tree (MST) Algorithm is a greedy algorithm that is used to find the minimum spanning tree of a connected, undirected graph with weighted edges. The primary purpose of the minimum spanning tree is to connect all the vertices in the graph in such a way that the total weight of the edges is minimized. Prim's algorithm was developed by Czech mathematician Vojtěch Jarník in 1930 and later rediscovered and popularized by American computer scientist Robert C. Prim in 1957. The Prim's MST algorithm starts with an arbitrary vertex in the graph, and then it grows the tree by iteratively choosing the smallest weighted edge that connects a vertex in the tree to a vertex not in the tree. The algorithm maintains a priority queue or a set data structure to store the candidate edges and selects the minimum weight edge from this set. The process is repeated until all the vertices are included in the tree, thus forming a minimum spanning tree. One of the advantages of Prim's algorithm is its simplicity and ease of implementation using adjacency lists or adjacency matrices, making it a popular choice for solving real-world network design problems.
package org.gs.digraph

import org.gs.graph.{Edge, EdgeWeightedGraph}
import org.gs.queue.IndexMinPQ
import scala.annotation.tailrec
import scala.collection.mutable.Queue

/** Compute a minimal spanning tree in an edge weighted graph
  *
  * Only the shortest edge connecting a vertex to the tree remains on queue
  *
  * @constructor creates a new PrimMST with an EdgeWeightedGraph
  * @param g EdgeWeightedGraph
  * @see [[https://algs4.cs.princeton.edu/43mst/PrimMST.java.html]]
  * @author Scala translation by Gary Struthers from Java by Robert Sedgewick and Kevin Wayne.
  */
class PrimMST(g: EdgeWeightedGraph) {
  private val edgeTo = new Array[Edge](g.numV)
  private val distTo = Array.fill[Double](g.numV)(Double.MaxValue)
  private val marked = Array.fill[Boolean](g.numV)(false)
  private val pq = new IndexMinPQ[Double](g.numV)
  for {
    v <- 0 until g.numV
    if (!marked(v))
  } prim(v)

  private def scan(v: Int): Unit = {
    marked(v) = true

    def scanEdge(e: Edge): Unit = if (!marked(e.other(v))) {
      val w = e.other(v)
      if (e.weight < distTo(w)) {
        distTo(w) = e.weight
        edgeTo(w) = e
        if (pq.contains(w)) pq.decreaseKey(w, distTo(w)) else pq.insert(w, distTo(w))
      }
    }
    g.adj(v) foreach (scanEdge)
  }

  private def prim(s: Int): Unit = {
    distTo(s) = 0.0
    pq.insert(s, distTo(s))
    scan(s)

    @tailrec
    def loop(): Unit = if (!pq.isEmpty) {
      scan(pq.delMin)
      loop()
    }

    loop()
  }

  /** returns sum of edge weights in a MST */
  def weight(): Double = edges.foldLeft(0.0)(_ + _.weight)

  /** returns edges of a MST */
  def edges(): List[Edge] = {
    val mst = new Queue[Edge]()
    edgeTo foreach (e => if (e != null) mst.enqueue(e))
    mst.toList
  }
}

LANGUAGE:

DARK MODE: