View

반응형

순열 교환

백준 다항식 · NTT 스털링 수 더블링 Java

문제 분석

  • 순열 A가 주어졌을 때, 정확히 k번 교환해서 A를 만드는 서로 다른 순열 B의 개수
  • k = 1부터 n−1까지 모든 값에 대해 답을 구해야 함
  • 제약: N ≤ 100,000 / 답은 mod 109+7

핵심 관찰: 답은 A에 무관하다! → “정확히 k개 전치의 곱으로 표현 가능한 순열의 수”와 동치

접근법

1. 수학적 환원

  • 순열 σ의 사이클 수 = c일 때, 최소 전치 수 = n − c
  • σ가 정확히 k개 전치의 곱 ⇔ n − c ≤ k and n − c ≡ k (mod 2)
  • f(k) = Σ S(n, j)   (j ≥ n−k, j ≡ n−k mod 2)

여기서 S(n, j)는 제1종 부호없는 스털링 수 (사이클 수가 j인 순열의 수)

2. 스털링 수 계산

  • P(x) = x(x+1)(x+2)···(x+n−1) 의 xj 계수 = S(n, j)
  • 이 다항식을 효율적으로 계산해야 함

3. 더블링 — O(N log N)

핵심 아이디어: x(2m) = x(m) · (x+m)(m)

  1. Pm(x) = x(x+1)···(x+m−1) 을 재귀로 구함
  2. Pm(x+m) 은 다항식 shift로 O(m log m)에 계산
  3. Pm(x) × Pm(x+m) = P2m(x)
  4. n이 홀수면 마지막에 (x+n−1) 하나 더 곱함

T(n) = T(n/2) + O(n log n) → O(n log n) (기하급수 감소)

4. 다항식 shift

f(x) → f(x+c) 변환:

  • Aj = f[j] · j! 을 뒤집고, Ck = ck/k! 과 컨볼루션
  • bi = D[m−i] / i!   (D = reverse(A) * C)
  • 한 번의 다항식 곱셈 = O(m log m)

5. 3-mod NTT

109+7은 NTT 비친화적 → 998244353, 985661441, 754974721 세 소수로 NTT 후 CRT로 합산

6. 최종 답 산출

패리티별 접미사 합: E[c] = S(n,c) + E[c+2],   f(k) = E[n−k]

예제 트레이스 (n=3)

P(x) = x(x+1)(x+2) = x³ + 3x² + 2x
→ S(3,1)=2, S(3,2)=3, S(3,3)=1

접미사 합 (패리티별):
  E[3] = S(3,3) = 1
  E[2] = S(3,2) + 0 = 3       (E[4]=0)
  E[1] = S(3,1) + E[3] = 2+1 = 3

답:
  k=1: E[3-1] = E[2] = 3
  k=2: E[3-2] = E[1] = 3

출력: 3 3 ✓

검증 (k=1):
  (1,3,2), (2,1,3), (3,2,1) → 1번 교환으로 (3,1,2) 가능 = 3개 ✓
시간 복잡도
O(N log N)
공간 복잡도
O(N)

삽질 기록

시도 방식 결과
1차분할정복 O(N log²N) + 3-mod NTTTLE
2차더블링 O(N log N) + naive 혼합AC

Java에서 3-mod NTT 상수가 무거워서 log 하나 차이가 AC/TLE를 갈랐다.

클린 코드 — Java

import java.io.*;
import java.util.*;

public class Main {
    static final long MOD = 1_000_000_007L;
    static final long[] P = {998244353, 985661441, 754974721};
    static final long[] G = {3, 3, 11};

    static long pw(long a, long b, long m) {
        long r = 1; a %= m;
        for (; b > 0; b >>= 1, a = a * a % m)
            if ((b & 1) == 1) r = r * a % m;
        return r;
    }

    static long[] fact, ifact;
    static void initFact(int n) {
        fact = new long[n + 1]; ifact = new long[n + 1];
        fact[0] = 1;
        for (int i = 1; i <= n; i++) fact[i] = fact[i-1] * i % MOD;
        ifact[n] = pw(fact[n], MOD - 2, MOD);
        for (int i = n - 1; i >= 0; i--) ifact[i] = ifact[i+1] * (i+1) % MOD;
    }

    static final int MAXN = 1 << 18;
    static long[] ga = new long[MAXN], gb = new long[MAXN];

    static void ntt(long[] a, int n, boolean inv, int id) {
        long mod = P[id], g = G[id];
        for (int i = 1, j = 0; i < n; i++) {
            int bit = n >> 1;
            for (; (j & bit) != 0; bit >>= 1) j ^= bit;
            j ^= bit;
            if (i < j) { long t = a[i]; a[i] = a[j]; a[j] = t; }
        }
        for (int len = 2; len <= n; len <<= 1) {
            long w = inv ? pw(g, mod-1-(mod-1)/len, mod) : pw(g, (mod-1)/len, mod);
            for (int i = 0; i < n; i += len) {
                long wn = 1;
                for (int j = 0; j < len/2; j++) {
                    long u = a[i+j], v = a[i+j+len/2] * wn % mod;
                    a[i+j] = (u+v) % mod;
                    a[i+j+len/2] = (u-v+mod) % mod;
                    wn = wn * w % mod;
                }
            }
        }
        if (inv) { long ni = pw(n, mod-2, mod); for (int i = 0; i < n; i++) a[i] = a[i]*ni%mod; }
    }

