View
반응형
순열 교환
백준
다항식 · NTT
스털링 수
더블링
Java
문제 분석
- 순열 A가 주어졌을 때, 정확히 k번 교환해서 A를 만드는 서로 다른 순열 B의 개수
- k = 1부터 n−1까지 모든 값에 대해 답을 구해야 함
- 제약: N ≤ 100,000 / 답은 mod 109+7
핵심 관찰: 답은 A에 무관하다! → “정확히 k개 전치의 곱으로 표현 가능한 순열의 수”와 동치
접근법
1. 수학적 환원
- 순열 σ의 사이클 수 = c일 때, 최소 전치 수 = n − c
- σ가 정확히 k개 전치의 곱 ⇔
n − c ≤ kandn − c ≡ k (mod 2) - ∴ f(k) = Σ S(n, j) (j ≥ n−k, j ≡ n−k mod 2)
여기서 S(n, j)는 제1종 부호없는 스털링 수 (사이클 수가 j인 순열의 수)
2. 스털링 수 계산
- P(x) = x(x+1)(x+2)···(x+n−1) 의 xj 계수 = S(n, j)
- 이 다항식을 효율적으로 계산해야 함
3. 더블링 — O(N log N)
핵심 아이디어: x(2m) = x(m) · (x+m)(m)
- Pm(x) = x(x+1)···(x+m−1) 을 재귀로 구함
- Pm(x+m) 은 다항식 shift로 O(m log m)에 계산
- Pm(x) × Pm(x+m) = P2m(x)
- n이 홀수면 마지막에 (x+n−1) 하나 더 곱함
T(n) = T(n/2) + O(n log n) → O(n log n) (기하급수 감소)
4. 다항식 shift
f(x) → f(x+c) 변환:
- Aj = f[j] · j! 을 뒤집고, Ck = ck/k! 과 컨볼루션
- bi = D[m−i] / i! (D = reverse(A) * C)
- 한 번의 다항식 곱셈 = O(m log m)
5. 3-mod NTT
109+7은 NTT 비친화적 → 998244353, 985661441, 754974721 세 소수로 NTT 후 CRT로 합산
6. 최종 답 산출
패리티별 접미사 합: E[c] = S(n,c) + E[c+2], f(k) = E[n−k]
예제 트레이스 (n=3)
P(x) = x(x+1)(x+2) = x³ + 3x² + 2x → S(3,1)=2, S(3,2)=3, S(3,3)=1 접미사 합 (패리티별): E[3] = S(3,3) = 1 E[2] = S(3,2) + 0 = 3 (E[4]=0) E[1] = S(3,1) + E[3] = 2+1 = 3 답: k=1: E[3-1] = E[2] = 3 k=2: E[3-2] = E[1] = 3 출력: 3 3 ✓ 검증 (k=1): (1,3,2), (2,1,3), (3,2,1) → 1번 교환으로 (3,1,2) 가능 = 3개 ✓
시간 복잡도
O(N log N)
공간 복잡도
O(N)
삽질 기록
| 시도 | 방식 | 결과 |
|---|---|---|
| 1차 | 분할정복 O(N log²N) + 3-mod NTT | TLE |
| 2차 | 더블링 O(N log N) + naive 혼합 | AC |
Java에서 3-mod NTT 상수가 무거워서 log 하나 차이가 AC/TLE를 갈랐다.
클린 코드 — Java
import java.io.*;
import java.util.*;
public class Main {
static final long MOD = 1_000_000_007L;
static final long[] P = {998244353, 985661441, 754974721};
static final long[] G = {3, 3, 11};
static long pw(long a, long b, long m) {
long r = 1; a %= m;
for (; b > 0; b >>= 1, a = a * a % m)
if ((b & 1) == 1) r = r * a % m;
return r;
}
static long[] fact, ifact;
static void initFact(int n) {
fact = new long[n + 1]; ifact = new long[n + 1];
fact[0] = 1;
for (int i = 1; i <= n; i++) fact[i] = fact[i-1] * i % MOD;
ifact[n] = pw(fact[n], MOD - 2, MOD);
for (int i = n - 1; i >= 0; i--) ifact[i] = ifact[i+1] * (i+1) % MOD;
}
static final int MAXN = 1 << 18;
static long[] ga = new long[MAXN], gb = new long[MAXN];
static void ntt(long[] a, int n, boolean inv, int id) {
long mod = P[id], g = G[id];
for (int i = 1, j = 0; i < n; i++) {
int bit = n >> 1;
for (; (j & bit) != 0; bit >>= 1) j ^= bit;
j ^= bit;
if (i < j) { long t = a[i]; a[i] = a[j]; a[j] = t; }
}
for (int len = 2; len <= n; len <<= 1) {
long w = inv ? pw(g, mod-1-(mod-1)/len, mod) : pw(g, (mod-1)/len, mod);
for (int i = 0; i < n; i += len) {
long wn = 1;
for (int j = 0; j < len/2; j++) {
long u = a[i+j], v = a[i+j+len/2] * wn % mod;
a[i+j] = (u+v) % mod;
a[i+j+len/2] = (u-v+mod) % mod;
wn = wn * w % mod;
}
}
}
if (inv) { long ni = pw(n, mod-2, mod); for (int i = 0; i < n; i++) a[i] = a[i]*ni%mod; }
}
static final long INV1 = pw(P[0], P[1]-2, P[1]);
static final long M12 = P[0]%P[2] * (P[1]%P[2]) % P[2];
static final long INV12 = pw(M12, P[2]-2, P[2]);
static long[] nttMul(long[] a, long[] b) {
int al = a.length, bl = b.length, sz = al+bl-1;
int n = 1; while (n < sz) n <<= 1;
long[][] rr = new long[3][sz];
for (int id = 0; id < 3; id++) {
System.arraycopy(a,0,ga,0,al); Arrays.fill(ga,al,n,0);
System.arraycopy(b,0,gb,0,bl); Arrays.fill(gb,bl,n,0);
ntt(ga,n,false,id); ntt(gb,n,false,id);
for (int i = 0; i < n; i++) ga[i] = ga[i]*gb[i]%P[id];
ntt(ga,n,true,id);
System.arraycopy(ga,0,rr[id],0,sz);
}
long[] res = new long[sz];
for (int i = 0; i < sz; i++) {
long a0=rr[0][i], a1=rr[1][i], a2=rr[2][i];
long t = (a1-a0%P[1]+P[1])%P[1]*INV1%P[1];
long x2 = (a0%P[2]+P[0]%P[2]*(t%P[2])%P[2])%P[2];
long s = (a2-x2+P[2])%P[2]*INV12%P[2];
long x = a0%MOD;
x = (x+P[0]%MOD*(t%MOD))%MOD;
x = (x+P[0]%MOD*(P[1]%MOD)%MOD*(s%MOD))%MOD;
res[i] = x;
}
return res;
}
static long[] naiveMul(long[] a, long[] b) {
long[] c = new long[a.length+b.length-1];
for (int i = 0; i < a.length; i++)
for (int j = 0; j < b.length; j++)
c[i+j] = (c[i+j]+a[i]*b[j])%MOD;
return c;
}
static long[] polyMul(long[] a, long[] b) {
if (Math.min(a.length, b.length) <= 64) return naiveMul(a, b);
return nttMul(a, b);
}
static long[] mulLinear(long[] f, long c) {
c %= MOD;
long[] g = new long[f.length+1];
for (int i = 0; i < f.length; i++) {
g[i] = (g[i]+f[i]*c)%MOD;
g[i+1] = f[i];
}
return g;
}
static long[] shift(long[] f, int c) {
int m = f.length - 1;
long[] A = new long[m+1];
for (int j = 0; j <= m; j++) A[j] = f[j]*fact[j]%MOD;
for (int i = 0, j = m; i < j; i++, j--) { long t=A[i]; A[i]=A[j]; A[j]=t; }
long[] C = new long[m+1];
long ck = 1, cm = c % MOD;
for (int k = 0; k <= m; k++) { C[k]=ck*ifact[k]%MOD; ck=ck*cm%MOD; }
long[] D = polyMul(A, C);
long[] g = new long[m+1];
for (int i = 0; i <= m; i++) g[i] = D[m-i]*ifact[i]%MOD;
return g;
}
static long[] compute(int n) {
if (n == 1) return new long[]{0, 1};
int m = n / 2;
long[] half = compute(m);
long[] shifted = shift(half, m);
long[] result = polyMul(half, shifted);
if (n % 2 == 1) result = mulLinear(result, n - 1);
return result;
}
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
int n = Integer.parseInt(br.readLine().trim());
StringTokenizer st = new StringTokenizer(br.readLine());
for (int i = 0; i < n; i++) st.nextToken();
initFact(n + 1);
long[] coef = compute(n);
long[] E = new long[n + 3];
for (int c = n; c >= 1; c--) {
E[c] = coef[c];
if (c+2 <= n) E[c] = (E[c]+E[c+2])%MOD;
}
StringBuilder sb = new StringBuilder();
for (int k = 1; k < n; k++) {
if (k > 1) sb.append(' ');
sb.append(E[n-k]);
}
System.out.println(sb);
}
}
알고리즘 풀이 블로그
728x90
반응형
'Problem Solving > Baekjoon' 카테고리의 다른 글
| [백준] 30178: Perfect Triples (Java, Python) (0) | 2026.02.18 |
|---|---|
| [백준] 2702: 초6 수학 (Java) (0) | 2026.02.17 |
| [백준] 1289: 트리의 가중치 (Java, Python) (0) | 2026.02.16 |
| [백준] 21609: 상어 중학교 [Java] (0) | 2024.10.17 |
| [백준] 15989: 1, 2, 3 더하기 4 [Java] (0) | 2024.10.15 |
reply
