【loj3396】novel(AC自动机维护文本串子串的匹配信息)

发布时间 2023-06-05 22:33:38作者: ez_lcw

设当前询问的串为 \(s_i\) 记为 \(t\)。考虑 \(r\) 右移,维护每个 \(l\) 对应的 \(g(l,r)\)\(\max_{l}\frac{g(l,r)}{r-l+1}\) 即可。

最基本的观察是:当 \(r\) 右移后,考虑 \(t_{1..r}\) 在 AC 自动机上匹配到的点 \(p\),那么对于 \(p\) 的任意祖先(包含 \(p\)\(u\),都会给 \(l\leq r-len_u+1\)\(g(l,r)\) 造成加上 \(w_u\) 的影响。

但由于 \(p\)\(r\) 右移过程中的飘忽不定,所以并不好直接维护这个过程。

注意到 \(t_{r-len_{fa_p}+1..r}\) 一定是某一个 \(s_{i'}\) 的前缀,这部分信息是计算过的。

\(2\leq l\leq r-len_{fa_p}\) 的更改很简单,相当于一个整体加。对于 \(l=1\) 单独维护。

所以有一个初步的做法:将所有的串同时处理同时 \(r\) 右移,然后用可持久化平衡树维护每个 \(l\) 对应的答案,那么只需要支持区间加和区间复制。

但由于我们只需要知道 \(\max_l\frac{g(l,r)}{r-l+1}\),所以可以做到线性:

首先发现 \(r-len_{fa_p}\) 是单调不降的,那么对于 \(2\leq l\leq r-len_{fa_p}\) 的部分,就只需要支持 push_back、全局加、用全局 \(\max \frac{val_i}{x-i}\) 和历史最大值 upmax。这里考虑用斜率优化:把 \(\max\frac{val_i}{x-i}\) 转为 \(-\min\frac{0-val_i}{x-i}\),看成 \((i,val_i)\)\((x,0)\) 的斜率,那么维护 \((i,val_i)\) 的右上凸壳,每次二分找到最陡的那条切线即可。事实上,由于 \(x\) 只会右移,凸壳只会整体上移且会在后面追加点,所以历史答案会变大的必要条件是切到的 \(i\) 更靠后了。那么就不用每次二分而是维护一个指针即可。

对于 \(r-len_{fa_p}+1\leq l\leq r\) 的部分,我们可以直接调用那个 \(s_{i'}\) 对应前缀的答案。

还有一个问题:设 \(r+1\) 对应的点为 \(q\),那么现在我们要把 \(l\in [r-len_{fa_p}+1,r+1-len_{fa_q}]\)\(g(l,r)\) 计算出来 push_back 进前半部分,直接到对应的 \(s_{i'}\) 递归计算即可。

不需要建反串 AC 自动机,但好像跑得最慢(

#include<bits/stdc++.h>

#define N 200011
#define NN 5000011

using namespace std;

using ll=long long;
using i128=__int128;

namespace modular
{
	const int mod=998244353;
	inline int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
	inline int dec(int x,int y){return x-y<0?x-y+mod:x-y;}
	inline int mul(int x,int y){return 1ll*x*y%mod;}
	inline void Add(int &x,int y){x=x+y>=mod?x+y-mod:x+y;}
	inline void Dec(int &x,int y){x=x-y<0?x-y+mod:x-y;}
	inline void Mul(int &x,int y){x=1ll*x*y%mod;}
	inline int poww(int a,int b){int ans=1;for(;b;Mul(a,a),b>>=1)if(b&1)Mul(ans,a);return ans;}
}using namespace modular;

template<typename T> void upmin(T &x,const T &y){if(y<x)x=y;}
template<typename T> void upmax(T &x,const T &y){if(x<y)x=y;}

template<typename T> void read(T &x)
{
	x=0; bool f=0; char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')f=1;ch=getchar();}
	while(ch>='0'&&ch<='9'){x=x*10+(ch^'0');ch=getchar();}
	if(f) x=-x;
}
template<> void read<string>(string &s)
{
	s=""; char ch=getchar();
	while(ch<'a'||ch>'d') ch=getchar();
	while(ch>='a'&&ch<='d') s+=ch,ch=getchar();
}

struct frac
{
	ll a; int b;
	frac():a(0),b(1){}
	frac(ll _a,int _b):a(_a),b(_b){if(b<0)a=-a,b=-b;}
	frac operator - () const {return frac(-a,b);}
	bool operator < (const frac &x) const {return (i128)a*x.b<(i128)x.a*b;}
	operator double () const {return 1.0*a/b;}
	operator int () const {return mul(a%mod,poww(b,mod-2));}
};

struct point
{
	int x; ll y;
	point(int _x,ll _y):x(_x),y(_y){}
	point addy(ll t) const {return point(x,y+t);}
};
frac slope(const point &a,const point &b){return frac(a.y-b.y,a.x-b.x);}

int n,type,w[N],nn[N];
int node,ch[NN][4],len[NN],fail[NN];
frac ans[NN];
string s[N];
ll sw[NN],g[N];

vector<int> p[N];

int *buc=fail; int ord[NN];
vector<pair<int,int>> ij[NN];
vector<ll> q[NN];

struct convex
{
	frac lans;
	vector<ll> g;
	vector<point> sta;
	ll tag;
	int pt;
	void push(point a)
	{
		g.push_back(a.y-tag);
		while(sta.size()>1&&slope(a,sta[sta.size()-2].addy(tag))<slope(a,sta[sta.size()-1].addy(tag))) sta.pop_back();
		if(!sta.empty()&&pt==sta.size()&&lans<slope(sta[sta.size()-1].addy(tag),a)) pt++;
		sta.push_back(a.addy(-tag)),upmin<int>(pt,sta.size());
	}
	void add(ll w){tag+=w;}
	frac query(int x)
	{
		auto p=point(x,0);
		if(pt) upmin(lans,slope(p,sta[pt-1].addy(tag)));
		for(frac v;pt<sta.size()&&(v=slope(p,sta[pt].addy(tag)))<lans;lans=v,pt++);
		return -lans;
	}
	ll qg(int i){return g[i]+tag;}
}conv[N];

int main()
{
	read(n),read(type);
	for(int i=1;i<=n;i++)
	{
		read(s[i]),read(w[i]);
		nn[i]=s[i].size();
		p[i].resize(nn[i]);
		p[i].shrink_to_fit();
		int u=0;
		ij[u].push_back({i,-1});
		for(int j=0;j<nn[i];j++)
		{
			const int v=s[i][j]-'a';
			if(!ch[u][v])
			{
				ch[u][v]=++node;
				buc[len[node]=len[u]+1]++;
			}
			p[i][j]=u=ch[u][v];
			ij[u].push_back({i,j});
		}
		sw[u]+=w[i];
	}
	for(int i=1;i<=node;i++) ij[i].shrink_to_fit(),buc[i]+=buc[i-1];
	for(int i=1;i<=node;i++) ord[buc[len[i]]--]=i;
	memset(fail,0,sizeof(fail));
	for(int i=1;i<=node;i++)
	{
		const int u=ord[i];
		sw[u]+=sw[fail[u]];
		for(int j=0;j<4;j++)
			(ch[u][j]?fail[ch[u][j]]:ch[u][j])=ch[fail[u]][j];
	}
	for(int ii=1;ii<=node;ii++)
	{
		const int u=ord[ii],lv=len[fail[u]];
		int k=0;
		for(auto pr:ij[u])
		{
			const int i=pr.first,j=pr.second;
			if(j+1<nn[i]) upmax(k,min(lv+1-len[fail[p[i][j+1]]],lv));
		}
		int v=fail[u];
		for(int t;k;k-=t,v=fail[v])
		{
			t=min(k,len[v]-len[fail[v]]);
			if(q[v].size()<t) q[v].resize(t);
		}
	}
	for(int i=1;i<=node;i++) q[i].shrink_to_fit();
	ans[0]=frac(0,1);
	for(int ii=0;ii<=node;ii++)
	{
		const int u=ord[ii];
		bool calced=0;
		for(auto pr:ij[u])
		{
			const int i=pr.first,j=pr.second;
			if(ii) conv[i].add(sw[fail[u]]);
			g[i]+=sw[u];
			if(!calced)
			{
				if(q[u].size()) q[u][0]=g[i];
				for(int k=1;k<(int)q[u].size();k++) q[u][k]=conv[i].qg(k);
				if(ii) upmax(ans[u],max({frac(g[i],len[u]),conv[i].query(j+1),ans[fail[u]]}));
				calced=1;
			}
			if(j+1<nn[i])
			{
				const int t=min(j,j+1-len[fail[p[i][j+1]]]);
				int v=fail[u],pt=0;
				for(int l=j-len[fail[u]]+1;l<=t;l++)
				{
					if(pt>=len[v]-len[fail[v]]) v=fail[v],pt=0;
					conv[i].push(point(l,q[v][pt])),pt++;
				}
				if(!fail[p[i][j+1]]) conv[i].push(point(j+1,0));
				ans[p[i][j+1]]=ans[u];
			}
		}
	}
	if(!type) printf("%.15lf\n",double(ans[p[1].back()]));
	else
	{
		int sum=0;
		for(int i=1;i<=n;i++)
			for(int j=0;j<nn[i];j++)
				sum^=int(ans[p[i][j]]);
		printf("%d\n",sum);
	}
	return 0;
}