回文自动机(PAM)的简单应用

发布时间 2023-10-31 20:20:32作者: xhy666

记录回文自动机的一些应用实例

题目主要来源

模板

​ 跑\(PAM\)就是构建两棵字典树,字典树上(奇偶)根到不同节点都对应了一个原串中本质不同的回文串,同时维护了每个回文串对应的最长回文后缀。

​ 这个模板定义节点\(0\)为偶根,节点\(1\)为奇根(有些板子可能反过来)

\(next[i][j]\):当前节点\(i\)加上某个字符\(j\)后对应的回文串

\(cnt[i]:\)对于原串的每一个前缀,当前节点作为最长回文后缀出现的次数(\(Count()\)累加以后是总的出现次数)

\(fail[i]\):节点\(i\)对应的最长回文后缀

\(trans[i]\):节点\(i\)对应的不超过他本身长度一半的最长回文后缀(个别题有用,一般不需要维护)

\(len[i]\):当前回文串的长度,定义偶根的长度为\(0\),奇根的长度为\(-1\)

\(s[i]\):原串第\(i\)个字符

\(num[i]\):节点\(i\)\(fail\)树上的深度

\(tot\):节点总个数

\(last\):通过哪个节点转移

\(n\):已经维护的长度

struct PalindromicTree {
    int next[MAXN][MAXK], cnt[MAXN], fail[MAXN], trans[MAXN];
    int len[MAXN], s[MAXN], num[MAXN];
    int tot, last, n;
    int newnode(int length) {
        memset(next[tot], 0, sizeof next[tot]);
        len[tot] = length;
        fail[tot] = cnt[tot] = num[tot] = 0;
        return tot++;
    }
    int get_fail(int x) {
        while(s[n - len[x] - 1] != s[n]) x = fail[x];
        return x;
    }
    void init() {
        last = tot = n = 0;
        newnode(0);
        newnode(-1);
        fail[0] = 1;
        s[n] = -1;
    }
    void count() {
        for(int i = tot - 1; i >= 0; --i) 
            cnt[fail[i]] += cnt[i];
    }
    void add(int c) {
        c -= 'a';
        s[++n] = c;
        int cur = get_fail(last);
        if(!next[cur][c]) {
            int now = newnode(len[cur] + 2);
            fail[now] = next[get_fail(fail[cur])][c];
            next[cur][c] = now;
            num[now] = num[fail[now]] + 1;
            
            if (len[now] <= 2) trans[now] = fail[now];
            else {
              int v = trans[cur];
              while ((len[v] + 2) * 2 > len[now] || s[n - 1 - len[v]] != s[n]) v = fail[v];
              trans[now] = next[v][c];	
            }
        }
        last = next[cur][c];
        cnt[last]++;
    }
} pam;

P5496 【模板】回文自动机(PAM)

​ 统计每个前缀的有多少个后缀为回文串,即统计当前节点能跳多少次\(fail\),就是\(fail\)的深度

const int MAXN = 5e5 + 10, MAXK = 26;

struct PalindromicTree {
    int next[MAXN][MAXK], cnt[MAXN], fail[MAXN], trans[MAXN];
    int len[MAXN], s[MAXN], num[MAXN];
    int tot, last, n;
    int newnode(int length) {
        memset(next[tot], 0, sizeof next[tot]);
        len[tot] = length;
        fail[tot] = cnt[tot] = num[tot] = 0;
        return tot++;
    }
    int get_fail(int x) {
        while(s[n - len[x] - 1] != s[n]) x = fail[x];
        return x;
    }
    void init() {
        last = tot = n = 0;
        newnode(0);
        newnode(-1);
        fail[0] = 1;
        s[n] = -1;
    }
    void count() {
        for(int i = tot - 1; i >= 0; --i) 
            cnt[fail[i]] += cnt[i];
    }
    void add(int c) {
        c -= 'a';
        s[++n] = c;
        int cur = get_fail(last);
        if(!next[cur][c]) {
            int now = newnode(len[cur] + 2);
            fail[now] = next[get_fail(fail[cur])][c];
            next[cur][c] = now;
            num[now] = num[fail[now]] + 1;
            
            if (len[now] <= 2) trans[now] = fail[now];
            else {
              int v = trans[cur];
              while ((len[v] + 2) * 2 > len[now] || s[n - 1 - len[v]] != s[n]) v = fail[v];
              trans[now] = next[v][c];	
            }
        }
        last = next[cur][c];
        cnt[last]++;
    }
} pam;

