F. Unique Occurrences(线段树分治+可撤销并查集)

发布时间 2023-11-01 17:10:21作者: gan_coder

F. Unique Occurrences
假如我们删除所有权值为x的边,那么所有权值为x的边对答案的贡献就是
\(\sum sz[u]*sz[v]\) sz表示两个联通块的大小,且(u,v)的边权为x

我们可以用可撤销并查集来进行处理,简单来说就是将一条边的存在时间看作区间,然后挂到线段树上,然后遍历到每个叶子的时候进行计算,需要注意的是,因为并查集需要支持撤销,所以不能路径压缩,而要采用按秩合并。

#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<map>
#include<vector>
#include<set>
#define A puts("YES")
#define B puts("NO")
#define fo(i,a,b) for (int (i)=(a);(i)<=(b);(i)++)
#define fd(i,b,a) for (int (i)=(b);(i)>=(a);(i)--)
#define mk(x,y) make_pair((x),(y))
#define lc (o<<1)
#define rc (o<<1|1)
using namespace std;
//typedef __int128 i128;
typedef double db;
typedef long long ll;
const int mo=998244353;
const int N=5e5+5;
struct node{
	int x,y,z;
};
struct key{
	int x,y,z,op;
};
node a[N];
int n,f[N],b[N],tot,top,f1,f2;
ll sz[N],ans;
pair<int,int> st[N];
int x,y,z;
key k;

vector<key> t[N*4];
int find(int x){
	return x==f[x] ? x:find(f[x]);
}
void merge(int x,int y){
	f1=find(x);
	f2=find(y);
	if (f1==f2) return;
	if (sz[f1]>sz[f2]) swap(f1,f2);
	sz[f2]+=sz[f1];
	f[f1]=f2;
	st[++top]=mk(f1,f2);
}
void undo(){
	f1=st[top].first;
	f2=st[top].second;
	sz[f2]-=sz[f1];
	f[f1]=f1;
	top--;
}
void upd(int o,int l,int r){
	if (x>y) return;
	if (x<=l && r<=y) {
		t[o].push_back(k);
		return;
	}
	int m=(l+r)>>1;
	if (x<=m) upd(lc,l,m);
	if (m<y) upd(rc,m+1,r);
}
void calc(int o,int l,int r){
	int now=top;
	for (auto i:t[o]) {
		if (!i.op) {
			merge(i.x, i.y);
		}	
	}
	
	if (l==r) {
		for (auto i:t[o]) {
			if (i.op) {
				ans+=sz[find(i.x)]*sz[find(i.y)];
			}
		}
		while (top>now) undo();
		return;
	}
	
	int m=(l+r)>>1;
	calc(lc,l,m);
	calc(rc,m+1,r);
	
	while (top>now) undo();
}
int main()
{
//	freopen("data.in","r",stdin);
//	freopen("ans.out","w",stdout);

	scanf("%d",&n);
	fo(i,1,n-1) {
		scanf("%d %d %d",&a[i].x, &a[i].y, &a[i].z);
		b[++tot]=a[i].z;
	}
	
	sort(b+1,b+tot+1);
	int m=unique(b+1,b+tot+1)-(b+1);
	fo(i,1,n-1) a[i].z=lower_bound(b+1,b+m+1,a[i].z)-b;
	
	fo(i,1,n) f[i]=i,sz[i]=1;
	
//	fo(i,1,n-1) {
//		printf("%d %d %d\n",a[i].x, a[i].y, a[i].z);
//	}

	fo(i,1,n-1) {
		x=1; y=a[i].z-1; k=(key){a[i].x, a[i].y, a[i].z, 0}; 
		upd(1,1,m);
		
		x=a[i].z; y=a[i].z; k=(key){a[i].x, a[i].y, a[i].z, 1};
		upd(1,1,m);
	
		x=a[i].z+1; y=m; k=(key) {a[i].x, a[i].y, a[i].z, 0};
		upd(1,1,m);
	}
	
	calc(1,1,m);
	printf("%lld",ans);
	return 0;
}