정글/알고리즘

[백준/Python] 1197번 : 최소 스패닝 트리

nkdev 2025. 3. 30. 03:34

문제

https://www.acmicpc.net/problem/1197

풀이

크루스칼 알고리즘을 사용해 최소 스패닝 트리를 만들었다. 

시간 초과 원인?

가중치가 가장 큰 간선에서 비로소 스패닝 트리가 만들어질 수도 있으므로, 어차피 최악의 경우에는 모든 간선을 확인해야 함

시간 초과의 근본적인 원인은 union-find에 어떤 최적화도 적용되어 있지 않아 높이가 아주 높은 트리가 만들어질 수 있기 때문

해결 방법

union 함수에서 parent 합치기

# 시간 초과
def union(a, b):
    a_ = find(a)
    b_ = find(b)
    if a_>b_:
        p[a_]=b_
    else:
        p[b_]=a_
rank = [1] * (v+1)  # 트리의 랭크(높이) 저장
def union(a, b):
    a_ = find(a)
    b_ = find(b)
    if a_ != b_:
        if rank[a_] > rank[b_]:  # 랭크가 높은 쪽으로 합침
            p[b_] = a_
        elif rank[a_] < rank[b_]:
            p[a_] = b_
        else:
            p[b_] = a_
            rank[a_] += 1

find 함수 경로 압축하기 :

  • find(x) 호출 후 x의 부모들을 루트로 직접 연결하여 이후에 find(x)를 호출할 때 더 빠르게 부모를 찾을 수 있게 함
  • find()가 호출될 때마다 트리 높이가 낮아져 이후 연산을 O(1) 수준으로 최적화할 수 있음
# 시간 초과
def find(x):
    if p[x]==x:
        return x
    return find(p[x])

find(x)함수는 노드 x의 루트를 찾음. 하지만 x가 중간에 부모를 가리키고 있으면 루트를 찾을 때까지 부모를 계속 따라가야 함

이 때 find는 재귀적으로 계속 부모를 추적함

def find(x):
    if p[x]!=x:
        p[x] = find(p[x])
    return p[x]

find(x)를 호출할 때마다 부모들이 직접 루트로 연결되기 때문에 트리의 깊이가 계속 얕아짐

예를 들어 x가 부모 y를 가리키고 y가 부모 z를 가리킨다면 find(x)를 호출할 때마다 x는 z를 바로 가리키게 된다.

그 후에는 find(x)를 다시 호출해도 x는 바로 z를 찾을 수 있다.

 

코드

import sys, heapq
sys.setrecursionlimit(10**9)
v, e = map(int, sys.stdin.readline().split())
graph = []
for i in range(e):
    x, y, w = map(int, sys.stdin.readline().split())
    heapq.heappush(graph,(w, x, y))
p = [i for i in range(v+1)]
total_weight = 0
total_edges = 0
def find(x):
    if p[x]!=x:
        p[x] = find(p[x])
    return p[x]
rank = [1] * (v+1)  # 트리의 랭크(높이) 저장

def union(a, b):
    a_ = find(a)
    b_ = find(b)
    if a_ != b_:
        if rank[a_] > rank[b_]:  # 랭크가 높은 쪽으로 합침
            p[b_] = a_
        elif rank[a_] < rank[b_]:
            p[a_] = b_
        else:
            p[b_] = a_
            rank[a_] += 1
while graph:
    w, a, b = heapq.heappop(graph)
    if find(a)!=find(b):
        total_weight += w
        union(a, b)
        total_edges += 1
        if total_edges == v-1:
            break
print(total_weight)

https://www.acmicpc.net/board/view/148711

https://www.acmicpc.net/board/view/149643