这是提供一种时间复杂度不那么优秀但十分好写也好理解的做法。
题目大意
给定一颗 \(n\) 个节点的树,每个节点拥有一个颜色,进行若干次询问,每次询问给出两种颜色 \(A,B\),求所有颜色为 \(A\) 的节点的子树中颜色为 \(B\) 的节点的个数的和。
思路分析
考虑根号分治。按颜色的节点数进行分类,节点数 \(>\sqrt n\) 的称为重颜色,节点数 \(\le \sqrt n\) 的称为轻颜色。
对询问进行分类讨论:
- \(A,B\) 均为轻颜色:
考虑到两者的节点数都不会很多,可以暴力标记每一个颜色为 \(B\) 的点,并枚举所有颜色为 \(A\) 的点,求出其子树中的标记数并累加入答案。这一过程可以用树状数组实现。单次实现的时间复杂度为 \(O(\sqrt n\log n)\)。
- \(A\) 为重颜色,\(B\) 为轻颜色:
设 \(f_{i,j}\) 表示点 \(i\) 到根的路径中颜色为重颜色 \(j\) 的节点的数量,那么答案显然为:
其中,\(\text{col}_X\) 表示所有颜色为 \(X\) 的点所构成的集合。
考虑计算 \(f\):
枚举每一个重颜色 \(K\),遍历 \(\text{col}_K\),将其中的每一个点的子树中的所有点的权值加 \(1\),再遍历每一个颜色不为 \(K\) 的点,其权值就是其对应的 \(f\)。这一过程可以用差分实现。时间复杂度为 \(O(n\sqrt n)\)。
- \(B\) 为重颜色,\(A\) 为轻颜色:
设 \(g_{i,j}\) 表示点 \(i\) 的子树中颜色为重颜色 \(j\) 的节点的数量,那么答案显然为:
\(g\) 的计算与 \(f\) 正好相反:
枚举每一个重颜色 \(K\),遍历 \(\text{col}_K\),将其中的每一个点的权值单点加 \(1\)。再遍历每一个颜色不为 \(K\) 的点,其子树中的权值和即为其对应的 \(g\),这一过程可以用前缀和实现。时间复杂度为 \(O(n\sqrt n)\)。
- \(A,B\) 均为重颜色:
容易发现,这种情况是上述两种情况的交集,上述两种情况的答案计算方式均适用于此情况。
实现细节
在实现中,为了优化空间复杂度,可以省去 \(f\) 和 \(g\) 直接累加答案。具体的说,设 \(\text{ans1}_{i,j}\) 表示所有颜色为 \(i\) 的点的子树中颜色为重颜色 \(j\) 的节点的数量的和,\(\text{ans2}_{i,j}\) 表示所有颜色为 \(i\) 的点到根的路径中颜色为重颜色 \(j\) 的节点的数量的和。(也就是第三,二种情况对应的答案)这样可以将空间复杂度由 \(O(n\sqrt n)\) 优化到 \(O(R\sqrt n)\)。
差分和前缀和可以直接在 dfs 序上进行,不需要放到树上。
总时间复杂度为 \(O(q\sqrt n\log n+n\sqrt n)\),瓶颈在于树状数组,卡常后可以通过。
空间复杂度为 \(O(R\sqrt n+n)\),完全可过。
代码
#include <iostream>
#include <cstring>
#include <cstdio>
#include <cmath>
#include <algorithm>
#include <vector>
using namespace std;
const int N=200100,M=450,R=25500;
int idx=1,n,m,q,B,cnt,in1,in2,cur,ans;
int w[N],to[N],nxt[N],head[N];
int c[N],dfn[N],siz[N],big[N],d[N];
int ans1[M][R],ans2[M][R];
vector<int> col[N];
int read(){//卡常用的快读快写
int x=0;char ch=getchar();
while(ch>'9'||ch<'0') ch=getchar();
while(ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
return x;
}
void write(int x){
int k=x/10;if(k) write(k);
putchar(x-(k<<3)-(k<<1)+'0');
}
void add(int u,int v){
idx++;to[idx]=v;nxt[idx]=head[u];head[u]=idx;
}
void dfs_1(int s){
dfn[s]=++cnt;siz[s]=1;
for(int i=head[s];i;i=nxt[i])
dfs_1(to[i]),siz[s]+=siz[to[i]];
}
struct BIT{
int a[N];
#define lowbit(x) ((x)&(-(x)))
void add(int x,int v){
for(;x<N;x+=lowbit(x)) a[x]+=v;
}
int query(int x){
if(!x) return 0;
int ans=0;
for(;x;x-=lowbit(x)) ans+=a[x];
return ans;
}
}tree;
int main(){
n=read();m=read();q=read();
B=sqrt(n);
w[1]=read();
col[w[1]].push_back(1);
for(int i=2;i<=n;i++){
in1=read();w[i]=read();
add(in1,i);
col[w[i]].push_back(i);
}
dfs_1(1);
for(int i=1;i<=m;i++){
if(col[i].size()<=B) continue;
big[i]=++cur;//对每一种重颜色重标号
for(int j=1;j<=n;j++) d[j]=0;
for(auto it:col[i]) d[dfn[it]]++;//单点加
for(int j=1;j<=n;j++) d[j]+=d[j-1];//前缀和
for(int j=1;j<=n;j++){
if(w[j]==i) continue;
ans1[cur][w[j]]+=d[dfn[j]+siz[j]-1]-d[dfn[j]-1];//直接累加答案
}
for(int j=1;j<=n;j++) d[j]=0;
for(auto it:col[i]){
d[dfn[it]]++;//差分
d[dfn[it]+siz[it]]--;//这里其实是 dfn[it]+siz[it]-1+1
}
for(int j=1;j<=n;j++) d[j]+=d[j-1];//对差分序列做前缀和得到原序列
for(int j=1;j<=n;j++){
if(w[j]==i) continue;
ans2[cur][w[j]]+=d[dfn[j]];
}
}
while(q--){
in1=read();in2=read();ans=0;
if(big[in2]){write(ans1[big[in2]][in1]);cout<<'\n'<<flush;continue;}
if(big[in1]){write(ans2[big[in1]][in2]);cout<<'\n'<<flush;continue;}
for(auto it:col[in2]) tree.add(dfn[it],1);//暴力加
for(auto it:col[in1])
ans+=tree.query(dfn[it]+siz[it]-1)-tree.query(dfn[it]-1);
write(ans);cout<<'\n'<<flush;
for(auto it:col[in2]) tree.add(dfn[it],-1);//记得清空
}
return 0;
}