정글/알고리즘

[알고리즘] 최소 스패닝 트리 (Kruskal 알고리즘, Prime 알고리즘)

nkdev 2025. 3. 28. 21:19

최소 스패닝 트리

신장 트리 (Spanning Tree)

  • 그래프에서 일부 간선을 선택해 만든 트리
  • 최소 연결로 이루어짐 즉, 모든 노드가 가장 적은 간선 수로 이어져있는 경우에 해당
  • 정점이 n개일 때 간선이 n-1개이면 스패닝 트리
  • BFS, DFS로 신장 트리 찾기 가능 (탐색 도중 사용한 간선을 모으면 됨)

최소 비용 신장 트리 (Minimum Spanning Tree, MST)

  • 스패닝 트리 중 사용된 간선들의 가중치 합이 최소인 트리
  • 각 간선의 가중치가 동일하지 않을 경우, 단순히 적은 간선을 쓴다고 최소비용이 되는 건 아님
  • 사이클이 없어야 함
  • 구현 방법으로는 Kruskal, Prime 알고리즘이 있는데 둘 다 그리디한 방법임

Kruskal Algorithm

  • 간선을 비용이 낮은 것부터 v-1개 선택하며 떨어져있던 노드를 연결하는 방법 참고
  1. 간선들을 가중치의 오름차순으로 정렬
    • 간선의 양 끝 정점과 간선의 가중치를 멤버로 가지는 node클래스를 선언한다.
    • node클래스의 객체를 ArrayList에 넣은 후 가중치 기준으로 Collections.sort()처리한다.
    • node클래스 안에 Comparable의 compareTo()을 오버라이딩하면 무엇을 기준으로 정렬처리할지 지정할 수 있다. 참고
  2. 정렬된 간선 리스트에서 순서대로 간선 선택
    • 정렬된 list의 처음부터 순서대로 간선을 선택하면 가장 낮은 가중치 먼저 선택
    • union&find 알고리즘 사용하여 사이클을 형성하는 간선 제외 참고 링크
    • 간선을 선택할 때마다 간선의 양 끝 정점의 부모 노드(루트 노드)를 같게 만들어 두 정점이 같은 집합에 속하게 만듦
      새로운 간선의 양 끝 정점의 부모 노드가 같다면 그래프가 순환하는 것.
      노드 1과 2를 잇는 간선이 선택됨 -> 노드 1, 2의 부모 노드를 같게 만듦
      노드 2와 3을 잇는 간선이 선택됨 -> 노드 2, 3의 부모 노드를 같게 만듦
      이때 다음으로 가중치가 작은 간선의 양 끝이 노드 1, 3이라도 사이클이 생기기 때문에 선택하면 안 된다.
      노드 1, 3의 부모 노드가 같은지 아닌지 판별해서 사이클이 생기는지 알 수 있다.
  3. 해당 간선을 MST 집합에 추가
  4. 2-3번을 간선 개수만큼 반복

자바 코드 :

//백준 1197번 문제
import java.io.*;
import java.util.*;

public class Main {
    static List<node> list = new ArrayList<>();
    static int[] parent;

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
        StringTokenizer st;

        st = new StringTokenizer(br.readLine());
        int V = Integer.parseInt(st.nextToken());
        int E = Integer.parseInt(st.nextToken());
        int sum = 0;
		
        //부모 노드 저장
        parent = new int[V+1];
        for(int i=0; i<parent.length; i++){
            parent[i] = i;
        }
		
        //리스트에 노드를 저장
        for(int i=0; i<E; i++){
            st = new StringTokenizer(br.readLine());
            int a = Integer.parseInt(st.nextToken());
            int b = Integer.parseInt(st.nextToken());
            int c = Integer.parseInt(st.nextToken());
            list.add(new node(a, b, c));
        }
        
        Collections.sort(list);
        
        //리스트에 저장된 노드를 차례로 꺼냄
        for(int i=0; i<list.size(); i++){
            node n = list.get(i);
            if(!isSameParent(n.a, n.b)){ //같은 집합에 속해있지 않으면 사이클 형성 안 되므로 union
                union(n.a, n.b);
                sum += n.c;
            }
        }
        bw.write(sum+" ");
        bw.close();
    }
    static void union(int a, int b){
        a = find(a);
        b = find(b);
        //if(a != b)
        //    parent[b] = a;
        if(x < y) parent[y] = x;
        else if(x > y) parent[x] = y;
        
    }
    static boolean isSameParent(int a, int b){
        a = find(a);
        b = find(b);
        if(a != b)
            return false;
        return true;
    }
    static int find(int x){
        if(parent[x] == x)
            return x;
        return parent[x] = find(parent[x]);
    }
}
class node implements Comparable<node>{
    int a, b, c;
    node(int a, int b, int c){
        this.a = a;
        this.b = b;
        this.c = c;
    }
    
