정글/알고리즘

[백준/Java] 11049번 : 행렬 곱셈

nkdev 2025. 4. 10. 02:42

문제

https://www.acmicpc.net/problem/11049

풀이

문제에서 주어진 원리 이해하기

행렬 곱셈 -> 앞 행렬의 열 수와 뒤 행렬의 행 수가 같아야 곱셈 가능

n*m행렬과 m*k 행렬의 곱셈 : 연산 수는 총 n*m*k번

 

그러나 행렬 N개를 곱하는 경우 곱셈 순서에 따라 연산 수가 달라질 수 있음

 

예)

A : 5*3

B : 3*2

C : 2*6

 

(AB)C

5*3과 3*2 -> 5*3*2 = 30

5*2와 2*6 -> 5*2*6 = 60

30 + 60 = 90

 

A(BC)

3*2와 2*6 -> 3*2*6 = 36

3*6과 5*3 -> 3*6*5 = 90

36 + 90 = 126

dp에 저장해야할 값?

처음에는 행렬A, 행렬B의 곱셈 결과로 몇 행, 몇 열짜리의 새로운 행렬이 생성되는지의 정보를 저장해야하나 생각했다.

해당 정보를 보고 연산 횟수를 계산해낼 수 있기 때문이다.

 

그러나 dp에서는 '반복되는 연산의 피연산자값'이 아니라 '반복되는 연산의 결과값'을 메모이제이션해야 한다.

dp의 핵심은 같은 연산을 반복하지 않도록 그 연산의 결과값을 저장해두고 재사용하는 것이기 때문이다.

따라서 이 문제에서는 '행렬 곱셈 시 얻을 수 있는 연산의 최소 개수'를 메모이제이션 해야 한다.

 

문제에서 행렬 곱셈은 입력받은 순서를 따라 연속적으로 이루어진다. 

그리고 곱셈 순서는 여러 가지가 될 수 있다.

만약 행렬이 4개 있다면

((AB)C)D

(A(BC))D

A((BC)D)

A(B(CD))

(AB)(CD)

이렇게 5가지 경우가 나올 수 있는데 그 중 가장 작은 연산 수를 구해야 한다.

 

점화식 도출 과정

내가 처음 짠 점화식 :

if(i==j) 
	dp[i][j]=0; 
else 
	dp[i][j]=min(dp[i][j-1],dp[i+1][j], dp[i][j-2]+dp[i+2][j])+val[0]*val[i]*val[i+1]

 

이 점화식이 틀린 이유 :

 

행렬 곱셈 최적화는 곱셈 순서를 다양한 분할 지점 k를 기준으로 나눠서 계산해야 한다.

예를 들어 A, B, C, D 행렬을 곱할 때는 ((AB)C)D, (A(BC))D, A((BC)D), (AB)(CD), A(B(CD)) 같은 모든 분할을 고려해야 함

그런데 해당 점화식은 모든 분할을 고려하지 않고 있다.

 

 

맞는 풀이 :

1️⃣ 큰 문제를 작은 문제로 나누기

앞서 행렬이 4개일 때는 ((AB)C)D, (A(BC))D, A((BC)D), (AB)(CD), A(B(CD))

이렇게 다섯 가지 경우 중 가장 연산 횟수가 작은 경우를 선택해야 한다고 했다.

 

'행렬 4개의 최소 연산 횟수'를 구하는 과정에서 미리 구해둔 '행렬 3개의 최소 연산 횟수'를 사용할 수 있다.

즉 이미 행렬 3개의 곱을 검사할 때 (AB)C, A(BC) 중 더 적은 연산 횟수가 뭔지 구했으므로

그 답을 메모이제이션해두고 행렬 4개의 곱을 검사할 때 갖다 쓰면 된다.

 

행렬 n개의 곱을 작은 여러 개의 문제로 나누는 방법은 분할 지점을 세우는 것이다.

n=4이면 분할 지점은 아래와 같이 세 개가 될 수 있다.

A/BCD 

AB/CD

ABC/D

 

2️⃣  bottom-up으로 dp값 채우기

문제를 가장 작은 경우로 분할하여 그 문제부터 답을 찾아 나가면서 큰 문제의 답을 찾을 수 있다.

즉 bottom-up으로 dp값을 채워 나가면서 최종 문제의 답을 구해보자.

 

먼저 가장 작은 경우를 생각해보자. 행렬 2개를 곱했을 때 최소 연산 횟수를 구한다.

행렬이 2개일 때는 분할 지점이 하나 뿐이다. (두 행렬의 중간 지점)

A/B

그래서 여러 분할을 비교할 필요 없이 그냥 [A의 행 × A의 열 × B의 열]을 계산한 값이 곧 최소 연산 횟수가 된다.

마찬가지로 BC, CD의 최소 연산 횟수도 구해준다.

 

다음으로 행렬 3개를 곱했을 때를 생각해보자.

분할 지점이 2개이다.

A/BC

AB/C

BC를 먼저 행렬 곱셈 하고 그 결과에 A를 곱셈하여 얻은 연산 횟수와 AB를 먼저 행렬 곱셈 하고 그 결과에 C를 곱셈하여 얻은 연산 횟수를 비교하여 더 작은 값이 최소 연산 횟수가 된다.

이 경우 아까 구해뒀던 행렬 2개짜리 곱의 최소 연산 횟수를 이용할 수 있다. 

[BC의 최소 연산 횟수 + A의 행 × A의 열 × C의 열] 

[AB의 최소 연산 횟수 + C의 행 × C의 열 × B의 열]

두 값 중 더 작은 값을 선택한다.

BCD도 마찬가지로 구해준다.

 

시작 지점, 끝 지점 인덱스를 설정하고 (s, e) 분할 지점을 설정한다. (k)

 

코드

import java.io.*;
import java.util.*;

public class Main {
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        int n = Integer.parseInt(br.readLine());
        ArrayList<Matrix> arr = new ArrayList<>();
        for(int i=0; i<n; i++){
            StringTokenizer st = new StringTokenizer(br.readLine());
            int row = Integer.parseInt(st.nextToken());
            int col = Integer.parseInt(st.nextToken());
            arr.add(new Matrix(row, col)); //각 행렬의 크기 저장
        }
        int[][] dp = new int[n][n]; //dp[i][j] = i번 부터 j번 행렬까지 모두 곱할 때의 최소 곱셈 횟수
        for(int i=0; i<n; i++){ //최소값을 구하기 위해 큰 수로 초기화
            Arrays.fill(dp[i], Integer.MAX_VALUE);
        }
        for(int i=0; i<n; i++){ //한 개의 행렬만 곱하는 경우 연산이 필요 없음
            dp[i][i] = 0;
        }
        for(int i=2; i<=n; i++){ //구간 길이 (몇 개의 행렬을 곱할 건지) -> 구간 길이가 i일 때 가능한 모든 시작점에서 구간을 검사
            for(int s=0; s<=n-i; s++){ //s = 구간의 시작 인덱스
                int e = s+i-1; //e = 구간의 끝 인덱스
                dp[s][e] = Integer.MAX_VALUE;
                for(int k=s; k<e; k++){ //[s, k], [k+1, e]로 구간을 나눠서 그 결과끼리 곱함
                    int cost = dp[s][k] + dp[k+1][e] + (arr.get(s).row * arr.get(k).col * arr.get(e).col); //s~k까지 곱한 최소 연산 수
                    dp[s][e] = Math.min(dp[s][e], cost);
                }
            }
        }
        System.out.println(dp[0][n-1]);
    }

}
class Matrix{
    int row, col;
    Matrix(int row, int col){
        this.row = row;
        this.col = col;
    }
}