Educational Codeforces Round 97 (Rated for Div 2) G. Death DBMS

发布时间 2023-09-23 10:55:26作者: muxingchengfeng

Problem - G - Codeforces

题意

给定n个字符串,每个字符串有一个值val,n次询问,每次给一个字符串,询问给定n个字符串中是询问字符串子串的值的最大值

分析

多模式匹配,从中找到给定串的子串,想到建立ac自动机,对于给定字符串,在自动机上面匹配时,沿fail指针向上跳并求最大值即可,由于n个字符串总长度M=3E5,所有长度不同的字符串最多有$\sqrt{M}$,所以fai树中的链长最多不超过$\sqrt{M}$,由于匹配字符串长度不超过M=3E5,所以总跳数小于M$\sqrt{M}$,对于每个结点的可能存在多个给定串,使用muityset维护val,总复杂度O(nlogn + M$\sqrt{M}$)

#include <bits/stdc++.h>

using i64 = long long;

struct ACA
{
    std::vector<std::array<int, 26>> tr;
    std::vector<int> ne;
    std::vector<std::set<int>> id;
    std::vector<int> invid;
    std::vector<std::multiset<int>> val;
    std::vector<int> cnt;
    std::vector<int> q;
    std::vector<std::vector<int>> adj;
    int idx;

    ACA(int n) : tr(n + 1), ne(n + 1), invid(n + 1), id(n + 1), cnt(n + 1), q(n + 1), val(n + 1)
    {
        adj.assign(n + 1, std::vector<int>());
        idx = 0;
    }

    void insert(std::string s, int x)
    {
        int p = 0;
        int sz = s.size();
        for (int i = 0; i < sz; i++)
        {
            int t = s[i] - 'a';
            if (!tr[p][t])
                tr[p][t] = ++idx;
            p = tr[p][t];
        }
        id[p].insert(x); // p是x结尾的数据
        val[p].insert(0);
        invid[x] = p;
    }

    void build()
    {
        int hh = 0, tt = -1;
        for (int i = 0; i < 26; i++)
            if (tr[0][i])
                q[++tt] = tr[0][i];
        while (hh <= tt)
        {
            int p = q[hh++];
            adj[ne[p]].push_back(p);
            for (int i = 0; i < 26; i++)
                if (!tr[p][i])
                    tr[p][i] = tr[ne[p]][i];
                else
                {
                    ne[tr[p][i]] = tr[ne[p]][i];
                    q[++tt] = tr[p][i];
                }
        }
    }

    void query(std::string str)
    {
        int sz = str.size();
        for (int i = 0, now = 0; i < sz; i++)
        {
            int t = str[i] - 'a';
            now = tr[now][t];
            int p = now;
            cnt[p]++; // 对应的节点加一,在最后bfs是保证前节点全部加一
        }
    }

    void dfs(int u)
    {
        for (auto v : adj[u])
        {
            dfs(v);
            cnt[u] += cnt[v];
        }
    }
};

constexpr int N = 3E5;

int main()
{
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);

    ACA aca(N);

    int n, q;
    std::cin >> n >> q;
    std::vector<int> a(n + 1, 0);
    for (int i = 1; i <= n; i++)
    {
        std::string s;
        std::cin >> s;
        aca.insert(s, i);
    }

    aca.build();

    std::vector<int> up(aca.idx + 1);

    auto dfs1 = [&](auto self, int u, int last) -> void
    {
        up[u] = last;
        if (aca.id[u].size())
            last = u;
        for (auto v : aca.adj[u])
            self(self, v, last);
    };

    dfs1(dfs1, 0, 0);

    while (q--)
    {
        int op;
        std::cin >> op;
        if (op == 1)
        {
            int x, y;
            std::cin >> x >> y;
            int id = aca.invid[x];
            aca.val[id].erase(aca.val[id].find(a[x]));
            aca.val[id].insert(y);
            a[x] = y;
        }
        else
        {
            std::string str;
            std::cin >> str;
            int ans = -1;
            for (int i = 0, now = 0; i < str.size(); i++)
            {
                int t = str[i] - 'a';
                now = aca.tr[now][t];
                int p = now;
                while (p)
                {
                    if (aca.val[p].size())
                        ans = std::max(ans, *(--aca.val[p].end()));
                    p = up[p];
                }
            }
            std::cout << ans << "\n";
        }
    }
    return 0;
}