CF1806E

发布时间 2023-04-24 19:44:06作者: untitled0

题面

看起来是个 DS 题,事实上是个乱搞题,做法挺多的。由于它给的这个结构看起来就不好优化,所以考虑随机化。

由于两个点到达 LCA 后剩下的贡献就是 LCA 到根的每个点权值的平方,这部分可以 \(O(n)\) 预处理,所以只需要考虑两个点之间的路径所产生的贡献。

在树上随机撒 \(\sqrt n\) 个点,称这些点为关键点,那么两个关键点之间的期望距离应 \(\sqrt n\)。预处理任意两个关键点之间的贡献,时间复杂度 \(O(n\sqrt n)\)

对于一个询问 \(x,y\),从这两个点分别暴力向上跳,直到碰到关键点,然后用关键点间的贡献加上跳的时候计算的贡献即可。由于一个点到最近的关键点的期望距离为 \(\sqrt n\),所以这部分时间复杂度 \(O(q\sqrt n)\),总时间复杂度 \(O((n+q)\sqrt n)\)。实测常数略大,最大点用时约 2.6s。

代码:

#pragma GCC optimize(2)
#include<bits/stdc++.h>
#define endl '\n'
#define rep(i, s, e) for(int i = s, i##E = e; i <= i##E; ++i)
#define per(i, s, e) for(int i = s, i##E = e; i >= i##E; --i)
#define F first
#define S second
#define int ll
#define gmin(x, y) ((x > y) && (x = y))
#define gmax(x, y) ((x < y) && (x = y))
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
constexpr int N = 1e5 + 5, B = 350;
int n, q, a[N], s[N], f[N], dep[N], id[N], st[B], dis[B][B];
vector<int> to[N];
mt19937 gen(chrono::high_resolution_clock::now().time_since_epoch().count());
void dfs1(int p) {
    dep[p] = dep[f[p]] + 1;
    s[p] = a[p] * a[p] + s[f[p]];
    for(auto i : to[p]) dfs1(i);
}
signed main() {
#ifdef ONLINE_JUDGE
    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);
#endif
    cin >> n >> q;
    rep(i, 1, n) cin >> a[i];
    rep(i, 2, n) {
        cin >> f[i];
        to[f[i]].push_back(i);
    }
    dfs1(1);
    uniform_int_distribution<> rnd(1, n);
    int num = sqrt(n);
    rep(i, 1, num) {
        int t = rnd(gen);
        st[i] = t;
        id[t] = i;
    }
    rep(i, 1, num) {
        dis[i][i] = s[i];
        rep(j, i + 1, num) {
            int u = st[i], v = st[j], res = 0;
            if(dep[u] < dep[v]) swap(u, v);
            while(dep[u] > dep[v]) u = f[u];
            while(u != v) {
                res += a[u] * a[v];
                u = f[u], v = f[v];
            }
            dis[i][j] = dis[j][i] = res + s[u];
        }
    }
    while(q--) {
        int u, v; cin >> u >> v;
        int fu = 0, fv = 0, res = 0;
        while((!fu || !fv) && u != v) {
            if(id[u]) fu = id[u];
            if(id[v]) fv = id[v];
            if(!fu || !fv) res += a[u] * a[v];
            u = f[u], v = f[v];
        }
        if(fu && fv) {
            cout << res + dis[fu][fv] << endl;
        }
        else cout << res + s[u] << endl;
    }
    return 0;
}