char str[MAXN];

void work() {
    cin >> str;
    int len = strlen(str);
    
    pam.init();
    rep (i, 0, len - 1) {
    	pam.add(str[i]);
    	str[i + 1] = (str[i + 1] - 97 + pam.num[pam.last]) % 26 + 97;
    	cout << pam.num[pam.last] << " \n"[i == len - 1];
    }
}

P1659 [国家集训队] 拉拉队排练

​ 统计长度为奇数的回文串信息,就从奇根开始\(dfs/bfs\)

const int MAXN = 1e6, MAXK = 26; 

struct PalindromicTree {
    int next[MAXN][MAXK], cnt[MAXN], fail[MAXN];
    int len[MAXN], s[MAXN], num[MAXN];
    int tot, last, n;
    int newnode(int length) {
        memset(next[tot], 0, sizeof next[tot]);
        len[tot] = length;
        fail[tot] = cnt[tot] = num[tot] = 0;
        return tot++;
    }
    int get_fail(int x) {
        while(s[n - len[x] - 1] != s[n]) x = fail[x];
        return x;
    }
    void init() {
        last = tot = n = 0;
        newnode(0);
        newnode(-1);
        fail[0] = 1;
        s[n] = -1;
    }
    void count() {
        for(int i = tot - 1; i >= 0; --i) 
            cnt[fail[i]] += cnt[i];
    }
    void add(int c) {
        c -= 'a';
        s[++n] = c;
        int cur = get_fail(last);
        if(!next[cur][c]) {
            int now = newnode(len[cur] + 2);
            fail[now] = next[get_fail(fail[cur])][c];
            next[cur][c] = now;
            num[now] = num[fail[now]] + 1;
        }
        last = next[cur][c];
        cnt[last]++;
    }
} pam;

LL qpow(LL a, LL b, LL p) {
	LL base = a % p, res = 1;
	while (b) {
		if (b & 1) res = res * base % p;
		base = base * base % p;
		b >>= 1;
	}
	return res % p;
}

void dfs(int u, vector<PIL> &tmp) {
    if (pam.len[u] >= 1) {
        tmp.push_back({pam.len[u], pam.cnt[u]});
    }
    rep (i, 0, 25) {
        int x = pam.next[u][i];
        if (x) dfs(x, tmp);
    }
}

void work() {
    int n;
    LL k;
    string s;
    cin >> n >> k >> s;
    pam.init();
    rep (i, 0, s.length() - 1) {
        pam.add(s[i]);
    }
    pam.count();

    vector<PIL> tmp;
    dfs(1, tmp);
    
    sort(tmp.begin(), tmp.end(), [&](PIL &x, PIL &y) {
        return x.fr > y.fr;
    });

    int res = 1;
    LL tot = 0;
    for (auto &[x, y]: tmp) {
        LL d = min(k - tot, y);
        tot += d;
        assert(x % 2 == 1);
        res = 1ll * res * qpow(x, d, MOD) % MOD;
        if (tot == k) break;
    }
    
    if (tot < k) cout << "-1\n";
    else cout << res << "\n";
}

P4287 [SHOI2011] 双倍回文

​ 一个回文串如果是双倍回文,那么他在\(fail\)树上的祖先一定有长度等于他一半的回文。所以把\(fail\)树建出来,跑一遍\(dfs\)就好了。也可以通过\(trans\)指针求。

const int MAXN = 5e5 + 10, MAXK = 26; 

