CF241B Friends

发布时间 2023-12-22 19:24:46作者: blueparrot

异或粽子的加强版,时间复杂度是 \(O(n log^2 w)\) ,其中 \(w\) 是值域 \(2^{30}\) ,原来的是和 \(k\) 有关的,相当于是 CF241B 的代码通过不了异或粽子,异或粽子的代码通过不了 CF241B(雾

先考虑一个整体的思路,求前 \(k\) 大,先需要求第 \(k\) 大,第 \(k\) 大直接先二分,然后判断 \(mid\) 的排名,所以子问题是在字典树上判断两两异或值大于等于 \(mid\) 的个数

要求大于等于 \(mid\) ,也就是说二进制下存在一个位置它是 1 且 \(mid\) 是 0 ,它之前的都和 \(mid\) 一样,那直接枚举每个数,然后在字典树上跳,如果 \(mid\) 这位是 1 的话就直接跳,如果是 0 的话,就统计是 1 的答案,然后跳 0 ,统计一下,因为贡献会互相算一次,所以算了两次,要除一下

那么现在求完了第 \(k\) 大,我们要求前 \(k\) 大的 \(a_i\oplus a_j\) 和,直接枚举 \(a_i\) ,现在就是要求大于 \(kth\) ,如果这位 \(kth\) 是 1 就直接跳,如果是 0 就先统计再跳,跟之前差不多,但是统计的时候要统计这样一个东西:以 \(i\) 为根节点的子树内第 \(k\) 位为 1 的数的个数。再统计直接枚举位数就好了,如果这 \(a_i\) 这一位是1的话,要容斥一下,统计的话就是算贡献,乘 \(2^k\) ,很常见的算贡献方式。记得最后要除 2

但是要注意,如果有数和第 \(k\) 大相同就要减掉那些数

#include<bits/stdc++.h>
#define il inline 
#define maxn 50001
using namespace std;
typedef long long ll;
const ll mod=1e9+7;
const ll inv2=500000004ll;
il int read(){
	char c;int x=0,f=0;
	while(!isdigit(c=getchar()))f|=(c=='-');
	while(isdigit(c))x=(x*10)+(c^48),c=getchar();
	return f?-x:x;
}
int n,a[maxn],tot=0;
ll k,kth;
int trie[maxn*30][2],cnt[maxn*30];
int res[maxn*30][31]; // 以i为子树里,第k位是1的数的个数
il void insert(int id){
	int u=0;
	for(int i=30;i>=0;i--){
		int c=(a[id]>>i)&1;
		if(!trie[u][c])trie[u][c]=++tot;
		u=trie[u][c],cnt[u]++;
		for(int k=0;k<=30;k++) 
			if((a[id]>>k)&1)res[u][k]++;
	}
}
il ll check(ll x){
	ll tot=0;
	for(int i=1;i<=n;i++){
		int u=0;
		for(int j=30;j>=0;j--){	
			int c1=(a[i]>>j)&1,c2=(x>>j)&1;
			if(c2)u=trie[u][c1^1];
			else tot+=cnt[trie[u][c1^1]],u=trie[u][c1];
			if(!u)break;
		}
		tot+=cnt[u];
	}
	return tot/2;
}
ll ans=0;
il void calc(){
	for(int i=1;i<=n;i++){
		int u=0;
		for(int j=30;j>=0;j--){
			int c1=(a[i]>>j)&1,c2=(kth>>j)&1;
			if(c2)u=trie[u][c1^1];
			else{
				for(int k=0;k<=30;k++){
					int c3=(a[i]>>k)&1;
					if(c3)ans=(ans+(1ll<<k)*(cnt[trie[u][c1^1]]-res[trie[u][c1^1]][k])%mod)%mod;
					else ans=(ans+(1ll<<k)*res[trie[u][c1^1]][k]%mod)%mod;
				}
				u=trie[u][c1];
			}
			if(!u)break;
		}
		ans=(ans+cnt[u]*1ll*kth%mod)%mod;
	}
	ans=(ans*inv2)%mod,ans=((ans-(check(kth)-k)*kth%mod)%mod+mod)%mod;
}
int main(){
	n=read(),k=read();
	for(int i=1;i<=n;i++){
		a[i]=read();
		insert(i);
	}
	ll lt=0,rt=(1ll<<30)+1;
	while(lt+1<rt){
		ll mid=(lt+rt)>>1;
		if(check(mid)>=k)lt=mid;
		else rt=mid;
	}
	kth=lt,calc();
	printf("%lld\n",ans);
	return 0;
}