Luogu 5439 XR-2永恒

发布时间 2023-07-21 08:29:57作者: Ender_32k

\(T\) 是节点数为 \(n\) 的那棵树,\(T'\) 是 Trie 树。带 \('\) 的,比如 \(\text{dep}'_u\),表示 Trie 上的信息(注意到 \(\text{dep}'\) 要从 \(0\) 开始),不带的表示原树。\([u,v]\) 表示 \(u\to v\) 的路径,\(S\) 是原树上无序点对的全集。

那么答案就是:

\[\sum\limits_{(s,t)\in S}\sum\limits_{[u,v]\subseteq [s,t]}\text{dep}'_{\text{lca}(u,v)} \]

考虑枚举 \((u,v)\) 分情况计算贡献:

  • \(\text{lca(u,v)}\neq u,v\):显然包含 \([u,v]\) 的路径数为 \(\text{siz}_u\text{siz}_v\)
  • 否则设 \(u\)\(v\) 的祖先,\(w\)\(u\)\(v\) 延伸的那个儿子,那么经过 \([u,v]\) 的路径数为 \((n-\text{siz}_w)\text{siz}_v\)

考虑第一种贡献咋算,根据某道经典题的套路,一个简单的想法是枚举 \(u\),然后 \([d_u\to \text{root}']\) 上面所有点增加 \(\text{siz}_u\),然后枚举 \(v\),对 \([d_v\to\text{root}']\) 求和乘上 \(\text{siz}_v\) 即可。但是要除掉根的贡献,因为 \(\text{dep}'_{d_\text{root}}=0\)

然后对于第二种贡献,我们直接对 \(T\) 搜一边,到了一个点 \(v\),考虑它到 \(\text{root}\) 的路径对 \(v\) 的贡献和。我们到一个点 \(u\),枚举出边 \(u\to w\),直接给 \([u,\text{root}]\) 加上 \(n-\text{siz}_w\),然后往 \(w\) 走,回溯的时候再减去即可。

但是这样会算重,发现在第一种贡献中,在 \(u\) 子树内的 \(v\) 也被计算了。冷静分析一下,因为算重的 \((u,v)\) 均满足 \(v\)\(u\) 子树内,所以我们做第二种贡献的时候将增量 \(n-\text{siz}_w\) 变成 \(n-\text{siz}_w-\text{siz}_u\) 即可。

然后线段树随便做了,复杂度 \(O(n\log^2 n)\)。自认为写得很清新。

#include <bits/stdc++.h>
using namespace std;

namespace vbzIO {
    char ibuf[(1 << 20) + 1], *iS, *iT;
    #if ONLINE_JUDGE
    #define gh() (iS == iT ? iT = (iS = ibuf) + fread(ibuf, 1, (1 << 20) + 1, stdin), (iS == iT ? EOF : *iS++) : *iS++)
    #else
    #define gh() getchar()
    #endif
    #define pc putchar
    #define pi pair<int, int>
    #define tu3 tuple<int, int, int>
    #define tu4 tuple<int, int, int, int>
    #define mp make_pair
    #define mt make_tuple
    #define fi first
    #define se second
    #define pb push_back
    #define ins insert
    #define era erase
    inline int read () {
        char ch = gh();
        int x = 0;
        bool t = 0;
        while (ch < '0' || ch > '9') t |= ch == '-', ch = gh();
        while (ch >= '0' && ch <= '9') x = (x << 1) + (x << 3) + (ch ^ 48), ch = gh();
        return t ? ~(x - 1) : x;
    }
    inline void write(int x) {
        if (x < 0) {
            x = ~(x - 1);
            putchar('-');
        }
        if (x > 9)
            write(x / 10);
        putchar(x % 10 + '0');
    }
}
using vbzIO::read;
using vbzIO::write; 

const int mod = 998244353;
const int maxn = 3e5 + 300;
const int inv2 = (mod + 1) / 2;
struct seg { int sum, tag; } tr[maxn << 2];
int n, m, rt, dfc, ans, fa[maxn], a[maxn];
int sz[maxn], dep[maxn], siz[maxn], fath[maxn], son[maxn], top[maxn], id[maxn], d[maxn];
vector<int> t1[maxn], t2[maxn];

#define ls x << 1
#define rs x << 1 | 1
#define mid ((l + r) >> 1)
void pushup(int x) { tr[x].sum = (tr[ls].sum + tr[rs].sum) % mod; }
void pushtag(int x, int c, int l, int r) { (tr[x].sum += 1ll * c * (r - l + 1) % mod) %= mod, (tr[x].tag += c) %= mod; }
void pushdown(int x, int l, int r) {
	if (!tr[x].tag) return;
	pushtag(ls, tr[x].tag, l, mid), pushtag(rs, tr[x].tag, mid + 1, r);
	tr[x].tag = 0;
}

void upd(int l, int r, int s, int t, int c, int x) {
	if (s <= l && r <= t) return pushtag(x, c, l, r);
	pushdown(x, l, r);
	if (s <= mid) upd(l, mid, s, t, c, ls);
	if (t > mid) upd(mid + 1, r, s, t, c, rs);
	pushup(x);
}

int qry(int l, int r, int s, int t, int x) {
	if (s <= l && r <= t) return tr[x].sum;
	int res = 0;
	pushdown(x, l, r);
	if (s <= mid) (res += qry(l, mid, s, t, ls)) %= mod;
	if (t > mid) (res += qry(mid + 1, r, s, t, rs)) %= mod;
	return res;
}

void updp(int u, int c) {
	c %= mod;
	while (u) {
		upd(0, m, id[top[u]], id[u], c, 1);
		u = fath[top[u]];
	}
}

int qryp(int u) {
	int res = 0;
	while (u) {
		(res += qry(0, m, id[top[u]], id[u], 1)) %= mod;
		u = fath[top[u]];
	}
	return res;
}

void dfs1(int u, int fat) {
	siz[u] = 1, dep[u] = dep[fat] + 1;
	for (int v : t2[u]) {
		dfs1(v, u), siz[u] += siz[v];
		if (siz[v] > siz[son[u]]) son[u] = v; 
	}
}

void dfs2(int u, int pre) {
	top[u] = pre, id[u] = ++dfc;
	if (son[u]) dfs2(son[u], pre);
	for (int v : t2[u]) {
		if (v == son[u]) continue;
		dfs2(v, v);
	}
} 

void dfs3(int u) {
	sz[u] = 1;
	for (int v : t1[u]) 
		dfs3(v), sz[u] += sz[v];
}

void dfs4(int u) {
	for (int v : t2[u]) 
		dfs4(v), (d[u] += d[v]) %= mod;
}

void dfs5(int u, int fat) {
	(d[u] += d[fat]) %= mod;
	for (int v : t2[u]) dfs5(v, u);
}

void dfs6(int u) {
	(ans += 1ll * sz[u] * qryp(a[u]) % mod) %= mod;
	for (int v : t1[u]) {
		updp(a[u], mod + n - sz[u] - sz[v]);
		dfs6(v);
		updp(a[u], mod - n + sz[u] + sz[v]); 
	}
}

int main() {
	n = read(), m = read();
	for (int i = 1; i <= n; i++) {
		fa[i] = read(); 
		if (!fa[i]) rt = i;
		else t1[fa[i]].pb(i);
	}
	for (int i = 1; i <= m; i++) {
		fath[i] = read();
		if (fath[i]) t2[fath[i]].pb(i);
	}
	scanf("%*s");
	for (int i = 1; i <= n; i++) a[i] = read();
	for (int u : t2[1]) dep[u] = 1, fath[u] = 0, dfs1(u, 0), dfs2(u, u);
	dfs3(rt);
	for (int i = 1; i <= n; i++) (d[a[i]] += sz[i]) %= mod;
	dfs4(1), d[1] = 0, dfs5(1, 0);
	for (int i = 1; i <= n; i++) {
		ans = (ans + 1ll * sz[i] * d[a[i]] % mod) % mod;
		ans = (ans + mod - 1ll * sz[i] * sz[i] % mod * dep[a[i]] % mod) % mod;
	}
	ans = 1ll * ans * inv2 % mod;
	dfs6(rt);
	write(ans);
	return 0;
}