2019 xdu网络赛F XDOJ 1409: 背包弹夹平底锅
网络赛最难的一个题,也是我唯一一个不会做的。
看了大大的题解发现我其实已经推到了倒数第二步T_T

但是眼拙没看出来可以写成卷积形式。。。所以按照那个推法无论如何都是O(n^2)的复杂度,变成卷积形式就可以用fft加速直接O(nlogn)一次性求出所有点的值了。。。
其实推前面的公式都不难,首先考虑从n里选出k种颜色,则有C_n^k 种抽法,然后考虑对于k种子弹在m个子弹中的分布,这是个可重集排列问题,可以用指数型母函数来求,用\[f(x)=x+1/2 x^2+1/3! x^3 +… \]来表示每种子弹取得情况(没有1是因为我已经抽好k种子弹了,他一定至少得出现一次),由泰勒公式可以知道\[f(x)= (e^x-1)\],然后有k种子弹所以分布情况的母函数就是\[f(x)^k = (e^x – 1)^k\], 把这个用二项式展开可以得到每一项为\[ C_k^i e^{ix}*(-1)^{k-i} \],然后\[e^(ix)= 1+(ix)+1/2*(ix)^2+1/6*(ix)^3… \]但是我们要求的只是其中幂次为m的项(因为总共的子弹数量为m),遍历每个二项式展开项,单独对x^m项的系数求和,就是对于一个特定的k的答案了,这样的话写出来是
然后这样对每个k都算一遍这个式子的话肯定不行,而且这个表达式没法递推(至少我找不到递推方法也可能是我太菜了)复杂度怎么都是n^2,但是把他改写一下,把与j无关的项目提出来可以变成
\[ \frac{n!}{(n-k)!}*\sum_{j=0}^k \frac{j^m}{j!}*\frac{(-1)^{k-j}}{(k-j)!} \]这样就可以发现后面的项变成了一个卷积,就可以用FFT在O(nlogn)时间快速计算出所有k的答案了。。。
不过我直接用fft貌似会爆精度,于是改成NTT(感觉98244353这个模就是专用ntt emm)
#include<bits/stdc++.h>
#include <stdio.h>
#include <string.h>
#include <iostream>
#include <algorithm>
#include <math.h>
#define G 3 //998244353 的原根
typedef long long ll;
const int maxn=4e5+5;
const int mod=998244353;
using namespace std;
const double PI = acos(-1.0);
ll a[maxn*3+10],b[maxn*3+10];
int sum[maxn];
ll fac[maxn],inv[maxn],onv[maxn];
ll n ,m;
void init()
{
inv[0] = fac[0] = inv[1] = onv[0]=onv[1]=1;
for(int i=1;i<maxn;i++) fac[i] = fac[i-1]*i%mod;
for(int i = 2;i<maxn;++i) inv[i] = (mod-mod/i)*inv[mod%i] % mod;
for(int i=2;i<maxn;i++) onv[i] = inv[i]*onv[i-1]%mod;
}
ll C(int x,int y)
{
return fac[x]*onv[y]%mod*onv[x-y]%mod;
}
ll po(int a,int k){
ll res=1,base=a;
while(k){
if(k&1){
res=(res*base)%mod;
}
k>>=1;
base=(base*base)%mod;
}
return res;
}
void ntt(ll *a,int N,int f){
int i,j=0,t,k;
for(i=1;i<N-1;i++){
for(t=N;j^=t>>=1,~j&t;);
if(i<j){
swap(a[i],a[j]);
}
}
for(i=1;i<N;i<<=1){
t=i<<1;
ll wn=po(G,(mod-1)/t);
for(j=0;j<N;j+=t){
ll w=1;
for(k=0;k<i;k++,w=1ll*w*wn%mod){
ll x=a[j+k],y=1ll*w*a[j+k+i]%mod;
a[j+k]=(x+y)%mod,a[j+k+i]=(x-y+mod)%mod;
}
}
}
if(f==-1){
reverse(a+1,a+N);
ll iinv=po(N,mod-2);
for(i=0;i<N;i++)
a[i]=1ll*a[i]*iinv%mod;
}
}
int main()
{
init();
while(~scanf("%lld%lld",&n,&m)){
int up=min(n,m);
int len = 1;
while(len < up*2 )len<<=1;
for(int i = 0;i <= up;i++)
a[i] = (((i&1)?-1:1)*onv[i])%mod;
for(int i = up+1;i < len;i++)
a[i] = 0;
for(int i = 0;i <= up;i++)
b[i] = po(i,m)*onv[i]%mod;
for(int i = up+1;i < len;i++)
b[i] = 0;
ntt(a,len,1);
ntt(b,len,1);
for(int i = 0;i < len;i++)
a[i] = a[i]*b[i]%mod;
ntt(a,len,-1);
for(int i = 0;i <= up;i++)
sum[i] = ((fac[n]*onv[n-i]%mod)*(a[i]))%mod;
ll tt=po(inv[n],m);
for(int i=1;i<=up;i++){
printf("%lld",(mod+tt*sum[i]%mod)%mod);
if(i!=up)printf(" ");
}
puts("");
}
return 0;
}
