CF135E

发布时间 2023-05-26 00:52:33作者: Neutral1sed

Key Observation:

若称一个前缀是 \(\texttt{distinct prefix}\) 当且仅当其中所有字符互不相同、且它是极长的满足这个性质的前缀,\(\texttt{distinct suffix}\) 同理,则 \(S\) 中最长的弱子串长度 \(f(S) = |S| - \min(|\texttt{distinct prefix}(S)|, |\texttt{distinct suffix}(S)|)\)

(以下将上面的定义简称为 \(\texttt{pre}\)\(\tt suf\)
考虑证明这个结论:
考虑最长的 \(\texttt{pre}\),设右端点是 \(i\),则 \(s_{i+1}\) 一定在它当中出现过,那么 \([i+1,|S|]\) 就是 \(S\) 的一个弱子串,并且它是所有后缀弱子串中最长的;我们先考虑 只通过 将左端点左移来判定是否是弱子串,那么可以发现不存在 \(l \in \texttt{pre}\) 的弱子串,那么 \([i+1,|S|]\) 就是最长的这样的弱子串;然后再考虑 只通过 将右端点右移,设 \(\tt suf\) 的左端点为 \(j\),同理可知不存在 \(r \in \tt suf\) 的弱子串,所以 \([1,j-1]\) 是最长的这样的弱子串;对两者取 \(\max\) 即得到了 可以通过 左端点左移 / 右端点右移判定 的最长的弱子串,于是原题得证。

于是肯定是枚举 \(\tt pre\) 或者 \(\tt suf\) 长度,但是由于有 \(\min\),所以直接做是不行的。
考虑容斥,\(f(x)\) 表示最长弱子串长度 \(\le x\) 的串个数(则答案为 \(f(n) - f(n-1)\)),也就是 \(\min(|\texttt{pre}|,|\texttt{suf}|) \ge |S|-x\)\(S\) 个数,那么我们只需要保证 \(|\texttt{pre}| \ge |S|-x \land |\texttt{suf}| \ge |S|-x\)
枚举 \(i = |S|-x\),则 \([1,i]\) 一定是 \(\tt pre\) 的前缀且 \([x+1,x+i]\) 一定是 \(\tt suf\) 的后缀(若有中间部分,则可以随便填):
\(i \lt x\),方案数即 \((k^{\underline{i}})^2 \times k^{x-i}\)
否则中间重合部分方案数 \(k^{\underline{i-x}}\),两边方案数都是 \((k-(i-x))^{\underline{x}}\),乘起来就行了。
复杂度 \(O(k \log n)\)

inline int Sqr(int x){ return 1ll*x*x%p; }
inline int P(int &x,int y){ return (x+=y)>=p&&(x-=p), x; }
inline int Pow(ll bs,ll b,ll rs = 1){ for(;b;bs=bs*bs%p,b>>=1) if(b&1) rs=rs*bs%p; return rs; }
inline int S(int a1,int d,int n){ return d==1 ? 1ll*a1*n%p : 1ll*a1*(Pow(d,n)-1+p)%p*Pow((d-1+p)%p,p-2)%p; }

int K, w, fac[N], invf[N];
inline int A(int n,int m){ return 1ll*fac[n]*invf[n-m]%p; }

__attribute__((constructor)) static void Prepare(){
	fac[0] = fac[1] = invf[0] = invf[1] = 1;
	for(int i=2;i<N;++i) fac[i] = 1ll*fac[i-1]*i%p;
	invf[N-1] = Pow(fac[N-1], p-2);
	for(int i=N-1;i>2;--i) invf[i-1] = 1ll*invf[i]*i%p;
}
inline int Calc(int n){
	int res = S(K, K, n);
	for(int i=1;i<=min(K,n-1);++i)
		P(res, 1ll*Pow(K,n-i)*Sqr(A(K,i))%p);
	for(int i=n;i<=K;++i)
		P(res, 1ll*A(K,i-n)*Sqr(A(K-(i-n),n))%p);
	return res;
}

main(){
	scanf("%d %d", &K, &w);
	printf("%d\n", (Calc(w)-Calc(w-1)+p)%p);
}