概率生成函数([CTSC2006] 歌唱王国 题解)

发布时间 2024-01-11 07:55:25作者: Liquefyx

定义

概率生成函数(Probability Generating Function,PGF,通常情况下以随机变量 \(X\) 在非负整数集合 \(\mathbb{N}\) 上为前提),顾名思义,即为概率的生成函数,或者说,如果数列 \(\{p_n\}\) 满足 \(P(X=i)=p_i\)(即 \(\{p_n \}\)\(X\) 的概率质量函数 PMF 所构成的数列),那么有概率生成函数:

\[F_X(x)=\sum^{+\infty}_{i=0}P(X=i)x^i \]

其中,\(P(X=i)\) 表示随机变量 \(X\)\(i\) 的概率。

一些性质

  1. 显然 \(F_X(1)=\sum\limits^{+\infty}_{i=0} P(X=i)=1\)(概率之和为 \(1\));

  2. \(F_X(x)\) 求导可以得到 \(F_X'(x)=\sum\limits^{+\infty}_{i=0}iP(X=i)x^{i-1}\),那么此时 \(F_X'(1)=\sum\limits^{+\infty}_{i=0} iP(X=i)=E(X)\)\(X\) 的期望为 \(E(X)\));

  3. \(F_X(x)\) 求二阶导可得 \(F_X'(x)=\sum\limits^{+\infty}_{i=0}i(i-1)P(X=i)x^{i-2}\),那么 \(F_X''(1)=\sum\limits^{+\infty}_{i=0} i^2P(X=i)-\sum\limits^{+\infty}_{i=0} iP(X=i)=E(X^2)-E(X)\) ,所以 \(E(X^2)=F_X''(1)+E(X)=F_X''(1)+F_X'(1)\),可得 \(D(X)=E(X^2)-(E(X))^2=F_X''(1)+F_X'(1)(1-F_X'(1))\)\(X\) 的方差为 \(D(X)\));

这些性质可以用来简化一些推导 QwQ。

P4548 [CTSC2006] 歌唱王国(芝士一道可用概率生成函数推导的字符串奇妙题目)

题意简述

字符集大小为 \(n\),处理 \(t\) 组数据,每组数据给出一个长度为 \(m\) 的字符集 \(S\),空串 \(T\) 每次随机加入一个字符,直到 \(T\) 中出现 \(S\) 停止,求 \(T\) 停止时长度的期望。

做法

