牛客14612 string AC自动机 + 树状数组

发布时间 2023-03-22 21:13:51作者: qdhys

传送门

题目大意

  有T组测试数据,对于每组测试时局有一个n和m,n表示初始拥有的字符串数量,m表示操作数量。紧接着输入n个字符串,再读入m行操作,每行以x str的形式给出,如果x为1则是往所拥有的字符串内插入str,若x为2则是查询当前字符串包括了多少完整的字符串(重复出现也算)。

  如果要查询一个字符串被另一个字符串完整包含了多少次,可以想到AC自动机的FAIL结点的子树上的后缀都是以当前当前结点为前缀的。但是AC自动机是离线型数据结构,所以我们需要读入所有的插入字符串,然后考虑用其他数据结构动态维护每个字符串出现的次数。这里可以使用树状数组进行维护。

  我们知道如果AC自动机FAIL树上的一个结点数量增加了实际上是给当前子树内所有结点出现的次数都增加了,所有我们用树状数组进行差分动态维护子树信息。

#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
 
#define ll long long
#define fs first
#define se second
#define AC main(void)
#define HYS std::ios::sync_with_stdio(false);std::cin.tie(0);std::cout.tie(0);

const long double eps = 1e-9;
const int N = 2e5 + 10, M = 2e5 + 10;

int n , m, _;

struct Fenwick{
	int maxm, cnt = 0;
	std::vector<int> tr;
	Fenwick(int n): tr(n + 1, 0) {maxm = n;}
	inline int lowbit(int x) {return x & -x;}
	
	inline void add(int x, int v){
       	for(int i = x; i <= maxm; i += lowbit(i))	tr[i] += v;
    	cnt += v;
	}
    
    inline int query(int x){
        int res = 0;
        for(int i = x; i; i -= lowbit(i)) res += tr[i];
        return res;
    }
	
	inline int query(int l, int r){
		return query(r) - query(l - 1);
	}
	
	inline int find_kmin(int k) {
        int ans = 0, cnt = 0;
        for (int i = 20; i >= 0; i--) {
            ans += (1 << i);
            if (ans >= maxm || cnt + tr[ans] >= k) ans -= (1 << i);
            else cnt += tr[ans];
        }
        return ans + 1;
    }

    inline int find_kmax(int k) {
        return find_kmin(cnt - k + 1);
    }
};

struct Aho_Corasick_Automaton{
	int cnt[N], tr[N][26], idx, q[N], nxt[N], mx[N];
	int pre[N];//跳到上一个有终止节点的位置
	int son[N];//
	int id[N], reid[N];//每个字符串原串位置的标记
	int sz[N], h[N], e[N], ne[N], dfn[N], tdl, idx1;//ac自动机fail树上建dfs序数组
	int vis[N * 2], used[N * 2];//找环
	int sigma = 26;
	int simple = 'a';

	inline void add(int a, int b){
		ne[idx1] = h[a], e[idx1] = b, h[a] = idx1 ++;
	}

	//多组测试清空操作
	inline void init(){
        for(int i = 0; i <= idx; i ++ ){
        	memset(tr[i], 0, sizeof tr[i]);
        	nxt[i] = cnt[i] = ne[i] = 0, h[i] = -1;
			vis[i] = used[i] = 0;
        } 
        idx = tdl = idx1 = 0;
    }

	inline bool findcycle(int u){//AC自动机找从0号是否可以有环(不能经过字符串被标记的地方)
		if(used[u] == 1)	return true;
		if(used[u] == -1)	return false;
		vis[u] = used[u] = true;
		for(int i = 0; i < 2; i ++)
			if(!son[tr[u][i]])	if(findcycle(tr[u][i]))	return true;
		used[u] = -1;
		return false;
	}

	inline void dfs(int u){//dfs序
		sz[u] = 1, dfn[u] = ++ tdl;
		for(int i = h[u]; ~i; i = ne[i]){
			int j = e[i];
			dfs(j);
			sz[u] += sz[j];
		}
	}
	
