View
반응형
트리의 가중치
백준 1289 트리 DP Python Java
문제 분석
- 트리: N개 정점, N-1개 간선 (가중치 있음)
- 경로의 가중치: 경로 상 간선 가중치의 곱
- 트리의 가중치: 모든 정점 쌍 (u, v) 간 경로 가중치의 합
- 제약: N ≤ 100,000 / W ≤ 1,000 / 답은 mod 109+7
핵심: 합이 아니라 곱이므로 "각 간선 기여도" 기법은 사용 불가 → 트리 DP 필요
접근법
트리 DP — O(N)
트리를 루팅한 뒤, 각 정점 v에서:
dp[v]= v에서 서브트리 내 모든 후손까지 경로 가중치(곱)의 합h[c]= wc × (1 + dp[c]) — v에서 자식 c의 서브트리 전체로 가는 경로 가중치 합
정점 v의 기여:
- v → 후손 경로: Σ h[ci]
- 서로 다른 서브트리를 잇는 경로 (v가 LCA): Σi<j h[ci] × h[cj]
2번은 running sum으로 O(자식 수)에 계산 → 전체 O(N)
| 구분 | 설명 |
|---|---|
| dp[v] | v → 서브트리 내 모든 후손까지 경로 곱의 합 |
| h[c] | v → 자식 c 서브트리 전체 경로 곱 합 = w × (1 + dp[c]) |
| 정점 v 기여 | (v 출발 경로) + (v 경유 교차 경로) = Σh + Σhi×hj |
| 교차 계산 | running sum 기법으로 나눗셈 없이 O(자식 수) |
풀이 (Python)
import sys
def solve():
data = sys.stdin.buffer.read().split()
idx = 0
N = int(data[idx]); idx += 1
MOD = 10**9 + 7
if N == 1:
print(0)
return
# 인접 리스트 구성
adj = [[] for _ in range(N + 1)]
for _ in range(N - 1):
a, b, w = int(data[idx]), int(data[idx+1]), int(data[idx+2])
idx += 3
adj[a].append((b, w))
adj[b].append((a, w))
# DFS로 트리 구조 파악 (루트: 1)
visited = [False] * (N + 1)
order = [] # BFS/DFS 순서
children = [[] for _ in range(N + 1)] # children[v] = [(자식, 간선가중치)]
stack = [1]
visited[1] = True
while stack:
v = stack.pop()
order.append(v)
for u, w in adj[v]:
if not visited[u]:
visited[u] = True
children[v].append((u, w))
stack.append(u)
# 리프부터 역순으로 DP 계산
dp = [0] * (N + 1) # dp[v] = v에서 서브트리 후손까지 경로 곱의 합
ans = 0
for v in reversed(order):
running = 0 # 지금까지 처리한 자식 서브트리의 h값 합
cross = 0 # 서로 다른 서브트리 쌍의 h곱 합
for c, w in children[v]:
h = w * (1 + dp[c]) % MOD # v→c 서브트리 전체 경로 가중치 합
cross = (cross + running * h) % MOD # 이전 서브트리들과의 교차 경로
running = (running + h) % MOD
dp[v] = running
ans = (ans + running + cross) % MOD # v에서 출발 + v를 경유하는 경로
print(ans)
solve()
풀이 (Java)
import java.io.*;
import java.util.*;
public class Main {
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
int N = Integer.parseInt(br.readLine().trim());
long MOD = 1_000_000_007L;
if (N == 1) {
System.out.println(0);
return;
}
// 인접 리스트 구성
List<int[]>[] adj = new ArrayList[N + 1];
for (int i = 1; i <= N; i++) adj[i] = new ArrayList<>();
for (int i = 0; i < N - 1; i++) {
StringTokenizer st = new StringTokenizer(br.readLine());
int a = Integer.parseInt(st.nextToken());
int b = Integer.parseInt(st.nextToken());
int w = Integer.parseInt(st.nextToken());
adj[a].add(new int[]{b, w});
adj[b].add(new int[]{a, w});
}
// BFS로 트리 구조 파악 (루트: 1)
int[] order = new int[N];
boolean[] visited = new boolean[N + 1];
List<int[]>[] children = new ArrayList[N + 1];
for (int i = 1; i <= N; i++) children[i] = new ArrayList<>();
int front = 0, back = 0;
order[back++] = 1;
visited[1] = true;
while (front < back) {
int v = order[front++];
for (int[] edge : adj[v]) {
int u = edge[0], w = edge[1];
if (!visited[u]) {
visited[u] = true;
children[v].add(new int[]{u, w});
order[back++] = u;
}
}
}
// 리프부터 역순 DP
long[] dp = new long[N + 1];
long ans = 0;
for (int i = N - 1; i >= 0; i--) {
int v = order[i];
long running = 0, cross = 0;
for (int[] child : children[v]) {
int c = child[0], w = child[1];
long h = (long) w % MOD * ((1 + dp[c]) % MOD) % MOD;
cross = (cross + running % MOD * (h % MOD)) % MOD;
running = (running + h) % MOD;
}
dp[v] = running;
ans = (ans + running + cross) % MOD;
}
System.out.println(ans);
}
}
예제 트레이스 (예제 3)
트리: 1-2(2), 2-3(3), 3-4(2), 3-5(2)
처리 순서 (역순): 4 → 5 → 3 → 2 → 1
v=4: 리프, dp=0, 기여=0
v=5: 리프, dp=0, 기여=0
v=3: 자식 {4(w=2), 5(w=2)}
h[4] = 2×(1+0) = 2, running=2, cross=0
h[5] = 2×(1+0) = 2, cross=0+2×2=4, running=4
dp[3]=4, 기여=4+4=8 (경로: 3→4=2, 3→5=2, 4→5=4)
v=2: 자식 {3(w=3)}
h[3] = 3×(1+4) = 15, running=15, cross=0
dp[2]=15, 기여=15 (경로: 2→3=3, 2→4=6, 2→5=6)
v=1: 자식 {2(w=2)}
h[2] = 2×(1+15) = 32, running=32, cross=0
dp[1]=32, 기여=32 (경로: 1→2=2, 1→3=6, 1→4=12, 1→5=12)
합계: 0+0+8+15+32 = 55 ✓
시간 복잡도
O(N)
공간 복잡도
O(N)
클린 코드 — Python
import sys
def solve():
data = sys.stdin.buffer.read().split()
idx = 0
N = int(data[idx]); idx += 1
MOD = 10**9 + 7
if N == 1:
print(0)
return
adj = [[] for _ in range(N + 1)]
for _ in range(N - 1):
a, b, w = int(data[idx]), int(data[idx+1]), int(data[idx+2])
idx += 3
adj[a].append((b, w))
adj[b].append((a, w))
visited = [False] * (N + 1)
order = []
children = [[] for _ in range(N + 1)]
stack = [1]
visited[1] = True
while stack:
v = stack.pop()
order.append(v)
for u, w in adj[v]:
if not visited[u]:
visited[u] = True
children[v].append((u, w))
stack.append(u)
dp = [0] * (N + 1)
ans = 0
for v in reversed(order):
running = 0
cross = 0
for c, w in children[v]:
h = w * (1 + dp[c]) % MOD
cross = (cross + running * h) % MOD
running = (running + h) % MOD
dp[v] = running
ans = (ans + running + cross) % MOD
print(ans)
solve()
클린 코드 — Java
import java.io.*;
import java.util.*;
public class Main {
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
int N = Integer.parseInt(br.readLine().trim());
long MOD = 1_000_000_007L;
if (N == 1) { System.out.println(0); return; }
List<int[]>[] adj = new ArrayList[N + 1];
for (int i = 1; i <= N; i++) adj[i] = new ArrayList<>();
for (int i = 0; i < N - 1; i++) {
StringTokenizer st = new StringTokenizer(br.readLine());
int a = Integer.parseInt(st.nextToken());
int b = Integer.parseInt(st.nextToken());
int w = Integer.parseInt(st.nextToken());
adj[a].add(new int[]{b, w});
adj[b].add(new int[]{a, w});
}
int[] order = new int[N];
boolean[] visited = new boolean[N + 1];
List<int[]>[] children = new ArrayList[N + 1];
for (int i = 1; i <= N; i++) children[i] = new ArrayList<>();
int front = 0, back = 0;
order[back++] = 1; visited[1] = true;
while (front < back) {
int v = order[front++];
for (int[] edge : adj[v]) {
int u = edge[0], w = edge[1];
if (!visited[u]) {
visited[u] = true;
children[v].add(new int[]{u, w});
order[back++] = u;
}
}
}
long[] dp = new long[N + 1];
long ans = 0;
for (int i = N - 1; i >= 0; i--) {
int v = order[i];
long running = 0, cross = 0;
for (int[] child : children[v]) {
int c = child[0], w = child[1];
long h = (long) w % MOD * ((1 + dp[c]) % MOD) % MOD;
cross = (cross + running % MOD * (h % MOD)) % MOD;
running = (running + h) % MOD;
}
dp[v] = running;
ans = (ans + running + cross) % MOD;
}
System.out.println(ans);
}
}
알고리즘 풀이 블로그
728x90
반응형
'Problem Solving > Baekjoon' 카테고리의 다른 글
| [백준] 2702: 초6 수학 (Java) (0) | 2026.02.17 |
|---|---|
| [백준] 14877: 순열 교환 (Java) (0) | 2026.02.17 |
| [백준] 21609: 상어 중학교 [Java] (0) | 2024.10.17 |
| [백준] 15989: 1, 2, 3 더하기 4 [Java] (0) | 2024.10.15 |
| [백준] 19585: 전설 [Java] (0) | 2024.10.10 |
reply