显然不能直接硬算,先设变量 \(Y\) 表示停止时 \(T\) 的长度,再引入两个概率生成函数 \(f(x)\)\(g(x)\) 分别表示 \(Y=i\) 的概率和 \(Y>i\) 的概率,即 \(f(x)=\sum^{+\infty}_{i=0}P(Y=i)x^i\)\(g(x)=\sum^{+\infty}_{i=0}P(Y>i)x^i\),为了方便我们分别用 \(f_i\)\(g_i\) 表示 \(f\)\(g\)\(i\) 次项的系数,那么 \(f_i\) 就表示在第 \(i\) 次添加停止的概率,\(g_i\) 就表示在第 \(i\) 次添加仍未停止的概率,最终我们要求的期望 \(E(Y)=f'(1)\),接下来的任务便是找式子求期望了。

第一个式子——转化期望

首先我们可以根据 \(f\)\(g\) 的定义来确定一个递推式,因为 \(Y\) 一定是一个正整数,所以 \(Y>i\) 等价于 \(Y\ge i+1\),手动分类讨论一下(等于和大于的情况)可得 \(g_i=f_{i+1}+g_{i+1}\),整理成生成函数得到 \(xg(x)+1=f(x)+g(x)\)\(+1\) 是补上 \(g_0=1\) 的项),由于我们要求 \(f'(1)\),那么就可以用 \(g\) 来表示 \(f\),即

\[f(x)=(x-1)g(x)+1 \]

求个导可得 \(f'(x)=g(x)+(x-1)g'(x)\),最终我们有 \(f'(1)=g(1)\),皆大欢喜 ?,这意味着只要我们能求出 \(g(1)\) 就能得到答案了。

第二个式子——求 \(g(1)\)

接下来我们又设一个数列 \(h_i\) 表示在第 \(i\) 次操作后仍未结束,强制性地给 \(T\) 随机加上长度为 \(m\) 的串 \(S\) 的概率(强制表示可能有尚未加满 \(m\) 个字符就已经出现了 \(S\) 的情况,但仍然继续加字符直到加满了 \(m\) 个字符,并且根据定义,加入的 \(m\) 个字符能组成串 \(S\)),那么就相当于我们在 \(g_i\) 的情况下强制加了 \(m\) 个字符,于是能得到 \(h_i=g_i\times n^{-m}\)

但同时我们能发现此时能够满足 \(h_i\) 条件的 \(Y\) 应该有 \(i<Y\le i+m\),对于 \(Y\) 的每一种值我们可以单独讨论,假设 \(Y=y\)\(t=y-i\)\(0<t\le m\)),前提概率便是 \(f_y\),我们只需求再补上串 \(S\) 的长为 \(m-t\) 的后缀使得加入这 \(m-t\) 个字符后最后 \(m\) 个字符能形成串 \(S\)\(f_y\times n^{t-m}\) 之和就能得到 \(h_i\),由于在第 \(y\) 次加入时已经满足出现长度为 \(m\) 的子串 \(S\),那么显然第 \(i+1\sim y\) 次加入的字符构成的串同时为串 \(S\) 的前、后缀,换种说法即为 \(t\) 属于串 \(S\) 的 Border 或者 \(t=m\)(这么说是因为 Border 的定义要求字符串本身不为该字符串的 Border),形式化地:\(h_i=\sum\limits^m_{t=1} pd(t)f_y\times n^{t-m}\),其中 \(pd(t)\) 取值为 \(0\)\(1\),表示判断 \(t\) 是否满足 \(t=m\)\(t\)\(S\) 的 Border 的条件。

结合这两种 \(h_i\) 的求法,可得到 \(g_i\times n^{-m}=\sum\limits^m_{t=1} pd(t)f_y\times n^{t-m}=n^{-m}\times\big(\sum\limits^m_{t=1} pd(t)f_y\times n^t\big)\),又知道 \(y=i+t\),于是 \(g_i=\sum\limits^m_{t=1} pd(t)f_{i+t}\times n^t\),转化为生成函数形式:

\[x^{m}g(x)=\sum^m_{t=1}x^{m-t}f(x)pd(t)\times n^t \]

代入可得 \(g(1)=\sum\limits^m_{t=1}f(1)pd(t)\times n^t=\sum\limits^m_{t=1}pd(t)\times n^t\),综上,最终的答案便为 \(\sum\limits^m_{t=1}pd(t)\times n^t\),直接上 KMP 就做完啦,完结撒花✿✿ヽ(°▽°)ノ✿。

上代码
#include <bits/stdc++.h>
#define mod 10000
#define File(xxx) freopen(xxx".in","r",stdin),freopen(xxx".out","w",stdout)
using namespace std;
typedef long long LL;
const int N = 1e5+5;
int k, t, n, nxt[N], a[N], vis[N], pw[N], ans;

template <typename T> inline void debug(T x) { cerr<<x<<'\n'; }
template <typename T, typename ...T_> inline void debug(T x, T_ ...p) { cerr<<x<<' ', debug(p...); }
template <typename T> void read(T& x) {
	x = 0; int f = 0; char c = getchar();
	while(c < '0' || c > '9') f |= (c == '-'), c=getchar();
	while(c >= '0' && c <= '9') x=(x<<1)+(x<<3)+(c^48), c=getchar();
	x=(f ? -x : x);
}
int lne; char put[105];
template <typename T> void write(T x, char ch) {
	lne = 0; if(x < 0) putchar('-'), x=-x;
	do { put[++lne]=x%10, x/=10; } while(x);
	while(lne) putchar(put[lne--]^48);
	putchar(ch);
}
int upd(int x) {
	return (x >= mod ? x-mod : x);
}

signed main() {
	read(k), read(t);
	pw[0]=1;
	for(int i = 1; i <= 100000; ++i)
		pw[i]=1LL*pw[i-1]*k%mod;
	while(t--) {
		read(n);
		ans=0;
		for(int i = 1; i <= n; ++i)
			read(a[i]), vis[i]=0;
		for(int i = 2, j = 0; i <= n; ++i) {
			while(j && a[j+1] != a[i]) j=nxt[j];
			if(a[j+1] == a[i]) ++j;
			nxt[i]=j;
		}
		int now = n;
		while(now) 
			ans=upd(ans+pw[now]), now=nxt[now];
		if(ans < 1000) putchar('0');
		if(ans < 100) putchar('0');
		if(ans < 10) putchar('0');
		write(ans, '\n');
	}
	return 0;
}