Published on

Notes on Minimum Spanning Tree

Authors

Introduction

Minimum spanning tree is the tree that connects all nodes in graph with the least sum of edge weights.

Algorithms

There are two greedy algorithms that focused on solving this problem:

Prim’s Algorithm

  • Vertex-based
  • DS: Priority-Queue DS (min-heap)
  • Suitable for dense graphs because adjacency list are stored in matrices

The process begins with an initial node in the tree. At each step, the algorithm selects the least weighted edge connecting the tree to a node outside it. The process repeats until all nodes are connected.

#include <vector>
#include <queue>
#include <limits>

// Prim's Algorithm for Minimum Spanning Tree
// Returns total weight of MST
int primMST(const std::vector<std::vector<std::pair<int, int>>>& graph) {
    int n = graph.size();
    std::vector<bool> visited(n, false);
    std::vector<int> key(n, std::numeric_limits<int>::max());
    
    // Min heap to store vertices and their key values
    std::priority_queue<std::pair<int, int>, std::vector<std::pair<int, int>>, std::greater<std::pair<int, int>>> pq;
    
    // Start with vertex 0
    key[0] = 0;
    pq.push({0, 0}); // {key, vertex}
    
    int mstWeight = 0;
    
    while (!pq.empty()) {
        int u = pq.top().second;
        int weight = pq.top().first;
        pq.pop();
        
        if (visited[u]) continue;
        
        visited[u] = true;
        mstWeight += weight;
        
        // Check all adjacent vertices of u
        for (const auto& edge : graph[u]) {
            int v = edge.first;
            int w = edge.second;
            
            if (!visited[v] && w < key[v]) {
                key[v] = w;
                pq.push({w, v});
            }
        }
    }
    
    return mstWeight;
}

Kruskal’s Algorithm

  • Edge-based algorithm
  • DS: Union-Find (Disjoint-Set)
  • Suitable for sparse graphs

The main process adds edges in increasing order of weight. If an edge connects two nodes that are already part of the tree, it is skipped. Stops when all edges are considered.

#include <vector>

// Disjoint Set data structure
class DisjointSet {
    std::vector<int> parent;
    std::vector<int> rank;
public:
    DisjointSet(int n) {
        parent.resize(n);
        rank.resize(n, 0);
        // initial: all nodes are 
        for (int i = 0; i < n; i++) {
            parent[i] = i;
        }
    }
    
    int Find(int x) {
        if (parent[x] != x) {
            parent[x] = Find(parent[x]);
        }
        return parent[x];
    }
    
    void Union(int x, int y) {
        int px = Find(x), py = Find(y);
        if (px == py) return; // already union
        
        if (rank[px] < rank[py]) {
            parent[px] = py;
        } else if (rank[px] > rank[py]) {
            parent[py] = px;
        } else {
            parent[py] = px;
            rank[px]++;
        }
    }
};

// Kruskal's Algorithm for Minimum Spanning Tree
// Returns total weight of MST
int kruskalMST(const std::vector<std::vector<std::pair<int, int>>>& graph) {
    int n = graph.size();
    std::vector<std::tuple<int, int, int>> edges; // {weight, u, v}
    
    // Collect all edges
    for (int u = 0; u < n; u++) {
        for (const auto& edge : graph[u]) {
            int v = edge.first;
            int w = edge.second;
            if (u < v) { // Add each edge only once
                edges.push_back({w, u, v});
            }
        }
    }
    
    // Sort edges by weight
    std::sort(edges.begin(), edges.end());
    
    DisjointSet ds(n);
    int mstWeight = 0;
    
    // Process edges in ascending order of weight
    for (const auto& edge : edges) {
        int w = std::get<0>(edge);
        int u = std::get<1>(edge);
        int v = std::get<2>(edge);
        
        if (ds.Find(u) != ds.Find(v)) {
            ds.Union(u, v);
            mstWeight += w;
        }
    }
    
    return mstWeight;
}

References

[1] GeeksforGeeks

[2] Wikipedia