[比赛补题][FFT][组合数学][母函数][×]2019 xdu网络赛F XDOJ 1409: 背包弹夹平底锅

2019 xdu网络赛F XDOJ 1409: 背包弹夹平底锅

题目链接

网络赛最难的一个题,也是我唯一一个不会做的。

看了大大的题解发现我其实已经推到了倒数第二步T_T

但是眼拙没看出来可以写成卷积形式。。。所以按照那个推法无论如何都是O(n^2)的复杂度,变成卷积形式就可以用fft加速直接O(nlogn)一次性求出所有点的值了。。。

其实推前面的公式都不难,首先考虑从n里选出k种颜色,则有C_n^k 种抽法,然后考虑对于k种子弹在m个子弹中的分布,这是个可重集排列问题,可以用指数型母函数来求,用\[f(x)=x+1/2 x^2+1/3! x^3 +… \]来表示每种子弹取得情况(没有1是因为我已经抽好k种子弹了,他一定至少得出现一次),由泰勒公式可以知道\[f(x)= (e^x-1)\],然后有k种子弹所以分布情况的母函数就是\[f(x)^k = (e^x – 1)^k\], 把这个用二项式展开可以得到每一项为\[ C_k^i e^{ix}*(-1)^{k-i} \],然后\[e^(ix)= 1+(ix)+1/2*(ix)^2+1/6*(ix)^3… \]但是我们要求的只是其中幂次为m的项(因为总共的子弹数量为m),遍历每个二项式展开项,单独对x^m项的系数求和,就是对于一个特定的k的答案了,这样的话写出来是

\[ C_n^k * \sum_{j=0}^k C_k^j (j)^m * (-1)^{k-j} \]

然后这样对每个k都算一遍这个式子的话肯定不行,而且这个表达式没法递推(至少我找不到递推方法也可能是我太菜了)复杂度怎么都是n^2,但是把他改写一下,把与j无关的项目提出来可以变成

\[ \frac{n!}{(n-k)!}*\sum_{j=0}^k \frac{j^m}{j!}*\frac{(-1)^{k-j}}{(k-j)!} \]

这样就可以发现后面的项变成了一个卷积,就可以用FFT在O(nlogn)时间快速计算出所有k的答案了。。。
不过我直接用fft貌似会爆精度,于是改成NTT(感觉98244353这个模就是专用ntt emm)

#include<bits/stdc++.h>
#include <stdio.h>
#include <string.h>
#include <iostream>
#include <algorithm>
#include <math.h>
#define G 3 //998244353 的原根

typedef long long ll;
const int maxn=4e5+5;
const int mod=998244353;

using namespace std;

const double PI = acos(-1.0);
ll a[maxn*3+10],b[maxn*3+10];
int sum[maxn];
ll fac[maxn],inv[maxn],onv[maxn];
ll n ,m;
void init()
{
    inv[0] = fac[0] = inv[1] = onv[0]=onv[1]=1;
    for(int i=1;i<maxn;i++) fac[i] = fac[i-1]*i%mod;
    for(int i = 2;i<maxn;++i) inv[i] = (mod-mod/i)*inv[mod%i] % mod;
    for(int i=2;i<maxn;i++) onv[i] = inv[i]*onv[i-1]%mod;
}

ll C(int x,int y)
{
    return fac[x]*onv[y]%mod*onv[x-y]%mod;
}

ll po(int a,int k){
    ll res=1,base=a;
    while(k){
        if(k&1){
            res=(res*base)%mod;
        }
        k>>=1;
        base=(base*base)%mod;
    }
    return res;
}

void ntt(ll *a,int N,int f){
    int i,j=0,t,k;
    for(i=1;i<N-1;i++){
        for(t=N;j^=t>>=1,~j&t;);
        if(i<j){
            swap(a[i],a[j]);
        }
    }
    for(i=1;i<N;i<<=1){
        t=i<<1;
        ll wn=po(G,(mod-1)/t);
        for(j=0;j<N;j+=t){
            ll w=1;
            for(k=0;k<i;k++,w=1ll*w*wn%mod){
                ll x=a[j+k],y=1ll*w*a[j+k+i]%mod;
                a[j+k]=(x+y)%mod,a[j+k+i]=(x-y+mod)%mod;
            }
        }
    }
    if(f==-1){
        reverse(a+1,a+N);
        ll iinv=po(N,mod-2);
        for(i=0;i<N;i++)
            a[i]=1ll*a[i]*iinv%mod;
    }
}

int main()
{
    init();
    while(~scanf("%lld%lld",&n,&m)){
        int up=min(n,m);
        int len = 1;
        while(len < up*2 )len<<=1;
        for(int i = 0;i <= up;i++)
            a[i] =  (((i&1)?-1:1)*onv[i])%mod;
        for(int i = up+1;i < len;i++)
            a[i] = 0;
        for(int i = 0;i <= up;i++)
            b[i] = po(i,m)*onv[i]%mod;
        for(int i = up+1;i < len;i++)
            b[i] = 0;
        ntt(a,len,1);
        ntt(b,len,1);
        for(int i = 0;i < len;i++)
            a[i] = a[i]*b[i]%mod;
        ntt(a,len,-1);
        for(int i = 0;i <= up;i++)
            sum[i] = ((fac[n]*onv[n-i]%mod)*(a[i]))%mod;
        ll tt=po(inv[n],m);
        for(int i=1;i<=up;i++){
            printf("%lld",(mod+tt*sum[i]%mod)%mod);
            if(i!=up)printf(" ");
        }
        puts("");
    }
    return 0;
}

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

此站点使用 Akismet 来减少垃圾评论。了解我们如何处理您的评论数据