2019 xdu网络赛F XDOJ 1409: 背包弹夹平底锅
网络赛最难的一个题,也是我唯一一个不会做的。
看了大大的题解发现我其实已经推到了倒数第二步T_T
但是眼拙没看出来可以写成卷积形式。。。所以按照那个推法无论如何都是O(n^2)的复杂度,变成卷积形式就可以用fft加速直接O(nlogn)一次性求出所有点的值了。。。
其实推前面的公式都不难,首先考虑从n里选出k种颜色,则有C_n^k 种抽法,然后考虑对于k种子弹在m个子弹中的分布,这是个可重集排列问题,可以用指数型母函数来求,用\[f(x)=x+1/2 x^2+1/3! x^3 +… \]来表示每种子弹取得情况(没有1是因为我已经抽好k种子弹了,他一定至少得出现一次),由泰勒公式可以知道\[f(x)= (e^x-1)\],然后有k种子弹所以分布情况的母函数就是\[f(x)^k = (e^x – 1)^k\], 把这个用二项式展开可以得到每一项为\[ C_k^i e^{ix}*(-1)^{k-i} \],然后\[e^(ix)= 1+(ix)+1/2*(ix)^2+1/6*(ix)^3… \]但是我们要求的只是其中幂次为m的项(因为总共的子弹数量为m),遍历每个二项式展开项,单独对x^m项的系数求和,就是对于一个特定的k的答案了,这样的话写出来是
然后这样对每个k都算一遍这个式子的话肯定不行,而且这个表达式没法递推(至少我找不到递推方法也可能是我太菜了)复杂度怎么都是n^2,但是把他改写一下,把与j无关的项目提出来可以变成
\[ \frac{n!}{(n-k)!}*\sum_{j=0}^k \frac{j^m}{j!}*\frac{(-1)^{k-j}}{(k-j)!} \]这样就可以发现后面的项变成了一个卷积,就可以用FFT在O(nlogn)时间快速计算出所有k的答案了。。。
不过我直接用fft貌似会爆精度,于是改成NTT(感觉98244353这个模就是专用ntt emm)
#include<bits/stdc++.h> #include <stdio.h> #include <string.h> #include <iostream> #include <algorithm> #include <math.h> #define G 3 //998244353 的原根 typedef long long ll; const int maxn=4e5+5; const int mod=998244353; using namespace std; const double PI = acos(-1.0); ll a[maxn*3+10],b[maxn*3+10]; int sum[maxn]; ll fac[maxn],inv[maxn],onv[maxn]; ll n ,m; void init() { inv[0] = fac[0] = inv[1] = onv[0]=onv[1]=1; for(int i=1;i<maxn;i++) fac[i] = fac[i-1]*i%mod; for(int i = 2;i<maxn;++i) inv[i] = (mod-mod/i)*inv[mod%i] % mod; for(int i=2;i<maxn;i++) onv[i] = inv[i]*onv[i-1]%mod; } ll C(int x,int y) { return fac[x]*onv[y]%mod*onv[x-y]%mod; } ll po(int a,int k){ ll res=1,base=a; while(k){ if(k&1){ res=(res*base)%mod; } k>>=1; base=(base*base)%mod; } return res; } void ntt(ll *a,int N,int f){ int i,j=0,t,k; for(i=1;i<N-1;i++){ for(t=N;j^=t>>=1,~j&t;); if(i<j){ swap(a[i],a[j]); } } for(i=1;i<N;i<<=1){ t=i<<1; ll wn=po(G,(mod-1)/t); for(j=0;j<N;j+=t){ ll w=1; for(k=0;k<i;k++,w=1ll*w*wn%mod){ ll x=a[j+k],y=1ll*w*a[j+k+i]%mod; a[j+k]=(x+y)%mod,a[j+k+i]=(x-y+mod)%mod; } } } if(f==-1){ reverse(a+1,a+N); ll iinv=po(N,mod-2); for(i=0;i<N;i++) a[i]=1ll*a[i]*iinv%mod; } } int main() { init(); while(~scanf("%lld%lld",&n,&m)){ int up=min(n,m); int len = 1; while(len < up*2 )len<<=1; for(int i = 0;i <= up;i++) a[i] = (((i&1)?-1:1)*onv[i])%mod; for(int i = up+1;i < len;i++) a[i] = 0; for(int i = 0;i <= up;i++) b[i] = po(i,m)*onv[i]%mod; for(int i = up+1;i < len;i++) b[i] = 0; ntt(a,len,1); ntt(b,len,1); for(int i = 0;i < len;i++) a[i] = a[i]*b[i]%mod; ntt(a,len,-1); for(int i = 0;i <= up;i++) sum[i] = ((fac[n]*onv[n-i]%mod)*(a[i]))%mod; ll tt=po(inv[n],m); for(int i=1;i<=up;i++){ printf("%lld",(mod+tt*sum[i]%mod)%mod); if(i!=up)printf(" "); } puts(""); } return 0; }