    //가중치 기준으로 정렬
   public int compareTo(node n){
        return this.c - n.c;
    }
}

//find함수에서 마지막에 find(parent[a]); 을 리턴하면 parent[a] = find(parent[a])를 리턴하는 것보다 훨씬 오래 걸린다. (2배 이상 차이) 원인이 뭘까
//주의 : union할 때 x, y의 부모를 find()한 값으로 유니온해야 한다.

파이썬 코드 : 

v, e = map(int, input().split())
parent = [0]*(v+1)
for i in range(1, v+1):
    parent[i] = i

def find(x):
    if parent[x] != x: # 노드 x가 다른 노드(parent[x])와 연결되어 있다면
        parent[x] = find(parent[x]) # 그 다른 노드(parent[x])는 어떤 노드와 연결되어 있는지 찾기
    return parent[x]

def union(a, b):
    a = find(a)
    b = find(b)
    if a>b:
        parent[a] = parent[b]
    else:
        parent[b] = parent[a]

edges = []
total_cost = 0

for _ in range(e):
    a, b, cost = map(int, input().split())
    edges.append((cost, a, b))

edges.sort()

for i in range(e):
    cost, a, b = edges[i]
    if find(a) != find(b):
        union(a, b)
        total_cost += cost

print(total_cost)

 

Prime Algorithm

  • 임의의 정점에서 시작해 간선 비용이 낮은 것을 찾아 점차적으로 뻗어 나가는 방법 
  • 어떤 정점에서 시작하더라도 동일한 형태의 MST를 구할 수 있음
  • 우선순위 큐를 이용해 가중치가 최소인 간선을 구하면 O(ElogV)의 시간 복잡도로 MST 구할 수 있음
  1. 시작 노드를 선택하여 우선순위 큐에 (노드번호, 간선 비용) 형태로 넣는다.
  2. 우선순위 큐가 빌 때까지 다음을 반복한다.
    1. 큐에서 가중치가 최소인 노드를 꺼낸다.
    2. 아직 방문하지 않은 노드이면 방문처리 하고 MST에 추가한다.
    3. 꺼낸 노드와 연결된 모든 노드 중 아직 방문하지 않은 노드들을 우선순위 큐에 넣는다. 

파이썬 코드 :

import sys
import heapq

read = sys.stdin.readline

def prim(x):
  visited[x] = True # 시작 노드 방문
  route = [x]
  heap = graph[x] # 가중치가 최소인 간선을 선택하기 위한 최소힙(우선순위큐)
  heapq.heapify(heap) # 우선순위큐로 만들어줌
  res = 0 # 가중치를 저장할 변수
  
  while heap:
    weight, w = heapq.heappop(heap) # 힙(우선순위 큐)에서 최소값을 꺼내는 연산 O(logV)
    if visited[w] == False:
      visited[w] = True # 방문하지 않은 노드인 경우 방문 처리
      route.append(w)
      res += weight # 가중치를 더해준다
      for edge in graph[w]: # 현재 노드와 연결된 모든 간선을 확인 O(E)
        if visited[edge[1]] == False:
          heapq.heappush(heap, edge) # 방문여부를 판단해 아직 방문하지 않은 노드만 최소힙에 넣어줌
  return route, res
  
v, e = map(int, read().split()) # 노드의 수, 간선의 수
graph = [[] for _ in range(v + 1)] # 그래프 저장 리스트, 가중치와 종점을 저장
visited = [False] * (v + 1) # 방문 여부를 판단하기 위한 리스트

for _ in range(e):
  u, v, weight = map(int, read().split())
  # 무방향 그래프이기 때문에 시점과 종점을 모두 고려
  graph[u].append([weight, v]) # 최소힙을 사용하기 위해 가중치를 먼저 넣어준다
  graph[v].append([weight, u])

print(prim(1)) # 1번 노드에서 시작하는 경우

'''
입력:
8 10
1 2 13
1 3 15
1 8 17
2 7 4
3 4 11
3 5 13
4 5 12
5 8 9
6 7 3
7 8 6

출력:
([1, 2, 7, 6, 8, 5, 4, 3], 58)
'''

 

 

https://koosco.tistory.com/entry/Python-%ED%94%84%EB%A6%BC-%EC%95%8C%EA%B3%A0%EB%A6%AC%EC%A6%98Prims-Algorithm