int res;
vector<int> g[MAXN]; 
int st[MAXN];

struct PalindromicTree {
    int next[MAXN][MAXK], cnt[MAXN], fail[MAXN];
    int len[MAXN], s[MAXN], num[MAXN];
    int tot, last, n;
    int newnode(int length) {
        memset(next[tot], 0, sizeof next[tot]);
        len[tot] = length;
        fail[tot] = cnt[tot] = num[tot] = 0;
        return tot++;
    }
    int get_fail(int x) {
        while(s[n - len[x] - 1] != s[n]) x = fail[x];
        return x;
    }
    void init() {
        last = tot = n = 0;
        newnode(0);
        newnode(-1);
        fail[0] = 1;
        s[n] = -1;
    }
    void count() {
        for(int i = tot - 1; i >= 0; --i) {
            cnt[fail[i]] += cnt[i];
            g[fail[i]].push_back(i);
        }
    }
    void add(int c) {
        c -= 'a';
        s[++n] = c;
        int cur = get_fail(last);
        if(!next[cur][c]) {
            int now = newnode(len[cur] + 2);
            fail[now] = next[get_fail(fail[cur])][c];
            next[cur][c] = now;
            num[now] = num[fail[now]] + 1;
        }
        last = next[cur][c];
        cnt[last]++;
    }
} pam;

void dfs(int u, int fu) {
    if (pam.len[u] >= 1 && pam.len[u] % 2 == 0) {
        if (st[pam.len[u] / 2] >= 1) res = max(res, pam.len[u]);
        st[pam.len[u]]++;
    }
    for (auto &v: g[u]) if (v != fu) dfs(v, u);
    if (pam.len[u] >= 1 && pam.len[u] % 2 == 0) {
        st[pam.len[u]]--;
    }
}

void work() {
    int n;
    string s;
    cin >> n >> s;

    pam.init();
    rep (i, 0, n - 1) pam.add(s[i]);
    pam.count();

    dfs(1, -1);

    cout << res << "\n";
}

P3649 [APIO2014] 回文串

​ 用回文自动机做就是大水题,\(count()\)的时候顺便统计答案即可

LL res;

const int MAXN = 1e6, MAXK = 26; 

struct PalindromicTree {
		int next[MAXN][MAXK], cnt[MAXN], fail[MAXN];
    int len[MAXN], s[MAXN], num[MAXN];
    int tot, last, n;
    int newnode(int length) {
        memset(next[tot], 0, sizeof next[tot]);
        len[tot] = length;
        fail[tot] = cnt[tot] = num[tot] = 0;
        return tot++;
    }
    int get_fail(int x) {
        while(s[n - len[x] - 1] != s[n]) x = fail[x];
        return x;
    }
    void init() {
        last = tot = n = 0;
        newnode(0);
        newnode(-1);
        fail[0] = 1;
        s[n] = -1;
    }
    void count() {
        for(int i = tot - 1; i >= 2; --i) {
            cnt[fail[i]] += cnt[i];
            res = max(res, 1ll * len[i] * cnt[i]);
        }
    }
    void add(int c) {
        c -= 'a';
        s[++n] = c;
        int cur = get_fail(last);
        if(!next[cur][c]) {
            int now = newnode(len[cur] + 2);
            fail[now] = next[get_fail(fail[cur])][c];
            next[cur][c] = now;
            num[now] = num[fail[now]] + 1;
        }
        last = next[cur][c];
        cnt[last]++;
    }
} pam;

void work() {
    string s;
    cin >> s;
    
    pam.init();
    for (auto &c: s) pam.add(c);
    pam.count();

    cout << res << "\n";
}

CF17E Palisection

