又是一个Trie图上找递推关系的题~只不过这次是单字符串做一个Trie图(这样的话用扩展kmp也可以8
把样例研究一下再结合Trie图基本就能确定出等式关系了,设dp[i]为
现在在Trie图上的节点i,还需要走多少步(期望意义下)才能凑出给定的字符串
容易知道当走到尽头的时候dp[n]=0;其他的项地推关系可以表示为dp[u]=1+\sum i/n*dp[v] ,其中v是u走向的下一个节点。
以样例n=2,s=”ABA”为例,Tried图和方程组如下
解这个方程组可得dp[0]=5,就是我们要的结果了。
具体代码实现的时候,直接对每个结点都把转移记录下来,然后用高斯消元解方程即可。注意应该把n乘到等式左边,然后用整数的高斯消元,不然会有误差wa。
#include<bits/stdc++.h> #define eps 1e-15 using namespace std; const int maxn=200+15; const int INF=0x3f3f3f3f; const int mod=1e9+7; const int maxnode=500; typedef long long LL; typedef unsigned long long ULL; long double A[maxn][maxn],X[maxn]; int equ,var; void Gauess(int n){ for(int i = 0; i < n; ++i){ int r = i; while(r < n && !A[r][i]) ++r; if(r != i) for(int j = 0; j <= n; ++j) swap(A[r][j], A[i][j]); for(int k = i+1; k < n; ++k) if(A[k][i]){ LL f = A[k][i]; for(int j = i; j <= n; ++j) A[k][j] = A[k][j] * A[i][i] - f * A[i][j]; } } for(int i = n-1; i >= 0; --i){ for(int j = i+1; j < n; ++j) A[i][n] -= A[j][n] * A[i][j]; A[i][n] /= A[i][i]; } } struct Trie { int sz,root,up;//sz include root but root's number is 0 int ch[maxnode][27],isend[maxnode],fail[maxnode]; void reset(int n) { sz=0;root=newnode();memset(fail,-1,sizeof(fail));up=n; } int newnode() { memset(ch[sz],-1,sizeof(ch[sz])); isend[sz++]=0; return sz-1; } void insert(char* s) { int len=strlen(s),u=root,c; for(int i=0;i<len;i++) { c=s[i]-'A'; if(ch[u][c]==-1) ch[u][c]=newnode(); u=ch[u][c]; } isend[u]=1; } void build() { int u; queue<int> Q; fail[root]=root; for(int i=0;i<up;i++) { if(ch[root][i]==-1) ch[root][i]=root; else { fail[ch[root][i]]=root; Q.push(ch[root][i]); } } while(!Q.empty()) { u=Q.front();Q.pop(); for(int i=0;i<up;i++) { if(ch[u][i]==-1) ch[u][i]=ch[fail[u]][i]; else { fail[ch[u][i]]=ch[fail[u]][i]; Q.push(ch[u][i]); } } } } void bfs() { int nxt; //double mul=1.0/up; for(int i=0;i<sz-1;i++) { A[i][i]=up; for(int j=0;j<up;j++) { nxt=ch[i][j]; A[i][nxt]-=1; } } A[sz-1][sz-1]=1;//X[sz-1]=0; } }; Trie T; int main() { int t,cs=0,n; char str[15]; scanf("%d",&t); while(t--) { if(cs) puts(""); scanf("%d %s",&n,str); cs++; printf("Case %d:\n",cs); memset(A,0,sizeof(A)); T.reset(n); T.insert(str); T.build(); T.bfs(); int len=strlen(str); for(int i=0;i<len;i++) X[i]=n; X[len]=0; equ=var=len+1; for(int i=0;i<len;i++) A[i][len+1]=n; Gauess(len+1); printf("%lld\n",(LL)(A[0][len+1]+0.5)); } return 0; }