정글/알고리즘

[백준/Python] 10830번 : 행렬 제곱

nkdev 2025. 3. 23. 20:42

문제

https://www.acmicpc.net/problem/10830

행렬 A의 B제곱을 구하는 문제이다. 1 ≤ B ≤ 100,000,000,000

A^B의 각 원소를 1,000으로 나눈 나머지를 출력해야 하고 

원소 개수 N에 대해 2 ≤ N ≤ 5를 만족한다. 

풀이

N번의 거듭제곱을 그냥 계산하면 시간 복잡도가 O(N)이다. 

반면 분할 정복을 사용해 거듭제곱을 계산하면 O(logN) 안에 계산할 수 있다.

 

  • 분할 정복을 이용한 거듭제곱
# 분할정복을 사용한 거듭제곱 구현하기
# 지수를 2로 나눈 나머지가 0인 횟수만큼 자기 자신과 곱셈하기 -> 시간 복잡도가 O(nlogn)
n = 16 # 지수
a = 2 # 밑
ret = a
while n>=2:
    if n%2==0: # n이 짝수이면
       ret *= ret
    else: # n이 홀수이면
        ret *= ret*a
    n//=2
print(ret)

오답노트

틀린 코드 :

import sys
sys.setrecursionlimit(100000)

N, B = map(int, sys.stdin.readline().split())
arr = []

for i in range(N):
    arr.append(list(map(int, sys.stdin.readline().split())))

for i in range(N):
    for j in range(N):
        arr[i][j] %= 1000

# 행렬 A와 B의 곱을 리턴
def mul(A, B, n):
    ret = [[0]*n for _ in range(n)]# 두 행렬의 곱
    for i in range(n): # 행
        for j in range(n):
            for k in range(n):
                ret[i][j] += A[i][k] * B[k][j]
    return ret

def solve(arr_, b, n):
    ret = arr_
    while b>=2:
        if b%2==0: # 지수가 짝수이면 자기 자신과 곱셈
            ret = mul(ret, ret, n)
        else: # 지수가 홀수이면 자기 자신과 원래 행렬을 곱해주기
            ret = mul(mul(ret, ret, n), arr_, n)
        b//=2
    return ret

result = solve(arr, B, N) # 행렬 arr를 지수 B번 곱하기 -> N은 행렬 길이
for i in range(N):
    for j in range(N):
        print(result[i][j]%1000, end=" ")
    print()

지수가 짝수이면 자기 자신과 곱셈하고

지수가 홀수이면 자기 자신과 곱셈한 후 원래 행렬을 한 번 더 곱해주었다.

이 방법을 사용하면 지수가 5일 때 6번 거듭제곱을 하게 된다.

 

  • 홀수가 어느 depth에 나왔는지에 따라 원래 행렬(밑)을 몇 번 곱하는지 결정하면 된다...?

만약 지수가 10인 경우 지수를 2로 계속 나눠서 작게 만들면

10 -> 5 -> 2 -> 1

이렇게 되는데 지수가 홀수일 때 (5일 때)의 값을 구하기 위해서는

이전 해를 자기 자신만 거듭제곱해주면 안 되고 원래 행렬을 두 번 더 곱해줘야 한다.

 

x^10

= (x^5) * (x^5) <- 지수가 5일 때

= (x^2 * x^2 * x^1) * (x^2 * x^2 * x^1) <- 원래 행렬을 한 번씩 더 곱해줘야 함

= (대충 x^1로 분할되어 만들어진 식)

 

이 부분을 재귀호출로 풀 때는 지수가 홀수인 경우에 조건을 달아서 넘겨주면 될 것 같은데

반복문으로 풀고 있어서 어떻게 처리해야 할 지 고민 중이다.

일단 재귀로 풀어볼까..

 

재귀로 풀었더니 바로 풀렸다.

import sys
sys.setrecursionlimit(100000)

N, B = map(int, sys.stdin.readline().split())
arr = []

for i in range(N):
    arr.append(list(map(int, sys.stdin.readline().split())))

for i in range(N):
    for j in range(N):
        arr[i][j] %= 1000

# 행렬 A와 B의 곱을 리턴
def mul(A, B):
    n = len(A)
    ret = [[0]*n for _ in range(n)]# 두 행렬의 곱
    for i in range(n): # 행
        for j in range(n): # 열
            for k in range(n):
                ret[i][j] += A[i][k] * B[k][j]
            ret[i][j] %= 1000
    return ret

def recur(b):
    global arr, N

    # 지수가 1이면 원래 행렬을 리턴
    if b==1:
        return arr

    # 지수가 2이면 원래 행렬을 곱셈한 후 각 요소를 1000으로 나눈 값 리턴
    if b==2:
        return mul(arr, arr)

    # 지수가 짝수이면 지수//2의 결과를 거듭제곱한 값을 리턴
    if b%2==0:
        tmp = recur(b//2)
        return mul(tmp, tmp)

    # 지수가 홀수이면 지수//2의 결과를 거듭제곱한 값에 원래 행렬을 한 번 더 곱한 값을 리턴
    else:
        tmp = recur(b//2)
        return mul(mul(tmp, tmp), arr)

result = recur(B)
for i in range(N):
    for j in range(N):
        print(result[i][j], end=" ")
    print()

하루 넘게 반복문으로 푸는 방법을 고민해봤는데 풀리지 않아서 재귀로 풀어봤는데 이 방법이 더 쉬웠다.

 

반복문으로 푸는 경우 코드는 아래와 같다.

import sys
sys.setrecursionlimit(100000)

N, B = map(int, sys.stdin.readline().split())
arr = []

for i in range(N):
    arr.append(list(map(int, sys.stdin.readline().split())))

for i in range(N):
    for j in range(N):
        arr[i][j] %= 1000

# 행렬 A와 B의 곱을 리턴
def mul(A, B):
    n = len(A)
    ret = [[0]*n for _ in range(n)]  # 두 행렬의 곱
    for i in range(n):  # 행
        for j in range(n):  # 열
            for k in range(n):
                ret[i][j] += A[i][k] * B[k][j]
            ret[i][j] %= 1000
    return ret

# 반복문 기반 분할정복
def matrix_power(arr, b):
    n = len(arr)
    result = [[1 if i == j else 0 for j in range(n)] for i in range(n)]  # 단위 행렬 초기화
    base = arr.copy()  # 초기 행렬 복사

    while b > 0:
        if b % 2 == 1:  # 지수가 홀수일 경우 현재 base를 결과에 곱함
            result = mul(result, base)
        base = mul(base, base)  # base를 제곱
        b //= 2  # 지수를 반으로 줄임

    return result

result = matrix_power(arr, B)

for i in range(N):
    for j in range(N):
        print(result[i][j], end=" ")
    print()