【题解】P5327 [ZJOI2019] 语言

发布时间 2023-04-21 11:26:40作者: flywatre

P5327 [ZJOI2019] 语言

题目描述

九条可怜是一个喜欢规律的女孩子。按照规律,第二题应该是一道和数据结构有关的题。

在一个遥远的国度,有 \(n\) 个城市。城市之间有 \(n - 1\) 条双向道路,这些道路保证了任何两个城市之间都能直接或者间接地到达。

在上古时代,这 \(n\) 个城市之间处于战争状态。在高度闭塞的环境中,每个城市都发展出了自己的语言。而在王国统一之后,语言不通给王国的发展带来了极大的阻碍。为了改善这种情况,国王下令设计了 \(m\) 种通用语,并进行了 \(m\) 次语言统一工作。在第 \(i\) 次统一工作中,一名大臣从城市 \(s_i\) 出发,沿着最短的路径走到了 \(t_i\),教会了沿途所有城市(包括 \(s_i, t_i\))使用第 \(i\) 个通用语。

一旦有了共通的语言,那么城市之间就可以开展贸易活动了。两个城市 \(u_i, v_i\) 之间可以开展贸易活动当且仅当存在一种通用语 \(L\) 满足 \(u_i\)\(v_i\) 最短路上的所有城市(包括 \(u_i, v_i\)),都会使用 \(L\)

为了衡量语言统一工作的效果,国王想让你计算有多少对城市 \((u, v)\)\(u < v\)),他们之间可以开展贸易活动。

测试点 \(n\) \(m\) 其他约定
\(1,2\) \(\le 300\) \(\le 300\)
\(3,4\) \(\le 5\times 10^3\) \(\le 5\times 10^3\)
\(5,6\) \(\le 10^5\) \(\le 10^5\) \(y_i=x_i+1\)
\(7\sim 10\) \(\le 10^5\) \(\le 10^5\)

对于 \(100\%\) 的数据,有 \(1 \le x_i, y_i, s_i, t_i \le n\leq 10 ^ 5\)\(1\leq m\leq 10 ^ 5\)\(x_i\neq y_i\)

题解

线段树合并喵喵题。
考虑对于每个点维护经过它的路径端点的集合,那么当前点能到达的点即为集合内的点构成的虚数的大小。
可以用树上差分,在端点处加入点,lca处减去点。
虚树大小即为在dfs序上设当前点为 u ,前一个点为 v,则虚树大小即为 \(\sum dep_u-dep_{lca(u,v)}\),于是可以线段树下标为dfs序动态维护虚树,差分就将子节点的线段树合并到自己,最终答案除二($ u< v$)即可。
题目的转化非常的妙,将链信息转化为点信息,还有线段树dfs序维护虚树的方法。
代码1h,调了2h,还是太菜了啊。