​ 问原串中有多少对相交的回文串。把回文串看成一个区间,就是求有多少相交的区间。统计一个区间与多少区间交,可以从右往左遍历,维护当前存在的区间个数\(m\),设当前节点上有\(t\)个区间左端点,对于答案的贡献就是\(t\)个区间两两配对加上\(t\)个区间与\(m\)个区间配对。那么需要知道当前节点上多少个区间右端点和左端点,右端点显然跑一遍原串统计即可,因为回文串对称,左端点跑一遍反串就行了。

const int MAXN = 2e6 + 10, MAXK = 26; 

int pre[N];

struct PalindromicTree {
    vector<PII> next[MAXN];  
    int fail[MAXN];
    int len[MAXN], s[MAXN], num[MAXN];
    int tot, last, n;
    int newnode(int length) {
        next[tot].clear();
        len[tot] = length;
        fail[tot] = num[tot] = 0;
        return tot++;
    }
    int get_fail(int x) {
        while(s[n - len[x] - 1] != s[n]) x = fail[x];
        return x;
    }
    void init() {
        rep (i, 0, tot) next[i].clear();
        last = tot = n = 0;
        newnode(0);
        newnode(-1);
        fail[0] = 1;
        s[n] = -1;
    }
    void add(int c, int st) {
        c -= 'a';
        s[++n] = c;
        int cur = get_fail(last);
        int id = -1;
        rep (i, 0, next[cur].size() - 1) {
            if (next[cur][i].fr == c) {
                id = i;
                break;
            }
        }
        if(id == -1) {
            int now = newnode(len[cur] + 2);
            int fu = get_fail(fail[cur]);
            for (auto &u: next[fu]) {
                if (u.fr == c) {
                    fail[now] = u.se;
                    break;
                }
        		}
            next[cur].push_back({c, now});
            id = next[cur].size() - 1;
            num[now] = num[fail[now]] + 1;
        }
        last = next[cur][id].se;
        if (st) pre[n] = num[last]; 
    }
} pam;

void work() {
    int n;
    string s;
    cin >> n >> s; 

    pam.init();
    for (auto &c: s) {
        pam.add(c, 1);
    }

    pam.init();
    LL m = 0, t = 0, res = 0;
    rrep (i, n - 1, 0) {
        m += pre[i + 1];
        pam.add(s[i], 0);
        t = pam.num[pam.last];
        m -= t;
        res = (res + m * t % MOD + t * (t - 1) / 2 % MOD) % MOD;
    }
    cout << res << "\n";

}

P5685 [JSOI2013] 快乐的 JYY

​ 统计两个字符串中的相同回文子串有多少对。从奇偶根开始跑dfs,每次转移相同边,那么两个节点对应的一定是相同的回文串,个数相乘即可。和icpc14西安G基本一样

const int MAXN = 5e5 + 10, MAXK = 26; 

struct PalindromicTree {
		vector<PII> next[MAXN];
    int fail[MAXN];
    int len[MAXN], s[MAXN], num[MAXN], cnt[MAXN];
    int tot, last, n;
    int newnode(int length) {
    		next[tot].clear();
        len[tot] = length;
        cnt[tot] = fail[tot] = num[tot] = 0;
        return tot++;
    }
    int get_fail(int x) {
        while(s[n - len[x] - 1] != s[n]) x = fail[x];
        return x;
    }
    void init() {
    		rep (i, 0, tot) next[i].clear();
        last = tot = n = 0;
        newnode(0);
        newnode(-1);
        fail[0] = 1;
        s[n] = -1;
    }
    void count() {
        for(int i = tot - 1; i >= 0; --i) 
            cnt[fail[i]] += cnt[i];
    }
    void add(int c) {
        c -= 'a';
        s[++n] = c;
        int cur = get_fail(last);
        int id = -1;
        rep (i, 0, next[cur].size() - 1) {
            if (next[cur][i].fr == c) {
            id = i;
            break;
            }
        }
        if(id == -1) {
            int now = newnode(len[cur] + 2);
            int fu = get_fail(fail[cur]);
            for (auto &u: next[fu]) {
                if (u.fr == c) {
                  fail[now] = u.se;
                  break;
                }
            }
            next[cur].push_back({c, now});
            id = next[cur].size() - 1;
            num[now] = num[fail[now]] + 1;
        }
        last = next[cur][id].se;
        cnt[last]++; 
    }
} pam1, pam2;

