题解 BZOJ 3586 字符串生成器

题意:

  有一个字符串生成器,初始时生成的字符串为空串,它每次按照给定概率随机生成一个小写字母,加在当前已生成字符串的后面。
  给定N个长度为L的字符串,每个字符串由小写字母组成。
  如果在某个时候,发现每个给定字符串都在当前已生成的字符串中作为子串出现过,生成器就会停下来,将当前生成的字符串作为输出。
  求输出字符串长度的期望值。

题解:

考虑状压dp。

分层dp,同层间有环,不同层间是DAG。

设$f[i][S] $为从trie图的i节点开始,还需要出现$S$这些字符串的期望长度。

初始条件:

$f[i][0]=0$

转移:

$ f[i][S]=f[i][S-i] $   $(if i为结尾节点且i∈S)$

$ f[i][S] = \sum_{j} f[trie[i][j]][S]*p[j]+1$  $else$

$trie[i][j]$为trie图上的边。(补出fail节点的后继)
然后${S}$从小到大枚举就好了。

(高斯消元掉了发精度。。囧 ..

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>
#include<cmath>
using namespace std;
typedef long double ld;
int tt,n,l,t;
char s[10];
int trie[55][26],pos[55],tot,fail[55];
ld f[55][356],p[26],a[55][55];
queue<int> q;
void mkfail(){
q.push(0);
while(!q.empty()){
int x=q.front();q.pop();
for(int i=0;i<t;i++){
if(trie[x][i]){
if(x) fail[trie[x][i]]=trie[fail[x]][i];
q.push(trie[x][i]);
}else trie[x][i]=trie[fail[x]][i];
}
}
}
void Gauss(int n){
for(int i=0;i<=n;i++){
ld mx=0;int pos;
for(int j=i;j<=n;j++){
if(fabs(a[j][i])>mx){
mx=fabs(a[j][i]);
pos=j;
}
}
for(int j=i;j<=n+1;j++) swap(a[i][j],a[pos][j]);
for(int j=i+1;j<=n;j++){
ld f=a[j][i]/a[i][i];
for(int k=i;k<=tot+1;k++){
a[j][k]-=f*a[i][k];
}
}
}
for(int i=n;i>=0;i--){
for(int j=i+1;j<=n;j++){
a[i][tot+1]-=a[i][j]*a[j][tot+1];
}
a[i][tot+1]/=a[i][i];
}
}

void calc(int x){
memset(a,0,sizeof a);
for(int i=0;i<=tot;i++){
a[i][i]=1;
if(pos[i]!=-1&&((1<<pos[i])&x)) a[i][tot+1]=f[i][x^(1<<pos[i])];
else{
a[i][tot+1]=1;
for(int j=0;j<t;j++){
a[i][trie[i][j]]-=p[j];
}
}
}
Gauss(tot);
for(int i=0;i<=tot;i++) f[i][x]=a[i][tot+1];
}
int main(){
scanf("%d",&tt);
while(tt--){
memset(trie,0,sizeof trie);
memset(pos,-1,sizeof pos);
memset(fail,0,sizeof fail);
memset(f,0,sizeof f);
tot=0;
scanf("%d%d%d",&n,&l,&t);
for(int i=0;i<n;i++){
scanf("%s",s);
int id=0;
for(int j=0;s[j];j++){
if(trie[id][s[j]-'a']) id=trie[id][s[j]-'a'];
else id=trie[id][s[j]-'a']=++tot;
}
pos[id]=i;
}
for(int i=0,x;i<t;i++){
scanf("%d",&x);
p[i]=x/10000.0;
}
mkfail();
for(int i=1;i<(1<<n);i++){
calc(i);
}
printf("%.10f\n",(double)f[0][(1<<n)-1]);
}
return 0;
}
评论小助手

评论正在加载...

Tip: 点击下方链接切换到 Disqus 评论框可以获得邮件提醒哦
🗣️ 加载 Disqus 评论框