「学习笔记」后缀数组

发布时间 2023-07-11 22:11:32作者: yi_fan0305

感谢 LB 学长的博文!

前置知识

后缀是指从某个位置 \(i\) 开始到整个串末尾结束的一个特殊子串,也就是 \(S[i \dots|S|-1]\)

计数排序 - OI Wiki (oi-wiki.org)

基数排序 - OI Wiki (oi-wiki.org)

变量

后缀数组最主要的两个数组是 sark

sa 表示将所有后缀排序后第 \(i\) 小的后缀的编号,即编号数组。

rk 表示后缀 \(i\) 的排名,即排名数组。

这两个数组满足一个重要性质: sa[rk[i]] = rk[sa[i]] = i

示例:

这个图很好理解。

做法

暴力的 \(O_{n^2 \log n}\) 做法

将所有的后缀数组都 sort 一遍,sort 复杂度为 \(O_{n \log n}\),字符串比较复杂度为 \(O_{n}\),总的复杂度 \(O_{n^2 \log n}\)

/*
  The code was written by yifan, and yifan is neutral!!!
 */

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

template<typename T>
inline T read() {
    T x = 0;
    bool fg = 0;
    char ch = getchar();
    while (ch < '0' || ch > '9') {
        fg |= (ch == '-');
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9') {
        x = (x << 3) + (x << 1) + (ch ^ 48);
        ch = getchar();
    }
    return fg ? ~x + 1 : x;
}

const int N = 1e6 + 5;

int n;
char s[N];
string h[N];
pair<string, int> ans[N];

int main() {
    scanf("%s", s + 1);
    n = strlen(s + 1);
    for (int i = 1; i <= n; ++ i) {
        for (int j = i; j <= n; ++ j) {
            h[i] += s[j];
        }
        ans[i] = {h[i], i};
    }
    sort(ans + 1, ans + n + 1);
    for (int i = 1; i <= n; ++ i) {
        cout << ans[i].second << ' ';
    }
    return 0;
}

倍增优化的 \(O_{n \log^2 n}\) 做法

先对长度为 \(1\) 的所有子串,即每个字符排序,得到排序后的 sa1rk1 数组。

之后倍增:

  1. 用两个长度为 \(1\) 的子串的排名,即 rk1[i]rk1[i + 1],作为排序的第一关键词和第二关键词,这样就可以对每个长度为 \(2\) 的子串进行排序,得到 sa2rk2

  2. 之后用两个长度为 \(2\) 的子串的排名,即 rk2[i]rk2[i + 2],来作为排序的第一关键词和第二关键词。(为什么是 \(i + 2\) 呢,因为 rk2[i]rk2[i + 1] 重复了 \(S_{i + 1}\))这样就可以对每个长度为 \(4\) 的子串进行排序,得到 sa4rk4

  3. 重复上面的操作,用两个长度为 \(\dfrac{w}{2}\) 的子串的排名,即 rk[i]rk[i + (w / 2)],来作为排序的第一关键词和第二关键词,直到 \(w \ge n\),最终得到的 sa 数组就是我们的答案数组。

示意图:

倍增的复杂度为 \(O_{\log n}\)sort 复杂度为 \(O_{n \log n}\),总的复杂度 \(O_{n \log ^ 2 n}\)

排序优化的 \(O_{n \log n}\) 的做法

发现后缀数组值域即为 \(n\),又是多关键字排序,考虑基数排序。
上面已经给出一个用于比较的式子:(A[i] < A[j] or (A[i] = A[j] and B[i] < B[j])),倍增过程中 A[i], B[i] 大小关系已知,先将 B[i] 作为第二关键字排序,再将 A[i] 作为第一关键字排序,两次计数排序实现即可。
单次计数排序复杂度 \(O_{n+w}\)\(w\) 为值域,显然最大与 \(n\) 同阶),总时间复杂度变为 \(O_{n \log n}\)

/*
  The code was written by yifan, and yifan is neutral!!!
 */

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

template<typename T>
inline T read() {
    T x = 0;
    bool fg = 0;
    char ch = getchar();
    while (ch < '0' || ch > '9') {
        fg |= (ch == '-');
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9') {
        x = (x << 3) + (x << 1) + (ch ^ 48);
        ch = getchar();
    }
    return fg ? ~x + 1 : x;
}

const int N = 1e6 + 5;

int n, m;
int sa[N], oldsa[N], rk[N << 1], oldrk[N << 1], cnt[N];
// rk 第 i 个后缀的排名,sa 第 i 小的后缀的编号
char s[N];