LL res;

void dfs(int u, int v) {
    if (u != 0 && u != 1) res += 1ll * pam1.cnt[u] * pam2.cnt[v];
    for (auto &p: pam1.next[u]) {
        for (auto &q: pam2.next[v]) {
        		if (p.fr == q.fr) dfs(p.se, q.se);
        }
    }
}

void work() {
    string str1, str2;
    cin >> str1 >> str2;

    pam1.init(), pam2.init();
    for (auto &c: str1) pam1.add(c);
    for (auto &c: str2) pam2.add(c);
    pam1.count(), pam2.count();

    dfs(0, 0), dfs(1, 1);
    cout << res << "\n";
}

P5555 秩序魔咒

​ 求两个字符串中的最长公共回文子串的长度和个数,做法还是跟上题一样

const int MAXN = 3e5, MAXK = 26; 

struct PalindromicTree {
    int next[MAXN][MAXK], cnt[MAXN], fail[MAXN];
    int len[MAXN], s[MAXN], num[MAXN];
    int tot, last, n;
    int newnode(int length) {
        memset(next[tot], 0, sizeof next[tot]);
        len[tot] = length;
        cnt[tot] = num[tot] = 0;
        return tot++;
    }
    int get_fail(int x) {
        while(s[n - len[x] - 1] != s[n]) x = fail[x];
        return x;
    }
    void init() {
        last = tot = n = 0;
        newnode(0);
        newnode(-1);
        fail[0] = 1;
        s[n] = -1;
    }
    void count() {
        for(int i = tot - 1; i >= 0; --i) 
            cnt[fail[i]] += cnt[i];
    }
    void add(int c) {
        c -= 'a';
        s[++n] = c;
        int cur = get_fail(last);
        if(!next[cur][c]) {
            int now = newnode(len[cur] + 2);
            fail[now] = next[get_fail(fail[cur])][c];
            next[cur][c] = now;
            num[now] = num[fail[now]] + 1;
        }
        last = next[cur][c];
        cnt[last]++;
    }
} pam_1, pam_2;


int res_len, res_cnt;
void dfs(int u, int v) {
    if (u != 0 && u != 1) {
        if (pam_1.len[u] > res_len) {
            res_len = pam_1.len[u];
            res_cnt = 1;
        }
        else if (pam_1.len[u] == res_len) res_cnt++;
    }
    rep (i, 0, 25) {
        int x = pam_1.next[u][i], y = pam_2.next[v][i];
        if (x && y) dfs(x, y);
    }
}

void work() {
    int n, m;
    string str_1, str_2;
    cin >> n >> m >> str_1 >> str_2;

    pam_1.init(), pam_2.init();
    for (auto &c: str_1) pam_1.add(c);
    for (auto &c: str_2) pam_2.add(c);
    pam_1.count(), pam_2.count();

    dfs(0, 0);
    dfs(1, 1);

    cout << res_len << " " << res_cnt << "\n";
}

P4762 [CERC2014] Virus synthesis

\(dp[i]\)表示形成当前回文串的最小操作数,\(ans=max(ans,Len - len[i]+dp[i])\),其中\(Len\)是原串的长度。

\(dp\)转移:

\(dp[i]=len[i]\)(只进行操作一)

对于长度为\(n\)的回文串\(S\)\(S_{half}=S[\frac{n+1}{2},n]\)\(S_{half}\)的最长回文后缀为\(T\)。可以证明对于操作二\(S\)只需两种转移

第一种:从\(S_{half}\)非本身的最长前缀转移(加一个字符并翻倍),即从回文树的父亲节点转移。

第二种:从\(T\)转移,先对\(T\)进行操作一,直到\(T=S_{half}\),再翻倍。即从回文树的\(trans\)节点转移。

