UVA1401 Remember the Word

发布时间 2023-06-11 11:43:05作者: ForgotDream

思路

首先有一个比较朴素的 DP 就是记 \(f_i\)\(s\) 的从第 \(i\) 个字符开始到字符串结尾的划分方案数,记模板串的集合为 \(T\)\(s\) 从第 \(i\) 个字符开始到字符串结尾的子串为 \(s(i)\),那么不难写出方程:

\[f_i = \sum f_{i + \operatorname{len}(t)}[t \in T \land t 是 s(i) 的前缀] \]

初始状态为 \(f_{\operatorname{len}(s)} = 1\),表示总存在一种方法划分空串,答案即为 \(f_0\)

但是直接按这个方程转移的话需要枚举每一个 \(t\) 还要判断 \(t\) 是否为 \(s(i)\) 的前缀,那么复杂度肯定会炸,于是考虑优化。

跟前缀有关的数据结构,不难想到字典树。把每一个 \(t\) 都插入字典树中,然后对于每一个 \(s(i)\) 直接在这棵字典树里查询 \(s(i)\) 即可,每经过一个带有结束标记的节点,就意味着找到了一个 \(t\)\(s(i)\) 的前缀,然后按上边那个方程转移即可。

由于题目保证了 \(t\) 的长度不超过 \(100\),那么总的时间复杂度约为 \(O(100\operatorname{len}(s))\),已经足够好了。

代码

有一个小技巧就是在单词结束的节点内保存这个单词的长度,不过直接在查询时累加长度当然也可行。

#include <bits/stdc++.h>

using i64 = long long;

constexpr int mod = 20071027;
template<size_t SIZ = int(1e5)>
struct Trie {
  std::vector<std::vector<int>> next;
  std::vector<int> f, end;
  Trie() { next.reserve(SIZ), next.emplace_back(26), end.push_back(0); }
  void insert(const std::string &s) {
    int p = 0;
    for (int i = 0; i < s.length(); i++) {
      int d = s[i] - 'a';
      if (!next[p][d]) { 
        next.emplace_back(26), end.push_back(0);
        next[p][d] = next.size() - 1;
      }
      p = next[p][d];
    }
    end[p] = s.length();
  }
  void find(const std::string &s, int st) {
    int p = 0;
    for (int i = st; i < s.length(); i++) {
      int d = s[i] - 'a';
      if (!next[p][d]) { return; }
      p = next[p][d];
      if (end[p]) { (f[st] += f[st + end[p]]) %= mod; }
    }
  }
  int solve(const std::string &s) {
    f.clear(), f.resize(s.length() + 1), f[s.length()] = 1;
    for (int i = s.length() - 1; ~i; i--) { find(s, i); }
    return f[0];
  }
};

void solve(int id, const std::string &s) {
  int n;
  std::cin >> n;
  Trie<> trie;
  for (int i = 0; i < n; i++) {
    std::string t;
    std::cin >> t;
    trie.insert(t);
  }
  std::cout << "Case " << id << ": " << trie.solve(s) << "\n";
}

signed main() {
  std::ios::sync_with_stdio(false);
  std::cin.tie(nullptr);

  std::string s;
  int cnt = 0;
  while (std::cin >> s) { solve(++cnt, s); }

  return 0;
}