int main() {
    scanf("%s", s + 1);
    n = strlen(s + 1);
    m = 127;

    /*--------------------------------*/

    // 计数排序

    for (int i = 1; i <= n; ++ i) {
        ++ cnt[rk[i] = s[i]];
    }
    for (int i = 1; i <= m; ++ i) {
        cnt[i] += cnt[i - 1];
    }
    for (int i = n; i >= 1; -- i) {
        sa[cnt[rk[i]] --] = i;
    }
    memcpy(oldrk + 1, rk + 1, n * sizeof(int));

    /*--------------------------------*/

    // 判重

    for (int cur = 0, i = 1; i <= n; ++ i) {
        if (oldrk[sa[i]] == oldrk[sa[i - 1]]) {
            rk[sa[i]] = cur;
        }
        else {
            rk[sa[i]] = ++ cur;
        }
    }

    /*--------------------------------*/

    for (int w = 1; w < n; w <<= 1, m = n) {

        // 先按照第二关键词计数排序

        memset(cnt, 0, sizeof cnt);
        memcpy(oldsa + 1, sa + 1, n * sizeof(int));
        for (int i = 1; i <= n; ++ i) {
            ++ cnt[rk[oldsa[i] + w]];
        }
        for (int i = 1; i <= m; ++ i) {
            cnt[i] += cnt[i - 1];
        }
        for (int i = n; i >= 1; -- i) {
            sa[cnt[rk[oldsa[i] + w]] --] = oldsa[i];
        }

        /*--------------------------------*/

        // 再按照第一关键词计数排序

        memset(cnt, 0, sizeof cnt);
        memcpy(oldsa + 1, sa + 1, n * sizeof(int));
        for (int i = 1; i <= n; ++ i) {
            ++ cnt[rk[oldsa[i]]];
        }
        for (int i = 1; i <= m; ++ i) {
            cnt[i] += cnt[i - 1];
        }
        for (int i = n; i >= 1; -- i) {
            sa[cnt[rk[oldsa[i]]] --] = oldsa[i];
        }

        /*--------------------------------*/

        // 更新数组

        memcpy(oldrk + 1, rk + 1, n * sizeof(int));
        for (int cur = 0, i = 1; i <= n; ++ i) {
            if (oldrk[sa[i]] == oldrk[sa[i - 1]] && oldrk[sa[i] + w] == oldrk[sa[i - 1] + w]) {
                rk[sa[i]] = cur;
            }
            else {
                rk[sa[i]] = ++ cur;
            }
        }
    }
    for (int i = 1; i <= n; ++ i) {
        printf("%d ", sa[i]);
    }
    return 0;
}

各种常数优化

  1. 考虑我们按照第二关键词排序的实质,就是将超出 \(n\) 范围的空字符串放在 sa 的最前面,在本次排序中,\(S[sa_i \dots sa_i+2^k−1]\) 是长度为 \(2^k\) 的子串 \(S[sai−2^k−1 \dots sai+2^k−1]\) 的后半截,sa[i] 的排名将作为排序的关键字。
    \(S[sa_i,sa_i+2^k−1]\) 的排名为 \(i\),则第一次计排\(S[sa_i−2^k−1 \dots sa_i+2^k−1]\) 的排名必为 \(i\),考虑直接赋值。
for (p = 0, i = n; i > n - w; -- i) {
    oldsa[++ p] = i;
}
for (int i = 1; i <= n; ++ i) {
    if (sa[i] > w) { // 保证 sa[i] 是后半截的编号
        oldsa[++ p] = sa[i] - w; // sa[i] 一定是后半截的编号,而我们要存的是前半截的开始编号
    }
}
  1. 减小值域,每次对 rk 进行更新之后,我们都计算了一个 \(p\),这个 \(p\) 即是 rk 的值域,将值域改成它即可。

  2. rk[id[i]] 存下来,减少不连续内存访问。

  3. 用函数 cmp 来计算是否重复。

  4. 若排名都不相同可直接生成后缀数组。

/*
  The code was written by yifan, and yifan is neutral!!!
 */

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

template<typename T>
inline T read() {
    T x = 0;
    bool fg = 0;
    char ch = getchar();
    while (ch < '0' || ch > '9') {
        fg |= (ch == '-');
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9') {
        x = (x << 3) + (x << 1) + (ch ^ 48);
        ch = getchar();
    }
    return fg ? ~x + 1 : x;
}

const int N = 1e6 + 5;

int n, m;
int sa[N], oldsa[N], rk[N << 1], oldrk[N << 1], cnt[N], key[N];
// rk 第 i 个后缀的排名,sa 第 i 小的后缀的编号
char s[N];

inline bool cmp(int x, int y, int w) {
    return oldrk[x] == oldrk[y] && oldrk[x + w] == oldrk[y + w];
}

int main() {
    int i, p = 0;
    scanf("%s", s + 1);
    n = strlen(s + 1);
    m = 127;
    for (int i = 1; i <= n; ++ i) {
        ++ cnt[rk[i] = s[i]];
    }
    for (int i = 1; i <= m; ++ i) {
        cnt[i] += cnt[i - 1];
    }
    for (int i = n; i >= 1; -- i) {
        sa[cnt[rk[i]] --] = i;
    }
    for (int w = 1; ; w <<= 1, m = p) {
        for (p = 0, i = n; i > n - w; -- i) {
            oldsa[++ p] = i;
        }
        for (int i = 1; i <= n; ++ i) {
            if (sa[i] > w) { // 保证 sa[i] 是后半截的编号
                oldsa[++ p] = sa[i] - w; // sa[i] 一定是后半截的编号,而我们要存的是前半截的开始编号
            }
        }
        memset(cnt, 0, sizeof cnt);
        for (i = 1; i <= n; ++ i) {
            ++ cnt[key[i] = rk[oldsa[i]]];
        }
        for (i = 1; i <= m; ++ i) {
            cnt[i] += cnt[i - 1];
        }
        for (i = n; i >= 1; -- i) {
            sa[cnt[key[i]] --] = oldsa[i];
        }
        memcpy(oldrk + 1, rk + 1, n * sizeof(int));
        for (p = 0, i = 1; i <= n; ++ i) {
            rk[sa[i]] = cmp(sa[i], sa[i - 1], w) ? p : ++ p;
        }
        if (p == n) {
            break ;
        }
    }
    for (int i = 1; i <= n; ++ i) {
        printf("%d ", sa[i]);
    }
    return 0;
}

参考资料

后缀数组简介 - OI Wiki (oi-wiki.org)

「笔记」后缀数组 - Luckyblock - 博客园 (cnblogs.com)