ABC306F 题解

发布时间 2023-07-04 18:13:32作者: zzafanti

题目链接

题目大意

对于 \(S_1 \cap S_2 = \emptyset\),

定义长度为 \(|S_1|+|S_2|\) 的序列 \(A\),为 \(S_1\cup S_2\) 排序后的结果。

定义二元函数 \(f(S_1,S_2)=\sum\limits_{1\leq i\leq |S_1|+|S_2|} i\times[A_i\in S_1]\)

给定 \(n\) 个大小为 \(m\) 的正整数集合 \(S\),求 \(\sum\limits_{1\leq i<j\leq n} f(S_i,S_j)\)

\(n\leq 1\times 10^4\)

\(m\leq 1\times 10^2\)

题目分析

注意到答案中要求 \(i<j\)

考虑对于计算每个集合中的每个数的对答案的贡献。

可以发现,元素 \(s\)\(f(S_i,S_j)\) 的贡献就是 \(S_1\cup S_2\) 中小于 \(s\) 的数的个数+1。

所以对于共 \(n\times m\) 个元素,计算它们在每种合法的 \(f(S_i,S_j) (1\leq i<j\leq n)\) 的贡献即可。

具体来说,将所有元素从小到大排序并保存该元素属于第几个集合,从前到后依次扫描这些元素。

设第 \(i\) 个元素属于第 \(p\) 个集合,设前 \(i-1\) 个元素中属于集合 \(k\) 的元素个数有 \(cnt_k\)个,那么第 \(i\) 个元素对答案的贡献就是 \(\sum\limits_{p<j \leq n}{cnt_j+cnt_p+1}\)

这个可以拆成关于 \(\sum\limits_{p<j\leq n} cnt_j\)\(cnt_p\) 的形式。

这两个都可以用树状数组或线段树维护,支持区间求和单点加即可。

就做完了。

时间复杂度 \(\mathcal{O}(nm\log n)\)

场上直接拿线段树写了,树状数组常数更小。

参考代码

#include<bits/stdc++.h>

using namespace std;

template<typename T>
void read(T &x){
	x=0;
	int sgn=0;
	char c=getchar();
	while(!isdigit(c)) sgn|=(c=='-'),c=getchar();
	while(isdigit(c)) x=x*10-'0'+c,c=getchar();
	if(sgn) x=-x;
}

const int N=1000010;

struct segment{
	struct node{
		int l,r;
		long long sum;
	};
	
	node tr[N<<2];
	
	#define ls u<<1
	#define rs u<<1|1
	
	void build(int u,int l,int r){
		tr[u].l=l,tr[u].r=r,tr[u].sum=r-l+1; //init 1,1,1,....
		if(l==r) return ;
		int mid=(l+r)>>1;
		build(ls,l,mid),build(rs,mid+1,r);
	}
	
	void add(int u,int pos){
		int l=tr[u].l,r=tr[u].r;
		if(l==r){
			tr[u].sum++;
			return ;
		}
		int mid=(l+r)>>1;
		if(pos<=mid) add(ls,pos);
		else add(rs,pos);
		tr[u].sum=tr[ls].sum+tr[rs].sum;
	}
	
	long long query(int u,int L,int R){
		int l=tr[u].l,r=tr[u].r;
		if(L<=l&&r<=R) return tr[u].sum;
		int mid=(l+r)>>1;
		long long ret=0;
		if(L<=mid) ret+=query(ls,L,R);
		if(mid<R) ret+=query(rs,L,R);
		return ret;
	}
	
	#undef ls
	#undef rs
}sgt;

int n,m,idx;

pair<int,int> p[N];
long long ans=0;

int main(){
	
	read(n),read(m);
	for(int i=1; i<=n; i++){
		for(int j=1,a; j<=m; j++){
			read(a);
			p[++idx]={a,i};
		}
	}
	
	sort(p+1,p+1+n*m);
	int t=n*m;
	sgt.build(1,1,n);
	for(int i=1; i<=t; i++){
		ans+=sgt.query(1,p[i].second+1,n)+((sgt.query(1,p[i].second,p[i].second)-1)*(n-p[i].second));
//		cout<<ans<<' ';
		sgt.add(1,p[i].second);
	}//cout<<endl;
	
	cout<<ans;
	
	return 0;
}