	inline int insert(std::string &s){//插入字符 和插入的是第几个字符
		int p = 0;
		for(char &ch : s){
			int u = ch - simple;
			if(!tr[p][u]){
				tr[p][u] = ++ idx;
			}
			p = tr[p][u];
		}
		//cnt[p] ++;
		return p;
	}

	inline int insert(std::string &s, int x){//插入字符 和插入的是第几个字符
		int p = 0;
		for(char &ch : s){
			int u = ch - simple;
			if(!tr[p][u]){
				tr[p][u] = ++ idx;
			}
			p = tr[p][u];
		}
		id[p] = x;//标记第x个字符的结尾
		reid[x] = p;
		cnt[p] ++;
		son[p] = 1;
		mx[p] = std::max(mx[p], (int)s.length());
		return p;
	}
	
	inline void build(){//建立ac自动机
		int p = 0;
		int hh = 0,tt = -1;
		for(int i = 0; i < sigma; i ++)
			if(tr[p][i])	q[++ tt] = tr[p][i];
		while(hh <= tt){
			int tq = q[hh ++];
			for(int i = 0; i < 26; i ++){
				int j = tr[tq][i];
				if(!tr[tq][i]){
					tr[tq][i] = tr[nxt[tq]][i];
				}
				else{
					q[++ tt] = tr[tq][i];
					nxt[j] = tr[nxt[tq]][i];
					if(cnt[nxt[j]])  pre[j] = nxt[j];
                   	else pre[j] = pre[nxt[j]];
					//son[j] |= son[nxt[j]];//标记能到达终止节点路径上的所有点
				}
			}
		}
	}

	//ac自动机fail树上建dfs序的建边
	inline void failtree(){
		memset(h, -1, sizeof h);
		for(int i = 1; i <= idx; i ++)	add(nxt[i], i);
		dfs(0);
	}

	inline int query(std::string &s){
		int res = 0, j = 0;
		for(char &ch : s){
			int u = ch - 'a';
			j = tr[j][u];
			int p = u;
			while(p){
				res += cnt[p];
				p = nxt[p];
			}
		}
		return res;
	}	

}acam;

inline void solve(){
	acam.init();
	std::cin >> n >> m;
	std::vector<std::string> start(n);
	std::vector<std::pair<int, std::string>> g(m + 1);
	std::string str;
	for(int i = 1; i <= n; i ++){
		std::cin >> str;
		start[i - 1] = str;
		acam.insert(str);
	}

	for(int i = 1; i <= m; i ++){
		int op;
		std::cin >> op >> str;
		g[i] = {op, str};
		if(op == 1)	acam.insert(str);
	}
	
	acam.build();
	acam.failtree();

	const auto tdl = acam.tdl;
	const auto &tr = acam.tr;
	const auto &sz = acam.sz, &dfn = acam.dfn;

	Fenwick fk(tdl);
	
	for(int i = 0; i < n; i ++){
		int p = 0;
		auto &s = start[i];
		for(int j = 0; s[j]; j ++)
			p = tr[p][s[j] - 'a'];
		
		fk.add(dfn[p], 1);
		fk.add(dfn[p] + sz[p], -1);
	}
	
	for(int i = 1; i <= m; i ++){
		auto &op = g[i].fs;
		auto &s = g[i].se;
		if(op == 1){
			int p = 0;
			for(int j = 0; s[j]; j ++)
				p = tr[p][s[j] - 'a'];
			fk.add(dfn[p], 1);
			fk.add(dfn[p] + sz[p], -1);
		}else{
			ll ans = 0;
			int p = 0;
			for(int j = 0; s[j]; j ++){
				p = tr[p][s[j] - 'a'];
				ans += fk.query(dfn[p]);
			}
			std::cout << ans << '\n';
		}
	}
}

signed AC{
   	

   	std::cin >> _;
	while(_ --)
    	solve();

    return 0;
}