[Ynoi2016] 这是我自己的发明(根号分治+分块/莫队)

发布时间 2023-08-03 21:48:07作者: 暗蓝色的星空

题目传送门

soltion

简单题

换根显然可以拆成 \(O(1)\) 个区间,这里先不管。

直接做法是莫队,把双子树拆成 \(dfs\) 序上的双前缀,可以直接莫队,但是常数比较大。

另一种做法是根分,对颜色出现次数分治,大于的求出 \(dfs\) 序的前缀和即可,小于的因为一共只有 \(O(n\sqrt n)\) 个点对,所以变成二维数点扫描线做,用分块平衡一下即可。

#include<bits/stdc++.h>
using namespace std;
const int N = 1e5+7;
const int M = 5e5+7;
int n,m;
struct edge 
{
	int y,next;
}e[2*N];
int flink[N],t=0;
void add(int x,int y)
{
	e[++t].y=y;
	e[t].next=flink[x];
	flink[x]=t;
}
int st[N],ed[N],seq[N],tot=0;
int jump[N][20],dep[N];
void dfs(int x,int pre)
{
	dep[x]=dep[pre]+1;
	seq[st[x]=++tot]=x;
	jump[x][0]=pre;
	for(int k=1;jump[x][k-1];k++)jump[x][k]=jump[jump[x][k-1]][k-1];
	for(int i=flink[x];i;i=e[i].next)
	{
		int y=e[i].y;
		if(y==pre)continue;
		dfs(y,x); 
	}
	ed[x]=tot;
}
int find(int x,int y)
{
	for(int k=19;k>=0;k--)
	if(dep[jump[x][k]]>dep[y])x=jump[x][k];
	return x;
}
bool Anc(int x,int y)
{
	return st[x]<=st[y]&&st[y]<=ed[x];
}
int a[N],dct[N],num=0;
int B;
vector<int> pos[N];
#define PII pair<int,int>
#define mk(x,y) make_pair(x,y)
#define X(x) x.first
#define Y(x) x.second
vector<PII> SubTree(int r,int x)
{
	if(!Anc(x,r)) return {mk(st[x],ed[x])};
	if(r==x) return {mk(1,n)};
	int y=find(r,x);
	return {mk(1,st[y]-1),mk(ed[y]+1,n)};
}
typedef long long LL;
LL Ans[M];
int op[M];
vector<PII> Sx[M],Sy[M];
int s[N];
inline int query(int l,int r){return s[r]-s[l-1];}
struct Query 
{
	int id,l,r,v;
};
vector<Query> Q[N];  
inline void Apply(int u,int l,int r,int L,int R)
{
	if(l>r||L>R)return;
	Q[l-1].push_back((Query){u,L,R,-1});
	Q[r].push_back((Query){u,L,R,1});
}
int sum[N],w[N],bel[N],L[N],R[N];
inline void upd(int x)
{
	//printf("upd(%d)\n",x);
	w[x]++;
	sum[bel[x]]++;
}
int ask(int l,int r)
{
	//printf("ask(%d %d)\n",l,r);
	int res=0;
	if(bel[l]==bel[r])
	{
		for(int i=l;i<=r;i++)res+=w[i];
		return res;
	}
	for(int i=l;i<=R[bel[l]];i++)res+=w[i];
	for(int i=L[bel[r]];i<=r;i++)res+=w[i];
	for(int i=bel[l]+1;i<bel[r];i++)res+=sum[i];
	return res;
}
int main()
{
	cin>>n>>m;B=sqrt(2*n);
	for(int i=1;i<=n;i++)scanf("%d",&a[i]),dct[++num]=a[i];
	sort(dct+1,dct+num+1);
	num=unique(dct+1,dct+num+1)-dct-1;
	for(int i=1;i<=n;i++)
	{
		a[i]=lower_bound(dct+1,dct+num+1,a[i])-dct;
		pos[a[i]].push_back(i);
	}
	for(int i=2;i<=n;i++)
	{
		int x,y;
		scanf("%d %d",&x,&y);
		add(x,y);
		add(y,x);
	}
	dfs(1,0);
	int r=1;
//	for(int i=1;i<=n;i++)
//	cout<<seq[i]<<endl;
	for(int i=1;i<=m;i++)
	{
		int x,y;
		scanf("%d",&op[i]);
		if(op[i]==1)
		{
			scanf("%d",&x);
			r=x;
		}
		else 
		{
			scanf("%d %d",&x,&y);
			Sx[i]=SubTree(r,x);
			Sy[i]=SubTree(r,y);
			for(auto u:Sx[i])
			for(auto v:Sy[i])
			Apply(i,X(u),Y(u),X(v),Y(v));
		}
	}
	for(int c=1;c<=num;c++)if(pos[c].size()>B)
	{
		for(int i=1;i<=n;i++)s[i]=s[i-1]+(a[seq[i]]==c);
		for(int i=1;i<=m;i++)
		{
			if(op[i]==2)
			{
				for(auto x:Sx[i])
				for(auto y:Sy[i])
				Ans[i]+=1ll*query(X(x),Y(x))*query(X(y),Y(y));
			}
		}
	}
	int O=sqrt(n);
	for(int i=1;i<=n;i++)
	{
		bel[i]=(i-1)/O+1;
		R[bel[i]]=i;
		if(!L[bel[i]])L[bel[i]]=i; 
	}
	for(int i=1;i<=n;i++)
	{
		if(pos[a[seq[i]]].size()<=B)for(int x:pos[a[seq[i]]])upd(st[x]);
		for(auto U:Q[i]) Ans[U.id]+=1ll*U.v*ask(U.l,U.r);
	}
	for(int i=1;i<=m;i++)
	if(op[i]==2)printf("%lld\n",Ans[i]);
	return 0;
}