#include<bits/stdc++.h>
using namespace std;
inline int rd(){
	int f=1,j=0;
	char w=getchar();
	while(!isdigit(w)){
		if(w=='-')f=-1;
		w=getchar();
	}
	while(isdigit(w)){
		j=j*10+w-'0';
		w=getchar();
	}
	return f*j;
}
const int N=100010;
int head[N],to[N*2],fro[N*2],tail;
int n,m,ansn[N];
long long ans;
int fr[N],dfn[N],bel[N],siz[N],dep[N],dfn_cnt;
struct Lca{
	int dfn[N*2],bel[N],st[N*2][21],cnt,tw[N*2];
	inline int cmp(int x,int y){return (dep[x]<dep[y])?x:y;}
	void init(){
		for(int i=2;i<=cnt;i++)tw[i]=((1<<(tw[i-1]+1))==i)?tw[i-1]+1:tw[i-1];
		for(int i=1;i<=cnt;i++)st[i][0]=dfn[i];
		for(int k=1;k<=20;k++){
			for(int i=1;i+(1<<k)-1<=cnt;i++)st[i][k]=cmp(st[i][k-1],st[i+(1<<(k-1))][k-1]);
		}
		return ;
	}
	int get(int x,int y){
		x=bel[x],y=bel[y];
		if(x>y)swap(x,y);
		int k=tw[y-x+1];
		return cmp(st[x][k],st[y-(1<<k)+1][k]);
	}
}lca;
inline void adlin(int x,int y){
	to[++tail]=y,fro[tail]=head[x],head[x]=tail;
	return ;
}
void dfs(int u,int fa){
	fr[u]=fa;
	dep[u]=dep[fa]+1,bel[u]=++dfn_cnt,dfn[dfn_cnt]=u,siz[u]=1;
	lca.bel[u]=++lca.cnt,lca.dfn[lca.cnt]=u;
	for(int k=head[u];k;k=fro[k]){
		int x=to[k];
		if(x==fa)continue;
		dfs(x,u);
		siz[u]+=siz[x];
		lca.dfn[++lca.cnt]=u;
	}
	return ;
}
int rt[N],ls[N*100],rs[N*100],sum[N*100],cnt;
struct node{
	int l,r,sum;
}tr[N*100];
node operator +(node a,node b){
	node c;
	c.sum=a.sum+b.sum;
	c.l=a.l?a.l:b.l,c.r=b.r?b.r:a.r;
	if(a.r&&b.l)c.sum-=dep[lca.get(a.r,b.l)];
	return c;
}
void update(int u){
	tr[u]=tr[ls[u]]+tr[rs[u]];
	sum[u]=sum[ls[u]]+sum[rs[u]];
	return ;
}
void modify(int &u,int l,int r,int aim,int k){
	if(!u)u=++cnt;
	if(l==r){
		sum[u]+=k;
		if(sum[u]>0)tr[u]=(node){dfn[aim],dfn[aim],dep[dfn[aim]]};
		else tr[u]=(node){0,0,0};
		return ;
	}
	int mid=(l+r)/2;
	if(aim<=mid)modify(ls[u],l,mid,aim,k);
	else modify(rs[u],mid+1,r,aim,k);
	update(u);
	return ;
}
void merge(int &u,int a,int b,int l,int r){
	if(!a||!b)return u=a+b,void(0);
	u=a;
	if(l==r){
		sum[u]+=sum[b];
		if(sum[u]>0)tr[u]=(node){dfn[l],dfn[l],dep[dfn[l]]};
		else tr[u]=(node){0,0,0};
		return ;
	}
	int mid=(l+r)/2;
	merge(ls[u],ls[a],ls[b],l,mid);
	merge(rs[u],rs[a],rs[b],mid+1,r);
	update(u);
	return ; 
}
vector<int>del[N];
void work(int u,int fa){
	for(int k=head[u];k;k=fro[k]){
		int x=to[k];
		if(x==fa)continue;
		work(x,u);
	}
	for(int k=head[u];k;k=fro[k]){
		int x=to[k];
		if(x==fa)continue;
		merge(rt[u],rt[u],rt[x],1,n);
	}
	for(int nw:del[u])modify(rt[u],1,n,bel[nw],-1);
//	modify(rt[u],1,n,bel[u],1);
//	ans+=tr[rt[u]].sum-dep[fr[lca.get(tr[rt[u]].l,tr[rt[u]].r)]]-1;
	ansn[u]=max(0,tr[rt[u]].sum-dep[fr[lca.get(tr[rt[u]].l,tr[rt[u]].r)]]-1);
	ans+=ansn[u];
//	cout<<u<<":"<<tr[rt[u]].sum-dep[fr[lca.get((tr[rt[u]].l),(tr[rt[u]].r))]]<<"-"<<tr[rt[u]].l<<" "<<tr[rt[u]].r<<"\n";
//	modify(rt[u],1,n,bel[u],-1);
//	if(u==1)return ;
//	merge(rt[fa],rt[fa],rt[u],1,n);
	return ;
}
signed main(){
//	freopen("language1.in","r",stdin);
//	freopen("ans.out","w",stdout);
	n=rd(),m=rd();
	for(int i=1;i<n;i++){
		int x=rd(),y=rd();
		adlin(x,y),adlin(y,x);
	}
	dfs(1,0),lca.init();
	for(int i=1;i<=m;i++){
		int s=rd(),t=rd(),x=lca.get(s,t);
//		cout<<s<<" "<<t<<"-"<<x<<"\n";
		modify(rt[s],1,n,bel[s],1),modify(rt[s],1,n,bel[t],1);
//		cout<<s<<":"<<tr[rt[s]].sum<<"-"<<tr[rt[s]].l<<" "<<tr[rt[s]].r<<"\n"; 
		modify(rt[t],1,n,bel[t],1),modify(rt[t],1,n,bel[s],1);
		del[x].push_back(t),del[x].push_back(s);
//		modify(rt[x],1,n,bel[s],-1),modify(rt[x],1,n,bel[t],-1);
		if(x!=1)del[fr[x]].push_back(t),del[fr[x]].push_back(s);
//		if(x!=1)modify(rt[fr[x]],1,n,bel[fr[s]],-1),modify(rt[fr[x]],1,n,bel[fr[t]],-1);
	}
	work(1,0);
	printf("%lld\n",ans/2);
//	for(int i=1;i<=n;i++)cout<<i<<":"<<ansn[i]<<"\n";
	return 0;
}