    static final long INV1 = pw(P[0], P[1]-2, P[1]);
    static final long M12 = P[0]%P[2] * (P[1]%P[2]) % P[2];
    static final long INV12 = pw(M12, P[2]-2, P[2]);

    static long[] nttMul(long[] a, long[] b) {
        int al = a.length, bl = b.length, sz = al+bl-1;
        int n = 1; while (n < sz) n <<= 1;
        long[][] rr = new long[3][sz];
        for (int id = 0; id < 3; id++) {
            System.arraycopy(a,0,ga,0,al); Arrays.fill(ga,al,n,0);
            System.arraycopy(b,0,gb,0,bl); Arrays.fill(gb,bl,n,0);
            ntt(ga,n,false,id); ntt(gb,n,false,id);
            for (int i = 0; i < n; i++) ga[i] = ga[i]*gb[i]%P[id];
            ntt(ga,n,true,id);
            System.arraycopy(ga,0,rr[id],0,sz);
        }
        long[] res = new long[sz];
        for (int i = 0; i < sz; i++) {
            long a0=rr[0][i], a1=rr[1][i], a2=rr[2][i];
            long t = (a1-a0%P[1]+P[1])%P[1]*INV1%P[1];
            long x2 = (a0%P[2]+P[0]%P[2]*(t%P[2])%P[2])%P[2];
            long s = (a2-x2+P[2])%P[2]*INV12%P[2];
            long x = a0%MOD;
            x = (x+P[0]%MOD*(t%MOD))%MOD;
            x = (x+P[0]%MOD*(P[1]%MOD)%MOD*(s%MOD))%MOD;
            res[i] = x;
        }
        return res;
    }

    static long[] naiveMul(long[] a, long[] b) {
        long[] c = new long[a.length+b.length-1];
        for (int i = 0; i < a.length; i++)
            for (int j = 0; j < b.length; j++)
                c[i+j] = (c[i+j]+a[i]*b[j])%MOD;
        return c;
    }

    static long[] polyMul(long[] a, long[] b) {
        if (Math.min(a.length, b.length) <= 64) return naiveMul(a, b);
        return nttMul(a, b);
    }

    static long[] mulLinear(long[] f, long c) {
        c %= MOD;
        long[] g = new long[f.length+1];
        for (int i = 0; i < f.length; i++) {
            g[i] = (g[i]+f[i]*c)%MOD;
            g[i+1] = f[i];
        }
        return g;
    }

    static long[] shift(long[] f, int c) {
        int m = f.length - 1;
        long[] A = new long[m+1];
        for (int j = 0; j <= m; j++) A[j] = f[j]*fact[j]%MOD;
        for (int i = 0, j = m; i < j; i++, j--) { long t=A[i]; A[i]=A[j]; A[j]=t; }
        long[] C = new long[m+1];
        long ck = 1, cm = c % MOD;
        for (int k = 0; k <= m; k++) { C[k]=ck*ifact[k]%MOD; ck=ck*cm%MOD; }
        long[] D = polyMul(A, C);
        long[] g = new long[m+1];
        for (int i = 0; i <= m; i++) g[i] = D[m-i]*ifact[i]%MOD;
        return g;
    }

    static long[] compute(int n) {
        if (n == 1) return new long[]{0, 1};
        int m = n / 2;
        long[] half = compute(m);
        long[] shifted = shift(half, m);
        long[] result = polyMul(half, shifted);
        if (n % 2 == 1) result = mulLinear(result, n - 1);
        return result;
    }

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        int n = Integer.parseInt(br.readLine().trim());
        StringTokenizer st = new StringTokenizer(br.readLine());
        for (int i = 0; i < n; i++) st.nextToken();
        initFact(n + 1);
        long[] coef = compute(n);
        long[] E = new long[n + 3];
        for (int c = n; c >= 1; c--) {
            E[c] = coef[c];
            if (c+2 <= n) E[c] = (E[c]+E[c+2])%MOD;
        }
        StringBuilder sb = new StringBuilder();
        for (int k = 1; k < n; k++) {
            if (k > 1) sb.append(' ');
            sb.append(E[n-k]);
        }
        System.out.println(sb);
    }
}
알고리즘 풀이 블로그
728x90
반응형
Share Link
reply
«   2026/03   »
1 2 3 4 5 6 7
8 9 10 11 12 13 14
15 16 17 18 19 20 21
22 23 24 25 26 27 28
29 30 31