后缀数组学习笔记

发布时间 2023-10-02 07:45:59作者: 下蛋爷
  • 基数排序

利用桶的单调性,从低位到高位依次将整数放到对应数位的桶中。

  • 后缀数组

定义:对于字符串 \(s\),定义 \(sa[i]\) 表示 \(s\)\(n\) 个后缀按字典序排序后的第 \(i\) 个后缀在 \(s\) 中的下标\(rk[i]\) 表示从 \(s_i\) 开始的后缀在后缀数组中的下标。

  • 倍增求 \(sa\)

不妨设 \(sa_{w,i}\) 表示只取每个后缀的前 \(w\) 个字符排序后第 \(i\) 个后缀在 \(s\) 中的下标。

考虑通过 \(sa_w\) 求出 \(sa_{2w}\)

容易发现如果 \(rk_{w,i}<rk_{w,j}\) 那么 \(s[i,\dots,i+w-1]\) 一定小于 \(s[j,\dots,j+w-1]\)

所以可以把每个长度为 \(2w\) 的子串拆成两个长度为 \(w\) 的子串。

那么判断两个长度为 \(2w\) 的子串 \(s[i,\dots,i+2w-1]\)\(s[j,\dots,j+2w-1]\) 的字典序就只要判断 \(rk_{w,i},rk_{w,i+w},rk_{w,j},rk_{w,j+w}\) 之间的关系。

如果 \(i\)\(j\) 小,那么 \(rk_{w,i}<rk_{w,j}\) 或者 \(rk_{w,i}=rk_{w,j}\) 并且 \(rk_{w,i+w}<rk_{w,j+w}\)

所以可以直接双关键字排序。

时间复杂度:\(O(n\log^2 n)\)

代码
#include <bits/stdc++.h>

// #define int int64_t

const int kMaxN = 1e6 + 5;

int n;
int sa[kMaxN << 2], rk[kMaxN << 2], nrk[kMaxN << 2];
std::string s;

void dickdreamer() {
  std::cin >> s;
  n = s.size();
  s = " " + s;
  for (int i = 1; i <= n; ++i) {
    sa[i] = i;
    rk[i] = s[i];
  }
  for (int w = 1; w <= n; w <<= 1) {
    auto cmp = [&] (const int x, const int y) {
      return rk[x] == rk[y] ? rk[x + w] < rk[y + w] : rk[x] < rk[y];
    };
    std::sort(sa + 1, sa + 1 + n, cmp);
    int c = 0;
    for (int i = 1; i <= n; ++i)
      nrk[sa[i]] = (rk[sa[i]] == rk[sa[i - 1]] && rk[sa[i] + w] == rk[sa[i - 1] + w] ? c : ++c);
    for (int i = 1; i <= n; ++i)
      rk[i] = nrk[i];
  }
  for (int i = 1; i <= n; ++i)
    std::cout << sa[i] << ' ';
}

int32_t main() {
#ifdef ORZXKR
  freopen("in.txt", "r", stdin);
  freopen("out.txt", "w", stdout);
#endif
  std::ios::sync_with_stdio(0), std::cin.tie(0), std::cout.tie(0);
  int T = 1;
  // std::cin >> T;
  while (T--) dickdreamer();
  // std::cerr << 1.0 * clock() / CLOCKS_PER_SEC << "s\n";
  return 0;
}
  • 优化

容易发现像上面那样求 \(rk\),每个 \(rk\) 的值域是 \([1,n]\),所以直接先对第二关键字排序,再对第一关键字排序即可。

时间复杂度:\(O(n\log n)\)

