View

반응형

트리의 가중치

백준 1289 트리 DP Python Java

문제 분석

  • 트리: N개 정점, N-1개 간선 (가중치 있음)
  • 경로의 가중치: 경로 상 간선 가중치의
  • 트리의 가중치: 모든 정점 쌍 (u, v) 간 경로 가중치의
  • 제약: N ≤ 100,000 / W ≤ 1,000 / 답은 mod 109+7

핵심: 합이 아니라 이므로 "각 간선 기여도" 기법은 사용 불가 → 트리 DP 필요

접근법

트리 DP — O(N)

트리를 루팅한 뒤, 각 정점 v에서:

  • dp[v] = v에서 서브트리 내 모든 후손까지 경로 가중치(곱)의 합
  • h[c] = wc × (1 + dp[c]) — v에서 자식 c의 서브트리 전체로 가는 경로 가중치 합

정점 v의 기여:

  1. v → 후손 경로: Σ h[ci]
  2. 서로 다른 서브트리를 잇는 경로 (v가 LCA): Σi<j h[ci] × h[cj]

2번은 running sum으로 O(자식 수)에 계산 → 전체 O(N)

구분 설명
dp[v] v → 서브트리 내 모든 후손까지 경로 곱의 합
h[c] v → 자식 c 서브트리 전체 경로 곱 합 = w × (1 + dp[c])
정점 v 기여 (v 출발 경로) + (v 경유 교차 경로) = Σh + Σhi×hj
교차 계산 running sum 기법으로 나눗셈 없이 O(자식 수)

풀이 (Python)

import sys

def solve():
    data = sys.stdin.buffer.read().split()
    idx = 0
    N = int(data[idx]); idx += 1
    MOD = 10**9 + 7

    if N == 1:
        print(0)
        return

    # 인접 리스트 구성
    adj = [[] for _ in range(N + 1)]
    for _ in range(N - 1):
        a, b, w = int(data[idx]), int(data[idx+1]), int(data[idx+2])
        idx += 3
        adj[a].append((b, w))
        adj[b].append((a, w))

    # DFS로 트리 구조 파악 (루트: 1)
    visited = [False] * (N + 1)
    order = []                              # BFS/DFS 순서
    children = [[] for _ in range(N + 1)]   # children[v] = [(자식, 간선가중치)]
    stack = [1]
    visited[1] = True
    while stack:
        v = stack.pop()
        order.append(v)
        for u, w in adj[v]:
            if not visited[u]:
                visited[u] = True
                children[v].append((u, w))
                stack.append(u)

    # 리프부터 역순으로 DP 계산
    dp = [0] * (N + 1)  # dp[v] = v에서 서브트리 후손까지 경로 곱의 합
    ans = 0

    for v in reversed(order):
        running = 0  # 지금까지 처리한 자식 서브트리의 h값 합
        cross = 0    # 서로 다른 서브트리 쌍의 h곱 합
        for c, w in children[v]:
            h = w * (1 + dp[c]) % MOD       # v→c 서브트리 전체 경로 가중치 합
            cross = (cross + running * h) % MOD  # 이전 서브트리들과의 교차 경로
            running = (running + h) % MOD
        dp[v] = running
        ans = (ans + running + cross) % MOD  # v에서 출발 + v를 경유하는 경로

    print(ans)

solve()

풀이 (Java)

import java.io.*;
import java.util.*;

public class Main {
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        int N = Integer.parseInt(br.readLine().trim());
        long MOD = 1_000_000_007L;

        if (N == 1) {
            System.out.println(0);
            return;
        }

        // 인접 리스트 구성
        List<int[]>[] adj = new ArrayList[N + 1];
        for (int i = 1; i <= N; i++) adj[i] = new ArrayList<>();

        for (int i = 0; i < N - 1; i++) {
            StringTokenizer st = new StringTokenizer(br.readLine());
            int a = Integer.parseInt(st.nextToken());
            int b = Integer.parseInt(st.nextToken());
            int w = Integer.parseInt(st.nextToken());
            adj[a].add(new int[]{b, w});
            adj[b].add(new int[]{a, w});
        }

        // BFS로 트리 구조 파악 (루트: 1)
        int[] order = new int[N];
        boolean[] visited = new boolean[N + 1];
        List<int[]>[] children = new ArrayList[N + 1];
        for (int i = 1; i <= N; i++) children[i] = new ArrayList<>();

        int front = 0, back = 0;
        order[back++] = 1;
        visited[1] = true;
        while (front < back) {
            int v = order[front++];
            for (int[] edge : adj[v]) {
                int u = edge[0], w = edge[1];
                if (!visited[u]) {
                    visited[u] = true;
                    children[v].add(new int[]{u, w});
                    order[back++] = u;
                }
            }
        }

        // 리프부터 역순 DP
        long[] dp = new long[N + 1];
        long ans = 0;

