洛谷 P5669 [SDOI2018] 原题识别-改 题解--zhengjun

发布时间 2023-12-21 16:00:06作者: A_zjzj

题面

鉴于这题目前还没题解,提供一种时间 \(\Theta(n\sqrt{m})\),空间 \(\Theta(n+m)\) 的做法。

询问 1

可以直接上树分块或者树上莫队,见 P6177 Count on a tree II/【模板】树分块

但是因为本题询问 2 的做法,所以我采用了树上莫队的做法。

询问 2

方便起见:

  • \(\operatorname{path}(u,v)\) 表示 \(u,v\) 路径之间的所有点构成的集合;
  • \(f(u,v)\) 表示 \(u,v\) 路径上本质不同的颜色数。

这里直接考虑 \(u,v\) 不为祖先关系的情况(\(u,v\) 为祖先关系的情况显然严格弱于这个,特判一下即可)。

所以答案即为:

\[\sum\limits_{i\in \operatorname{path}(1,u)}\sum\limits_{j\in \operatorname{path}(1,v)}f(i,j) \]

因为我们发现,答案的形式非常像对于一个区间的所有子区间求和,那么我们引入新的函数:

\[F(u,v)=\sum\limits_{i,j\in \operatorname{path}(u,v)}f(i,j) \]

首先观察这个 \(F\),将 \(\operatorname{path}(u,v)\) 理解为一个区间 \([1,m]\)

它的实际意义就是 \([1,m]\) 的所有子区间的本质不同的颜色数之和。

但是这样并不好计算,我们考虑另外一种实际意义:对于每种颜色,计算 \([1,m]\) 的所有子区间中包含该颜色的个数和。

如果把颜色 \(c\) 删去,序列剩下来长度为 \(l_1,l_2,\cdots ,l_{k_c}\)\(k_c\) 段连续区间,那么该颜色的贡献就是 \(\binom{n+1}{2}-\sum \binom{l_i+1}{2}\)

