[BOJ1693] 트리 색칠하기

문제 설명

문제 링크

  • 트리 DP를 공부하면서 재밌는 문제를 풀게 되었다.
  • 먼저 문제 설명을 하자면,
  • n개의 정점으로 이루어진 트리가 있고, 이 트리의 정점을 색칠하려 한다.
  • 색칠은 1~n번의 색깔로 색칠을 할 수 있고, 각 번호의 색으로 색칠을 할때 각 번호만큼 해당하는 비용이 든다.
  • 조건으로는 인접한 정점은 다른 색으로 색칠을 해야한다.
  • 이때, 전체 정점을 색칠하는데 드는 비용을 최소화 하는 것이다.

풀이 과정

  • 이 문제는 전형적인 트리 DP문제이긴 하다.
  • 하지만 DP를 모든 경우로 보게 될 경우 DP[n][n]만큼을 봐야하기때문에 시간복잡도가 O(N^2)이 된다.
  • 이를 줄이기 위해 아이디어가 필요한데, 이는 아래 링크의 koosaga님이 올려주신 글을 참고했다.
  • https://www.acmicpc.net/board/view/13972
  • 결론을 말하면, n개의 점을 색칠할 때, 최소의 비용을 찾기 위해 모든 색깔을 다 칠해볼 필요가 없다.

N개의 노드들을 최소 비용으로 색칠할 때, log2(N)의 색깔만 있으면 구할 수 있다.

  • 이 부분에 대해 증명 과정을 필기를 통해 대신한다.

코드

import java.util.*;
import java.io.*;

public class BOJ1693_트리색칠하기 {
    static long[][] dp;
    static boolean[] visit;
    static List<Integer>[] adj;
    static List<Integer>[] child;
    static int color;

    public static void main(String[] args) throws IOException {
        BufferedReader input = new BufferedReader(new InputStreamReader(System.in));
        int n = Integer.parseInt(input.readLine());
        color = getColor(n) + 1;
        dp = new long[n + 1][color + 1];
        visit = new boolean[n + 1];
        adj = new List[n + 1];
        child = new List[n + 1];
        for (int i = 0; i <= n; i++) {
            adj[i] = new ArrayList<>();
            child[i] = new ArrayList<>();
        }
        for (int i = 0; i < n - 1; i++) {
            StringTokenizer st = new StringTokenizer(input.readLine());
            int a = Integer.parseInt(st.nextToken());
            int b = Integer.parseInt(st.nextToken());
            adj[a].add(b);
            adj[b].add(a);
        }
        make_tree(1);
        for (int i = 1; i <= color; i++) {
            dp[1][i] = getColor(1, i);
        }
        long output = dp[1][1];
        for (int i = 2; i <= color; i++) {
            output = Math.min(output, dp[1][i]);
        }
        System.out.println(output);

    }

    static long getColor(int n, int c) {
        if (dp[n][c]!=0) return dp[n][c];
        dp[n][c] = c;
        for (int i = 0; i < child[n].size(); i++) {
            int nn = child[n].get(i);
            long cost = Integer.MAX_VALUE;
            for (int j = 1; j <= color; j++) {
                if (j != c) {
                    cost = Math.min(cost, getColor(nn, j));
                }
            }
            dp[n][c] += cost;
        }
        return dp[n][c];
    }
    static void make_tree(int n) {
        visit[n] = true;
        for (int i = 0; i < adj[n].size(); i++) {
            int nn = adj[n].get(i);
            if (!visit[nn]) {
                child[n].add(nn);
                make_tree(nn);

            }
        }
    }
    static int getColor(int n) {
        return (int) (Math.log10(n) / Math.log10(2));
    }
}

댓글남기기