2023年中国大学生程序设计竞赛女生专场 H. 字符串游戏 (AC自动机)

发布时间 2023-10-26 20:29:56作者: qdhys

解题思路:

对于每个询问串的查询可以改为以节点为后缀来统计有多少个查询串在里面然后来统计答案。拿下面这个例子来说:

3 1
a
bb
abb
aabb

首先对查询串(n个串)构建AC自动机,对于每个字符串结尾位置的状态p设置sum[p] = 1, 同时插入的时候维护每个状态的长度len[p]。

我们用for(int j = 0; j < str.length(); j ++)从前往后遍历aababc串,当我们遍历到字符串第1个字符后的时候在AC自动机上的状态为p。此时只需要查看p是不是终止状态(即sum[p]是否不为0), 此时sum[p]不为0,且len[p] = 1,答案统计就是前面能用的字符数量乘上后面能用的字符数量。我们可以得知此时以0个位置为结尾的后缀(a)在整个字符串出现的情况有(0 - len[p] + 2) * (4 - 0) = 6种,即a, aa, aab, aabb。此时ans = 4。

当p走完字符串的第2个字符之后。重复上述步骤, 因为我们查询串没有aa这个状态,我们仍会走到a这个状态,所以此时状态和上一个状态一致。此时sum[p]不为0,且len[p] = 1,我们可以得知此时以1个位置为结尾的后缀a( 其实应该是aa但是没有这个状态所以此时还是a )在整个字符串出现的情况有(1 - len[p] + 2) * (4 - 1) = 6种,即aa, aab, aabb, a, ab, abb。此时ans = 10。

当p走完字符串的第3个字符之后。 此时走到的状态并不是结尾状态。此时没有状态更新。

当p走完字符串的第4个字符之后。我们会走到abb这个状态,此时sum[p]不为0,且len[p] = 3。此时还是在整个字符串出现的情况有(3 - len[p] + 2) * (4 - 3) = 2,即aabb, abb。此时ans = 12。但是以第四个字符结尾的并非仅仅只有abb这个状态, 还有bb这个状态, 我们知道bb这个状态在abb状态的nxt指针上。所以我们只需要根据nxt指针上跳到bb这个状态即可。此时我们能得到sum[p] = 1, len[p] = 2, 所以bb这个状态的答案为(3 - len[p] + 2) * (4 - 3) = 3, 即aabb, abb, bb。所以ans = 15就是最终答案。

所以这题的解法就是走到每个状态后跳nxt指针,如果sum[p]不为0那么我们就统计答案。但是如果给长度为1e6的全a串我们这么慢慢跳时间肯定不够。接下来考虑如何优化。

我们看统计答案的式子(j - len[p] + 2) * (str.length() - j),遍历的时候str.length() - j是不变的,那么变化的我们只需要看左边的(j - len[p] + 2)如何优化,如果我们知道当前位置nxt路径上所有合法位置的长度和sumlen,还有数量sum的话,是不是可以将式子变化为(j * sum - sumlen + 2 * sum) * (str.length() - j)。所以我们只需要按照后缀自动机nxt路径从浅到深做一遍前缀和即可。而且我们已知构造AC自动机的q数组是严格按照bfs序递增的顺序构建的,所以我们只需要在这个q数组上进行前缀和即可。

#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
 
#define ll long long
#define fs first
#define se second
#define AC main(void)

const long double eps = 1e-9;
const int N = 2e5 + 10, M = 2e5 + 10;
const int MOD = 1e9 + 7;