那么,如果在序列的后面加入一个元素,那么答案的增量就是 \(\sum\limits_{c}n'-l'_{c,k'_c}\)

所以,我们如果我们维护出了 \(suf_c=l_{c,k_c}\) 以及它的和,那么我们就可以 \(\Theta(1)\) 向右边扩展了。

注意,我们同时可以 \(\Theta(1)\) 删除最后一个位置。

使用链表维护相同颜色的位置,并实时记录每个颜色的起始位置,维护出 \(\sum suf_c\)\(\sum pre_c\),这样左右端点都能够 \(\Theta(1)\) 左右移动了。

现在,我们就可以使用树上莫队来计算 \(F(u,v)\) 了。

接下来考虑怎么计算答案。

设询问的两个节点分别为 \(u,v\)

\(t\)\(u,v\) 的最近公共祖先 \((t\ne u , t\ne v)\)

\(p,q\)\(t\) 的两个不同的儿子且 \(p\in \operatorname{path}(u,t),q\in \operatorname{path}(v,t)\)

考虑对答案进行转化,这里直接给出结果:

\[\begin{aligned} ans & = \sum\limits_{i\in \operatorname{path}(1,u)}\sum\limits_{j\in \operatorname{path}(1,v)}f(i,j)\\ & = F(1,u)+F(1,v)-|\operatorname{path}(1,t)|+F(u,v)-F(u,p)-F(v,q)-F(u,t)-F(v,t)+1 \end{aligned} \]

其中后面一大坨的尾巴是 $ \operatorname{path}(u,p)$ 和 $ \operatorname{path}(v,q)$ 之间的贡献,即:

\[\sum\limits_{i\in \operatorname{path}(u,p)}\sum\limits_{j\in \operatorname{path}(v,q)}f(i,j)=F(u,v)-F(u,p)-F(v,q)-F(u,t)-F(v,t)+1 \]

最后加一是因为 \(f(t,t)\) 被两边都减了一遍,类似于容斥。

而剩下的贡献就是 \(F(1,u)+F(1,v)-|\operatorname{path}(1,t)|\)

  • 减去 \(|\operatorname{path}(1,t)|\) 的就是 \(\forall i\in \operatorname{path}(1,t),f(i,i)\) 都算了两遍;
  • 而其余的 \(\forall i,j\in \operatorname{path}(1,t) \land i\ne j,f(i,j)\) 本身就应该算两遍。

做到这里似乎已经做完了……

细节处理:

  • 维护一条链的情况时,需要使用循环队列,因为左右端点都可能移动很远,但是任意时刻序列长度都不超过 \(n\)

本人直接写完后不卡常最大点用时 4.05s,经过调整块长、对莫队的排序进行奇偶优化过后,最大点用时 1.96s,效率还行,毕竟询问 2 有个 \(5\) 倍常数。

代码

#include<bits/stdc++.h>
using namespace std;
using ll=long long;
#ifdef DEBUG
template<class T>
ostream& operator << (ostream &out,vector<T> a){
	out<<'[';
	for(T x:a)out<<x<<',';
	return out<<']';
}
template<class T>
vector<T> ary(T *a,int l,int r){
	return vector<T>{a+l,a+1+r};
}
template<class T>
void debug(T x){
	cerr<<x<<endl;
}
template<class T,class...S>
void debug(T x,S...y){
	cerr<<x<<' ',debug(y...);
}
#else
#define debug(...) void()
#endif
const int N=1e5+10,V=N*2,M=2e5+10;
int n,q,a[N];
int dft,B,id[V],pos[V],dfn[N];
vector<int>to[N];
struct ques{
	int l,r,id,w;
	bool operator < (const ques &a)const{
		return ::id[l]^::id[a.l]?::id[l]<::id[a.l]:(::id[l]&1?r<a.r:r>a.r);
	}
}o1[M],o2[M*5];
int m1,m2;
void make(int u,int fa=0){
	pos[dfn[u]=++dft]=u;
	for(int v:to[u])if(v^fa){
		make(v,u);
		pos[++dft]=u;
	}
}
namespace Path{
	int top[N],fa[N],dep[N],siz[N],son[N];
	void dfs1(int u){
		siz[u]=1,dep[u]=dep[fa[u]]+1;
		for(int v:to[u])if(v^fa[u]){
			fa[v]=u,dfs1(v);
			siz[u]+=siz[v];
			if(siz[v]>siz[son[u]])son[u]=v;
		}
	}
	int dft,dfn[N],pos[N];
	void dfs2(int u,int t){
		top[u]=t,pos[dfn[u]=++dft]=u;
		if(son[u])dfs2(son[u],t);
		for(int v:to[u])if(v^fa[u]&&v^son[u])dfs2(v,v);
	}
	void init(){
		dfs1(1),dfs2(1,1);
	}
	int LCA(int u,int v){
		for(;top[u]^top[v];u=fa[top[u]]){
			if(dep[top[u]]<dep[top[v]])swap(u,v);
		}
		return dep[u]<dep[v]?u:v;
	}
	int jump(int u,int k){
		for(;k>dep[u]-dep[top[u]];u=fa[top[u]])k-=dep[u]-dep[top[u]]+1;
		return pos[dfn[u]-k];
	}
}
using Path::dep;
namespace DS1{
	int now,cnt[N];
	void insert(int x){
		now+=!cnt[x]++;
	}
	void erase(int x){
		now-=!--cnt[x];
	}
	int query(){
		return now;
	}
}
namespace DS2{
	struct Queue{
		int a[N];
		const int& operator [] (const int &x)const{
			return a[(x%N+N)%N];
		}
		int& operator [] (const int &x){
			return a[(x%N+N)%N];
		}
	}col,pre,nex;
	int s,t;
	int now,cnt[N],bg[N],ed[N];
	ll s1,s2,ans;
	void init(){
		s=1e9,t=s-1,s1=s2=now=ans=0;
		memset(bg,0,sizeof bg);
		memset(ed,0,sizeof ed);
		memset(cnt,0,sizeof cnt);
	}
	void push_back(int x){
		// debug("push_back",x);
		col[++t]=x,now+=!cnt[x]++;
		s1+=n-now;
		s2+=n-(ed[x]?t-ed[x]:t-s+1);
		ans+=(t-s+1ll)*n-s2;
		pre[t]=ed[x],nex[t]=0;
		if(ed[x])nex[ed[x]]=t;
		ed[x]=t;
		if(!bg[x])bg[x]=t;
	}
	void pop_back(){
		// debug("pop_back");
		int x=col[t];
		ed[x]=pre[t];
		if(!ed[x])bg[x]=0;
		else nex[ed[x]]=0;
		ans-=(t-s+1ll)*n-s2;
		s2-=n-(ed[x]?t-ed[x]:t-s+1);
		s1-=n-now;
		now-=!--cnt[col[t--]];
	}
	void push_front(int x){
		// debug("push_front",x);
		col[--s]=x,now+=!cnt[x]++;
		s1+=n-(bg[x]?bg[x]-s:t-s+1);
		s2+=n-now;
		ans+=(t-s+1ll)*n-s1;
		nex[s]=bg[x],pre[s]=0;
		if(bg[x])pre[bg[x]]=s;
		bg[x]=s;
		if(!ed[x])ed[x]=s;
	}
	void pop_front(){
		// debug("pop_front");
		int x=col[s];
		bg[x]=nex[s];
		if(!bg[x])ed[x]=0;
		else pre[bg[x]]=0;
		ans-=(t-s+1ll)*n-s1;
		s2-=n-now;
		s1-=n-(bg[x]?bg[x]-s:t-s+1);
		now-=!--cnt[col[s++]];
	}
	ll query(){
		return ans;
	}
}
ll f[N],ans[M];
void dfs(int u,int fa=0){
	DS2::push_back(a[u]);
	f[u]=DS2::query();
	for(int v:to[u])if(v^fa){
		dfs(v,u);
	}
	DS2::pop_back();
}
int vis[N];
void solve1(){
	for(int i=1;i<=m1;i++){
		if(o1[i].l>o1[i].r)swap(o1[i].l,o1[i].r);
	}
	B=dft/max(1.0,sqrt(m1))*3;
	for(int i=1;i<=dft;i++)id[i]=(i-1)/B+1;
	sort(o1+1,o1+1+m1);
	auto go=[&](int u,int v){
		if(!vis[v])DS1::insert(a[v]),vis[v]=1;
		else DS1::erase(a[u]),vis[u]=0;
	};
	int l=1,r=0;
	for(int i=1;i<=m1;i++){
		for(;r<o1[i].r;r++)go(pos[r],pos[r+1]);
		for(;r>o1[i].r;r--)go(pos[r],pos[r-1]);
		for(;l<o1[i].l;l++)go(pos[l],pos[l+1]);
		for(;l>o1[i].l;l--)go(pos[l],pos[l-1]);
		ans[o1[i].id]+=DS1::query()*o1[i].w;
	}
}
void solve2(){
	memset(vis,0,sizeof vis);
	for(int i=1;i<=m2;i++){
		if(o2[i].l>o2[i].r)swap(o2[i].l,o2[i].r);
	}
	B=dft/max(1.0,sqrt(m2))*3;
	for(int i=1;i<=dft;i++)id[i]=(i-1)/B+1;
	sort(o2+1,o2+1+m2);
	auto go_t=[&](int u,int v){
		if(!vis[v])DS2::push_back(a[v]),vis[v]=1;
		else DS2::pop_back(),vis[u]=0;
	};
	auto go_s=[&](int u,int v){
		if(!vis[v])DS2::push_front(a[v]),vis[v]=1;
		else DS2::pop_front(),vis[u]=0;
	};
	int l=1,r=0;
	DS2::init();
	for(int i=1;i<=m2;i++){
		for(;r<o2[i].r;r++)go_t(pos[r],pos[r+1]);
		for(;r>o2[i].r;r--)go_t(pos[r],pos[r-1]);
		for(;l<o2[i].l;l++)go_s(pos[l],pos[l+1]);
		for(;l>o2[i].l;l--)go_s(pos[l],pos[l-1]);
		ans[o2[i].id]+=DS2::query()*o2[i].w;
	}
}
int main(){
	freopen(".in","r",stdin);
	// freopen(".out","w",stdout);
	scanf("%d%d",&n,&q);
	for(int i=1;i<=n;i++)scanf("%d",&a[i]);
	for(int i=1,u,v;i<n;i++){
		scanf("%d%d",&u,&v);
		to[u].push_back(v),to[v].push_back(u);
	}
	make(1),Path::init(),DS2::init(),dfs(1);
	debug("f",ary(f,1,n));
	for(int i=1,op,u,v;i<=q;i++){
		scanf("%d%d%d",&op,&u,&v);
		if(op==1){
			o1[++m1]={dfn[u],dfn[v],i,1};
		}else{
			int t=Path::LCA(u,v);
			ans[i]=f[u]+f[v]-dep[t];
			if(u^t)o2[++m2]={dfn[u],dfn[Path::jump(u,dep[u]-dep[t]-1)],i,-1};
			if(v^t)o2[++m2]={dfn[v],dfn[Path::jump(v,dep[v]-dep[t]-1)],i,-1};
			if(u^t&&v^t){
				ans[i]++;
				o2[++m2]={dfn[u],dfn[v],i,1};
				o2[++m2]={dfn[u],dfn[t],i,-1};
				o2[++m2]={dfn[v],dfn[t],i,-1};
			}
		}
	}
	solve1(),solve2();
	for(int i=1;i<=q;i++)printf("%lld\n",ans[i]);
	return 0;
}