        for (int i = N - 1; i >= 0; i--) {
            int v = order[i];
            long running = 0, cross = 0;
            for (int[] child : children[v]) {
                int c = child[0], w = child[1];
                long h = (long) w % MOD * ((1 + dp[c]) % MOD) % MOD;
                cross = (cross + running % MOD * (h % MOD)) % MOD;
                running = (running + h) % MOD;
            }
            dp[v] = running;
            ans = (ans + running + cross) % MOD;
        }

        System.out.println(ans);
    }
}

예제 트레이스 (예제 3)

트리: 1-2(2), 2-3(3), 3-4(2), 3-5(2)

처리 순서 (역순): 4 → 5 → 3 → 2 → 1

v=4: 리프, dp=0, 기여=0
v=5: 리프, dp=0, 기여=0

v=3: 자식 {4(w=2), 5(w=2)}
  h[4] = 2×(1+0) = 2,  running=2, cross=0
  h[5] = 2×(1+0) = 2,  cross=0+2×2=4, running=4
  dp[3]=4, 기여=4+4=8  (경로: 3→4=2, 3→5=2, 4→5=4)

v=2: 자식 {3(w=3)}
  h[3] = 3×(1+4) = 15,  running=15, cross=0
  dp[2]=15, 기여=15  (경로: 2→3=3, 2→4=6, 2→5=6)

v=1: 자식 {2(w=2)}
  h[2] = 2×(1+15) = 32,  running=32, cross=0
  dp[1]=32, 기여=32  (경로: 1→2=2, 1→3=6, 1→4=12, 1→5=12)

합계: 0+0+8+15+32 = 55 ✓
시간 복잡도
O(N)
공간 복잡도
O(N)

클린 코드 — Python

import sys

def solve():
    data = sys.stdin.buffer.read().split()
    idx = 0
    N = int(data[idx]); idx += 1
    MOD = 10**9 + 7
    if N == 1:
        print(0)
        return
    adj = [[] for _ in range(N + 1)]
    for _ in range(N - 1):
        a, b, w = int(data[idx]), int(data[idx+1]), int(data[idx+2])
        idx += 3
        adj[a].append((b, w))
        adj[b].append((a, w))
    visited = [False] * (N + 1)
    order = []
    children = [[] for _ in range(N + 1)]
    stack = [1]
    visited[1] = True
    while stack:
        v = stack.pop()
        order.append(v)
        for u, w in adj[v]:
            if not visited[u]:
                visited[u] = True
                children[v].append((u, w))
                stack.append(u)
    dp = [0] * (N + 1)
    ans = 0
    for v in reversed(order):
        running = 0
        cross = 0
        for c, w in children[v]:
            h = w * (1 + dp[c]) % MOD
            cross = (cross + running * h) % MOD
            running = (running + h) % MOD
        dp[v] = running
        ans = (ans + running + cross) % MOD
    print(ans)

solve()

클린 코드 — Java

import java.io.*;
import java.util.*;

public class Main {
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        int N = Integer.parseInt(br.readLine().trim());
        long MOD = 1_000_000_007L;
        if (N == 1) { System.out.println(0); return; }
        List<int[]>[] adj = new ArrayList[N + 1];
        for (int i = 1; i <= N; i++) adj[i] = new ArrayList<>();
        for (int i = 0; i < N - 1; i++) {
            StringTokenizer st = new StringTokenizer(br.readLine());
            int a = Integer.parseInt(st.nextToken());
            int b = Integer.parseInt(st.nextToken());
            int w = Integer.parseInt(st.nextToken());
            adj[a].add(new int[]{b, w});
            adj[b].add(new int[]{a, w});
        }
        int[] order = new int[N];
        boolean[] visited = new boolean[N + 1];
        List<int[]>[] children = new ArrayList[N + 1];
        for (int i = 1; i <= N; i++) children[i] = new ArrayList<>();
        int front = 0, back = 0;
        order[back++] = 1; visited[1] = true;
        while (front < back) {
            int v = order[front++];
            for (int[] edge : adj[v]) {
                int u = edge[0], w = edge[1];
                if (!visited[u]) {
                    visited[u] = true;
                    children[v].add(new int[]{u, w});
                    order[back++] = u;
                }
            }
        }
        long[] dp = new long[N + 1];
        long ans = 0;
        for (int i = N - 1; i >= 0; i--) {
            int v = order[i];
            long running = 0, cross = 0;
            for (int[] child : children[v]) {
                int c = child[0], w = child[1];
                long h = (long) w % MOD * ((1 + dp[c]) % MOD) % MOD;
                cross = (cross + running % MOD * (h % MOD)) % MOD;
                running = (running + h) % MOD;
            }
            dp[v] = running;
            ans = (ans + running + cross) % MOD;
        }
        System.out.println(ans);
    }
}
알고리즘 풀이 블로그
728x90
반응형
Share Link
reply
«   2026/03   »
1 2 3 4 5 6 7
8 9 10 11 12 13 14
15 16 17 18 19 20 21
22 23 24 25 26 27 28
29 30 31