题解 P9196【[JOI Open 2016] 销售基因链】

发布时间 2023-06-13 13:07:47作者: rui_er

套路题,来讲个套路解法。

如果没有后缀的要求,答案就是 trie 树的子树内字符串数量。现在加上了后缀,尝试继续使用 trie 树解决问题。

我们建立两棵 trie 树 \(T_1,T_2\),其中 \(T_1\) 是正常的 trie 树,\(T_2\) 是每个字符串翻转后的 trie 树。这样的话,包含给定后缀的字符串在 \(T_2\) 上构成一棵子树。例如,样例一中的两棵 trie 树如下:

考虑一个询问意味着什么,其实就是查询同时在 \(T_1\) 一棵子树和 \(T_2\) 一棵子树内的字符串数。例如 \((\texttt{AU},\texttt{C})\) 就是查询同时在 \(T_1\) 的节点 \(2\) 的子树和 \(T_2\) 的节点 \(1\) 的子树内的字符串数,只有 \(\texttt{AUGC}\) 符合要求,因此答案为 \(1\)

把每个字符串抽象为坐标 \((x,y)\) 的点,其中 \(x,y\) 分别是 \(T_1,T_2\) 上字符串终止节点的 dfs 序,那么一次询问就是查询一个矩形内部点的数量,是二维数点问题,扫描线维护即可。

//By: OIer rui_er
#include <bits/stdc++.h>
#define rep(x,y,z) for(int x=(y);x<=(z);x++)
#define per(x,y,z) for(int x=(y);x>=(z);x--)
#define debug(format...) fprintf(stderr, format)
#define fileIO(s) do{freopen(s".in","r",stdin);freopen(s".out","w",stdout);}while(false)
using namespace std;
typedef long long ll;

mt19937 rnd(std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::system_clock::now().time_since_epoch()).count());
int randint(int L, int R) {
    uniform_int_distribution<int> dist(L, R);
    return dist(rnd);
}

template<typename T> void chkmin(T& x, T y) {if(x > y) x = y;}
template<typename T> void chkmax(T& x, T y) {if(x < y) x = y;}

const int N = 2e6+5;

int n, m, ans[N];
string s, t;
vector<tuple<int, int>> dots;
vector<tuple<int, int, int, int>> queries;

int f(char c) {
    if(c == 'A') return 0;
    if(c == 'G') return 1;
    if(c == 'U') return 2;
    if(c == 'C') return 3;
    return -1;
}

struct Trie {
    int son[N][4], tot, sz[N], dfn[N], tms;
    int insert(string s) {
        int u = 0;
        for(char c : s) {
            int k = f(c);
            if(!son[u][k]) son[u][k] = ++tot;
            u = son[u][k];
        }
        return u;
    }
    int find(string s) {
        int u = 0;
        for(char c : s) {
            int k = f(c);
            if(!son[u][k]) return -1;
            u = son[u][k];
        }
        return u;
    }
    void dfs(int u) {
        dfn[u] = tms++;
        sz[u] = 1;
        rep(i, 0, 3) {
            if(son[u][i]) {
                dfs(son[u][i]);
                sz[u] += sz[son[u][i]];
            }
        }
    }
}pre, suf;

struct BIT {
    int c[N];
    int lowbit(int x) {return x & (-x);}
    void add(int x, int k) {++x; for(; x < N; x += lowbit(x)) c[x] += k;}
    int ask(int x) {++x; int k = 0; for(; x; x -= lowbit(x)) k += c[x]; return k;}
}bit;

int main() {
    // freopen("debug.log", "w", stderr);
    scanf("%d%d", &n, &m);
    rep(i, 1, n) {
        cin >> s;
        int x = pre.insert(s);
        reverse(s.begin(), s.end());
        int y = suf.insert(s);
        dots.emplace_back(x, y);
    }
    pre.tms = suf.tms = 0;
    pre.dfs(0);
    suf.dfs(0);
    for(auto& [x, y] : dots) {
        x = pre.dfn[x];
        y = suf.dfn[y];
    }
    rep(i, 1, m) {
        cin >> s >> t;
        reverse(t.begin(), t.end());
        int u = pre.find(s), v = suf.find(t);
        if(u == -1 || v == -1) ans[i] = 0;
        else {
            // debug("# %d : %d %d %d %d\n", i, pre.dfn[u], suf.dfn[v], pre.dfn[u]+pre.sz[u]-1, suf.dfn[v]+suf.sz[v]-1);
            queries.emplace_back(pre.dfn[u]-1, suf.dfn[v]-1, i, +1);
            queries.emplace_back(pre.dfn[u]-1, suf.dfn[v]+suf.sz[v]-1, i, -1);
            queries.emplace_back(pre.dfn[u]+pre.sz[u]-1, suf.dfn[v]-1, i, -1);
            queries.emplace_back(pre.dfn[u]+pre.sz[u]-1, suf.dfn[v]+suf.sz[v]-1, i, +1);
        }
    }
    sort(dots.begin(), dots.end());
    sort(queries.begin(), queries.end());
    // for(auto [x, y, id, mul] : queries) debug("* %d %d %d %d\n", x, y, id, mul);
    int ds = (int)dots.size(), qs = (int)queries.size(), did = 0, qid = 0;
    // debug("= %d %d\n", pre.tot, suf.tot);
    rep(i, 0, pre.tot) {
        // debug("? %d\n", i);
        while(did < ds) {
            auto [x, y] = dots[did];
            if(x != i) break;
            // debug("+ %d/%d : %d %d\n", did, ds, x, y);
            bit.add(y, 1);
            ++did;
        }
        while(qid < qs) {
            auto [x, y, id, mul] = queries[qid];
            if(x != i) break;
            // debug("! %d/%d : %d %d %d %d -> %d\n", qid, qs, x, y, id, mul, bit.ask(y));
            ans[id] += mul * bit.ask(y);
            ++qid;
        }
    }
    rep(i, 1, m) printf("%d\n", ans[i]);
    return 0;
}