「解题报告」CF809E Surprise me!

发布时间 2023-06-05 09:23:41作者: APJifengc

好像是典题。

简单莫反一下。

\[\begin{aligned} & \sum_{i=1}^n \sum_{j=1}^n \varphi(a_i \cdot a_j) \operatorname{dis}(i, j)\\ =& \sum_{i=1}^n \sum_{j=1}^n \frac{\varphi(a_i) \varphi(a_j) \gcd(a_i, a_j)}{\varphi(\gcd(a_i, a_j))} \operatorname{dis}(i, j)\\ =& \sum_{d=1}^n\frac{d}{\varphi(d)}\sum_{i=1}^n \sum_{j=1}^n \varphi(a_i) \varphi(a_j) \operatorname{dis}(i, j) [\gcd(a_i, a_j) = d]\\ =& \sum_{d=1}^n\frac{d}{\varphi(d)}\sum_{i=1}^n \sum_{j=1}^n \varphi(i) \varphi(j) \operatorname{dis}(b_i, b_j) [\gcd(i, j) = d] & (b_{a_i} := i)\\ =& \sum_{d=1}^n\frac{d}{\varphi(d)} \sum_{i=1}^{n} \sum_{j=1}^{n} \varphi(i) \varphi(j) \operatorname{dis}(b_{i}, b_{j}) \sum_{kd | \gcd(i, j)} \mu(k)\\ =& \sum_{T=1}^n \sum_{d | T} \frac{d}{\varphi(d)} \mu\left(\frac{T}{d}\right) \sum_{i=1}^{\lfloor\frac{n}{T}\rfloor}\sum_{j=1}^{\lfloor\frac{n}{T}\rfloor} \varphi(iT) \varphi(jT) \operatorname{dis}(b_{iT}, b_{jT})\\ \end{aligned} \]

前面那个东西可以直接暴力算出来,考虑后面的东西。后面的东西实际上是在求一个集合中的点两两距离乘两个端点的第权之和,而容易发现集合总大小是 \(O(n \log n)\) 的,所以可以直接对每个 \(T\) 建出虚树,在虚树上统计答案即可。统计答案很简单,直接枚举每一条边,看有多少路径经过这条边即可。

#include <bits/stdc++.h>
using namespace std;
const int MAXN = 200005, P = 1000000007;
int n;
int a[MAXN], b[MAXN];
vector<int> e[MAXN];
int phi[MAXN], mu[MAXN], dphi[MAXN];
int pri[MAXN], pcnt;
bool vis[MAXN];
int g[MAXN];
int qpow(int a, int b) {
    int ans = 1;
    while (b) {
        if (b & 1) ans = 1ll * ans * a % P;
        a = 1ll * a * a % P;
        b >>= 1;
    }
    return ans;
}
void sieve() {
    phi[1] = mu[1] = dphi[1] = 1;
    for (int i = 2; i <= n; i++) {
        if (!vis[i]) {
            pri[++pcnt] = i; mu[i] = -1, dphi[i] = 1ll * i * qpow(i - 1, P - 2) % P, phi[i] = i - 1;
        }
        for (int j = 1; j <= pcnt && i * pri[j] <= n; j++) {
            vis[i * pri[j]] = 1;
            if (i % pri[j] == 0) {
                dphi[i * pri[j]] = dphi[i];
                phi[i * pri[j]] = phi[i] * pri[j];
                break;
            } else {
                dphi[i * pri[j]] = 1ll * dphi[i] * dphi[pri[j]] % P;
                phi[i * pri[j]] = phi[i] * (pri[j] - 1);
                mu[i * pri[j]] = -mu[i];
            }
        }
    }
    for (int i = 1; i <= n; i++) {
        for (int j = 1; i * j <= n; j++) {
            g[i * j] = (g[i * j] + 1ll * dphi[i] * mu[j] + P) % P;
        }
    }
}
int dfn[MAXN], dcnt, fa[MAXN], dep[MAXN], top[MAXN], siz[MAXN], son[MAXN];
void dfs1(int u, int pre) {
    fa[u] = pre, dep[u] = dep[pre] + 1, siz[u] = 1;
    for (int v : e[u]) if (v != pre) {
        dfs1(v, u);
        siz[u] += siz[v];
        if (siz[v] > siz[son[u]]) son[u] = v;
    }
}
void dfs2(int u, int pre, int t) {
    top[u] = t, dfn[u] = ++dcnt;
    if (son[u]) dfs2(son[u], u, t);
    for (int v : e[u]) if (v != pre && v != son[u]) {
        dfs2(v, u, v);
    }
}
int lca(int u, int v) {
    while (top[u] != top[v]) dep[top[u]] > dep[top[v]] ? u = fa[top[u]] : v = fa[top[v]];
    return dep[u] > dep[v] ? v : u;
}
int dis(int u, int v) {
    return dep[u] + dep[v] - 2 * dep[lca(u, v)];
}
int val[MAXN];
int m, p[MAXN << 1];
vector<pair<int, int>> t[MAXN];
int sval[MAXN];
bool mark[MAXN];
int tot;
int dfs3(int u) {
    int ret = 0;
    sval[u] = mark[u] ? val[u] : 0;
    for (auto p : t[u]) {
        int v = p.first, w = p.second;
        ret = (ret + dfs3(v)) % P;
        ret = (ret + 1ll * w * sval[v] % P * (tot - sval[v] + P)) % P;
        sval[u] = (sval[u] + sval[v]) % P;
    }
    return ret;
}
int solve() {
    tot = 0;
    for (int i = 1; i <= m; i++) mark[p[i]] = 1, tot = (tot + val[p[i]]) % P;
    sort(p + 1, p + 1 + m, [&](int a, int b) { return dfn[a] < dfn[b]; });
    for (int i = 1; i < m; i++) p[m + i] = lca(p[i], p[i + 1]);
    m += m - 1;
    sort(p + 1, p + 1 + m, [&](int a, int b) { return dfn[a] < dfn[b]; });
    m = unique(p + 1, p + 1 + m) - p - 1;
    for (int i = 1; i < m; i++) {
        int u = lca(p[i], p[i + 1]), v = p[i + 1];
        t[u].push_back({ v, dis(u, v) });
    }
    int ret = dfs3(p[1]);
    for (int i = 1; i <= m; i++) mark[p[i]] = 0, t[p[i]].clear();
    return ret;
}
int main() {
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) {
        scanf("%d", &a[i]);
        b[a[i]] = i;
    }
    sieve();
    for (int i = 1; i < n; i++) {
        int u, v; scanf("%d%d", &u, &v);
        e[u].push_back(v), e[v].push_back(u);
    }
    dfs1(1, 0), dfs2(1, 0, 1);
    for (int i = 1; i <= n; i++) val[i] = phi[a[i]];
    int ans = 0;
    for (int T = 1; T <= n; T++) {
        m = 0;
        for (int i = 1; i * T <= n; i++) {
            p[++m] = b[i * T];
        }
        int val = solve();
        // printf("%d: %d, %d\n", T, g[T], val);
        ans = (ans + 1ll * g[T] * val) % P;
    }
    printf("%lld\n", 2ll * ans * qpow(1ll * n * (n - 1) % P, P - 2) % P);
    return 0;
}