代码
void suffix_sort(std::string s, int *sa, int *rk) {
  static int cnt[kMaxN], ork[kMaxN << 1], id[kMaxN];
  memset(cnt, 0, sizeof(cnt));
  int n = static_cast<int>(s.size()) - 1;
  for (int i = 1; i <= n; ++i) {
    rk[i] = s[i];
    ++cnt[rk[i]];
  }
  for (int i = 1; i <= 128; ++i)
    cnt[i] += cnt[i - 1];
  for (int i = n; i; --i)
    sa[cnt[rk[i]]--] = i;
  for (int i = 1; i <= n; ++i)
    ork[i] = rk[i];
  int m = 0;
  for (int i = 1; i <= n; ++i) {
    if (ork[sa[i]] == ork[sa[i - 1]]) {
      rk[sa[i]] = m;
    } else {
      rk[sa[i]] = ++m;
    }
  }
  for (int w = 1; w < n; w <<= 1) {
    memset(cnt, 0, sizeof(cnt));
    for (int i = 1; i <= n; ++i)
      id[i] = sa[i];
    for (int i = 1; i <= n; ++i)
      ++cnt[rk[id[i] + w]];
    for (int i = 1; i <= m; ++i)
      cnt[i] += cnt[i - 1];
    for (int i = n; i; --i)
      sa[cnt[rk[id[i] + w]]--] = id[i];

    memset(cnt, 0, sizeof(cnt));
    for (int i = 1; i <= n; ++i)
      id[i] = sa[i];
    for (int i = 1; i <= n; ++i)
      ++cnt[rk[id[i]]];
    for (int i = 1; i <= m; ++i)
      cnt[i] += cnt[i - 1];
    for (int i = n; i; --i)
      sa[cnt[rk[id[i]]]--] = id[i];
    
    for (int i = 1; i <= n; ++i)
      ork[i] = rk[i];
    m = 0;
    for (int i = 1; i <= n; ++i) {
      if (ork[sa[i]] == ork[sa[i - 1]] && ork[sa[i] + w] == ork[sa[i - 1] + w]) {
        rk[sa[i]] = m;
      } else {
        rk[sa[i]] = ++m;
      }
    }
  }
}

但是这么做常数很大,并且实测还没直接 sort 快,因为做两次基数排序是很慢的。

注意到第一次排序相当于把 \(i+w>n\)\(i\) 放前面,并且把其他的 \(i\) 按照 \(rk[i+w]\) 的顺序排序。

又因为原来的 \(sa\) 数组就已经按照 \(rk\) 排好序了,所以直接从前往后扫 \(sa\) 数组,如果 \(sa[i]>w\) 就把 \(sa[i]-w\) 放到新数组中即可。

还有个小优化是如果当前的 \(rk\) 总共出现 \(n\) 次就说明已经排序完成,那么 break 就可以了。

这样做就比直接 sort 要快很多了。

代码
void getsa(std::string s, int *sa, int *rk) {
  static int cnt[kMaxN], ork[kMaxN << 1], id[kMaxN];
  memset(cnt, 0, sizeof(cnt));
  int n = static_cast<int>(s.size()) - 1, m = 0;
  for (int i = 1; i <= n; ++i) {
    rk[i] = s[i];
    ++cnt[rk[i]];
  }
  for (int i = 1; i <= 128; ++i)
    cnt[i] += cnt[i - 1];
  for (int i = n; i; --i)
    sa[cnt[rk[i]]--] = i;
  std::copy_n(rk + 1, n, ork + 1);
  for (int i = 1; i <= n; ++i) {
    if (ork[sa[i]] == ork[sa[i - 1]]) {
      rk[sa[i]] = m;
    } else {
      rk[sa[i]] = ++m;
    }
  }
  for (int w = 1; m < n; w <<= 1) {
    int p = 0;
    for (int i = n - w + 1; i <= n; ++i)
      id[++p] = i;
    for (int i = 1; i <= n; ++i)
      if (sa[i] > w)
        id[++p] = sa[i] - w;
    std::fill_n(cnt + 1, n, 0);
    for (int i = 1; i <= n; ++i)
      ++cnt[rk[id[i]]];
    for (int i = 1; i <= m; ++i)
      cnt[i] += cnt[i - 1];
    for (int i = n; i; --i)
      sa[cnt[rk[id[i]]]--] = id[i];
    
    m = 0;
    std::copy_n(rk + 1, n, ork + 1);
    for (int i = 1; i <= n; ++i) {
      if (ork[sa[i]] == ork[sa[i - 1]] && ork[sa[i] + w] == ork[sa[i - 1] + w]) {
        rk[sa[i]] = m;
      } else {
        rk[sa[i]] = ++m;
      }
    }
  }
}

\(height\) 数组

定义:\(sa[i-1]\)\(sa[i]\) 的最长公共前缀长度。

\(height\)

  • 引理:\(height[rk[i]]\geq height[rk[i-1]]-1\)
证明

\(height[rk[i]]\) 就是 \(s_i\)\(sa\)\(i\) 前面的后缀的 LCP。