int n , m, _;

  struct Aho_Corasick_Automaton{
	int cnt[N], tr[N][26], idx, q[N], nxt[N], mx[N];
	int pre[N];//跳到上一个有终止节点的位置
	int son[N];//
	int id[N], reid[N];//每个字符串原串位置的标记
	int sz[N], h[N], e[N], ne[N], dfn[N], tdl, idx1;//ac自动机fail树上建dfs序数组
	int vis[N * 2], used[N * 2];//找环
	int sigma = 26;
	int simple = 'a';
	int len[N];
	int sum[N], sumlen[N];

	inline void add(int a, int b){
		ne[idx1] = h[a], e[idx1] = b, h[a] = idx1 ++;
	}

	//多组测试清空操作
	inline void init(){
        for(int i = 0; i <= idx; i ++ ){
        	memset(tr[i], 0, sizeof tr[i]);
        	nxt[i] = cnt[i] = ne[i] = 0, h[i] = -1;
			vis[i] = used[i] = 0;
        } 
        idx = tdl = idx1 = 0;
    }

	inline bool findcycle(int u){//AC自动机找从0号是否可以有环(不能经过字符串被标记的地方)
		if(used[u] == 1)	return true;
		if(used[u] == -1)	return false;
		vis[u] = used[u] = true;
		for(int i = 0; i < 2; i ++)
			if(!son[tr[u][i]])	if(findcycle(tr[u][i]))	return true;
		used[u] = -1;
		return false;
	}

	inline void dfs(int u){//dfs序
		sz[u] = 1, dfn[u] = ++ tdl;
		for(int i = h[u]; ~i; i = ne[i]){
			int j = e[i];
			dfs(j);
			sz[u] += sz[j];
		}
	}

	inline int insert(std::string &s, int x){//插入字符 和插入的是第几个字符
		int p = 0;
		for(char &ch : s){
			int u = ch - simple;
			if(!tr[p][u]){
				tr[p][u] = ++ idx;
				len[idx] = len[p] + 1;
			}
			p = tr[p][u];
		}
		id[p] = x;//标记第x个字符的结尾
		reid[x] = p;
		cnt[p] ++;
		sumlen[p] = len[p];
		son[p] = 1;
		sum[p] = 1;
		mx[p] = std::max(mx[p], (int)s.length());
		return p;
	}
	
	inline void build(){//建立ac自动机
		int p = 0;
		int hh = 0,tt = -1;
		for(int i = 0; i < sigma; i ++)
			if(tr[p][i])	q[++ tt] = tr[p][i];
		while(hh <= tt){
			int tq = q[hh ++];
			for(int i = 0; i < 26; i ++){
				int j = tr[tq][i];
				if(!tr[tq][i]){
					tr[tq][i] = tr[nxt[tq]][i];
				}
				else{
					q[++ tt] = tr[tq][i];
					nxt[j] = tr[nxt[tq]][i];
					if(cnt[nxt[j]])  pre[j] = nxt[j];
                   	else pre[j] = pre[nxt[j]];
                   	if (son[nxt[j]])
					son[j] |= son[nxt[j]];//标记能到达终止节点路径上的所有点
				}
			}
		}
	}

	//ac自动机fail树上建dfs序的建边
	inline void failtree(){
		memset(h, -1, sizeof h);
		for(int i = 1; i <= idx; i ++)	add(nxt[i], i);
		dfs(0);
	}

	inline int query(std::string &s){
		int res = 0, j = 0;
		for(char &ch : s){
			int u = ch - 'a';
			j = tr[j][u];
			int p = u;
			while(p){
				res += cnt[p];
				p = nxt[p];
			}
		}
		return res;
	}	

}acam;

inline void solve(){
	std::cin >> n >> m;
	
	for(int i = 1; i <= n; i ++){
		std::string str;
		std::cin >> str;
		acam.insert(str, i);
	}
	
	acam.build();
	
	for (int i = 1; i <= acam.idx; i ++) {
		int t = acam.q[i];
		acam.sumlen[t] += acam.sumlen[acam.nxt[t]];
		acam.sum[t] += acam.sum[acam.nxt[t]];
	}
	
	for (int i = 1; i <= m; i ++) {
		std::string str;
		std::cin >> str;
		ll ans = 0;
		int p = 0;
		for (int j = 0; str[j]; j ++) {
			p = acam.tr[p][str[j] - 'a'];
			if (acam.sum[p]) ans += 1LL * (1ll * j * acam.sum[p] - acam.sumlen[p] + 2 * acam.sum[p]) % MOD * (str.length() - j) % MOD;
		}
		std::cout << ans % MOD << '\n';
	}
}

int main(void){
   	std::ios::sync_with_stdio(false);std::cin.tie(0);std::cout.tie(0);

	_ = 1;
   	//std::cin >> _;
	while(_ --)
    	solve();

    return 0;
}