​ 直观来讲,对于操作二,\(S\)可以从所有\(S_{half}\)的回文子串转移,为什么只有两种情况?

​ 考虑\(S_{half}\)的最后一个字符\(c\),假设去掉\(c\),所有情况都被第一种囊括了;而对于以\(c\)结尾的回文后缀,显然只需要考虑最长的那个。也就是说,只需考虑添加\(c\)以后增加的本质不同的后缀回文子串中最长的那个。

​ 假如\(T\)是奇串呢?那么如果存在一个比\(T\)更短的偶(后缀回文)串,那么他肯定是\(S_{half}\)某个前缀的最长回文后缀,已经考虑过了。

​ 所以对于所有的奇串,本质不承担任何转移操作二的作用,也不可能从操作二转移过来,\(dp[i]=len[i]\)

#include<bits/stdc++.h>
using namespace std;

#define fr first
#define se second
#define et0 exit(0);
#define rep(i, a, b) for(int i = (int)(a); i <= (int)(b); i ++)
#define rrep(i, a, b) for(int i = (int)(a); i >= (int)(b); i --)
#define IO ios::sync_with_stdio(false),cin.tie(0);

typedef long long LL;
typedef pair<int, int> PII;
typedef pair<int, PII> PIP;
typedef unsigned long long ULL;
 
const int INF = 0X3f3f3f3f, N = 1e5 + 10, MOD = 998244353;
const LL LLINF = 0x3f3f3f3f3f3f3f3f;
const double eps = 1e-7, pi = acos(-1);

const int MAXN = 1e5 + 10, MAXK = 4; 

int dp[N], mp[N];
int str_len, res;

struct PalindromicTree {
    int next[MAXN][MAXK], cnt[MAXN], fail[MAXN], trans[MAXN];
    int len[MAXN], s[MAXN], num[MAXN];
    int tot, last, n;
    int newnode(int length) {
        memset(next[tot], 0, sizeof next[tot]);
        len[tot] = length;
        cnt[tot] = num[tot] = 0;
        return tot++;
    }
    int get_fail(int x) {
        while(s[n - len[x] - 1] != s[n]) x = fail[x];
        return x;
    }
    void init() {
        last = tot = n = 0;
        newnode(0);
        newnode(-1);
        fail[0] = 1;
        s[n] = -1;
        dp[0] = 1;
    }
    void count() {
        for(int i = tot - 1; i >= 0; --i) 
            cnt[fail[i]] += cnt[i];
    }
    void add(int c) {
        c = mp[c];
        s[++n] = c;
        int cur = get_fail(last);
        if(!next[cur][c]) {
            int now = newnode(len[cur] + 2);
            fail[now] = next[get_fail(fail[cur])][c];
            next[cur][c] = now;
            num[now] = num[fail[now]] + 1;
            
     		if (len[now] <= 2) trans[now] = fail[now];
     		else {
     			int v = trans[cur];
     			while ((len[v] + 2) * 2 > len[now] || s[n - 1 - len[v]] != s[n]) v = fail[v];
     			trans[now] = next[v][c];	
     		}
     		
     		dp[now] = len[now];
     		if (len[now] % 2 == 0) {
     			dp[now] = min(dp[now], dp[cur] + 1);
     			dp[now] = min(dp[now], len[now] / 2 - len[trans[now]] + 1 + dp[trans[now]]);
     		}
     		res = min(res, str_len - len[now] + dp[now]);
        }
        last = next[cur][c];
        cnt[last]++;
    }
} pam;

void work() {
	string str;
	cin >> str;
	str_len = str.length();
	res = str_len;

	pam.init();
	for (auto &c: str) pam.add(c);
	
	cout << res << "\n";
}
/*

*/

signed main() {
    IO
    mp['A'] = 0; mp['T'] = 1; mp['C'] = 2; mp['G'] = 3;
    
    int test = 1;
    cin >> test;
    while (test--) work();
    
    return 0;
}