CF1486F Pairs of Paths 总结--zhengjun

发布时间 2023-07-12 20:34:10作者: A_zjzj

需要保持:

  • 写代码前先仔细考虑一下细节,分类讨论清楚再开始码。

警告:

  • namespace 里面写了个 n,想调用全局 n 的时候没加 2*冒号。

思路大概就是分类讨论然后计数就完事了。

代码

#include<bits/stdc++.h>
using namespace std;
using ll=long long;
const int N=3e5+10;
int n,m;
vector<int>to[N];
int fa[N],dep[N],bg[N],ed[N];
namespace T{
	int dft,pos[N],siz[N],son[N],top[N];
	void dfs1(int u){
		dep[u]=dep[fa[u]]+1,siz[u]=1;
		for(int v:to[u])if(v^fa[u]){
			fa[v]=u,dfs1(v);
			siz[u]+=siz[v];
			if(siz[v]>siz[son[u]])son[u]=v;
		}
	}
	void dfs2(int u,int t){
		top[u]=t,pos[bg[u]=++dft]=u;
		if(son[u])dfs2(son[u],t);
		for(int v:to[u])if(v^fa[u]&&v^son[u])dfs2(v,v);
		ed[u]=dft;
	}
	void init(){
		dfs1(1),dfs2(1,1);
	}
	int LCA(int u,int v){
		for(;top[u]^top[v];u=fa[top[u]]){
			if(dep[top[u]]<dep[top[v]])swap(u,v);
		}
		return dep[u]<dep[v]?u:v;
	}
	int kth(int u,int k){
		for(;dep[u]-dep[top[u]]<k;u=fa[top[u]])k-=dep[u]-dep[top[u]]+1;
		return pos[bg[u]-k];
	}
}
using T::LCA;
using T::kth;
int cnt[N];
struct path{
	int u,v,t,x,y;
}a[N];
ll ans;
namespace Solve1{
	void dfs(int u,int fa=0){
		cnt[u]+=cnt[fa];
		for(int v:to[u])if(v^fa)dfs(v,u);
	}
	void calc(){
		for(int i=1;i<=n;i++)ans+=cnt[i]*(cnt[i]-1ll)/2;
		dfs(1);
		for(int i=1;i<=m;i++){
			ans+=cnt[a[i].u]+cnt[a[i].v]-cnt[a[i].t]-cnt[fa[a[i].t]];
		}
	}
}
namespace Solve2{
	void dfs(int u,int fa=0){
		for(int v:to[u])if(v^fa){
			dfs(v,u),cnt[u]+=cnt[v];
		}
	}
	void calc(){
		fill(cnt,cnt+1+n,0);
		for(int i=1;i<=m;i++){
			if(a[i].u^a[i].t){
				a[i].x=kth(a[i].u,dep[a[i].u]-dep[a[i].t]-1);
				cnt[a[i].u]++,cnt[a[i].x]--;
			}
			if(a[i].v^a[i].t){
				a[i].y=kth(a[i].v,dep[a[i].v]-dep[a[i].t]-1);
				cnt[a[i].v]++,cnt[a[i].y]--;
			}
		}
		dfs(1);
	}
}
namespace Solve3{
	int tot[N];
	void calc(){
		for(int i=1;i<=m;i++)tot[a[i].t]++;
		for(int i=1;i<=m;i++){
			if(a[i].u^a[i].t)ans+=tot[a[i].u];
			if(a[i].v^a[i].t)ans+=tot[a[i].v];
		}
	}
}
namespace Solve4{
	int n;
	int sum[N];
	int f[N],g[N];
	void calc(){
		for(int i=1;i<=m;i++){
			if(a[i].x>a[i].y)swap(a[i].x,a[i].y),swap(a[i].u,a[i].v);
		}
		sort(a+1,a+1+m,[](path x,path y){
			return x.t^y.t?x.t<y.t:(x.x^y.x?x.x<y.x:x.y<y.y);
		});
		for(int i=1,j;i<=m;i=j){
			for(j=i;j<=m&&a[j].t==a[i].t;j++);
			n=0;
			for(int k=i;k<j;k++){
				if(a[k].u^a[k].t)sum[a[k].x]++;
				if(a[k].v^a[k].t)sum[a[k].y]++;
			}
			ans+=(j-i)*(j-i-1ll)/2;
			int u=a[i].t;
			for(int v:to[u])if(v^fa[u])ans-=sum[v]*(sum[v]-1ll)/2;
		}
		for(int i=1,j;i<=m;i=j){
			for(j=i;j<=m;j++){
				if(a[j].t^a[i].t||a[j].x^a[i].x||a[j].y^a[i].y)break;
			}
			if(a[i].x&&a[i].y)ans+=(j-i)*(j-i-1ll)/2;
		}
		for(int i=1;i<=m;i++){
			f[a[i].t]++;
			if(a[i].u^a[i].t)g[a[i].x]--;
			if(a[i].v^a[i].t)g[a[i].y]--;
		}
		for(int u=2;u<=::n;u++){
			ans+=1ll*cnt[u]*(f[fa[u]]+g[u]);
		}
	}
}
int main(){
	freopen(".in","r",stdin);
	//freopen(".out","w",stdout);
	scanf("%d",&n);
	for(int i=1,u,v;i<n;i++){
		scanf("%d%d",&u,&v);
		to[u].push_back(v),to[v].push_back(u);
	}
	T::init();
	int mm;
	scanf("%d",&mm);
	for(int u,v;mm--;){
		scanf("%d%d",&u,&v);
		if(u==v)cnt[u]++;
		else a[++m]={u,v,LCA(u,v)};
	}
	Solve1::calc();
	Solve2::calc();
	Solve3::calc();
	Solve4::calc();
	cout<<ans;
	return 0;
}