F. Trees and XOR Queries Again

发布时间 2023-12-06 13:20:26作者: gan_coder

首先容易想到lca+线性基,\(O(nlognB^2+qlognB^2)\),显然T飞了。

#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<map>
#include<vector>
#include<set>
#include<iostream>
#include<queue>
#include<ctime>
#define A puts("YES")
#define B puts("NO")
//#define A puts("Yes")
//#define B puts("No")
#define fo(i,a,b) for (int (i)=(a);(i)<=(b);(i)++)
#define fd(i,b,a) for (int (i)=(b);(i)>=(a);(i)--)
#define mk(x,y) make_pair((x),(y))
#define lc (o<<1)
#define rc (o<<1|1)
using namespace std;
//typedef __int128 i128;
typedef double db;
typedef long long ll;
const int N=2e5+5;
const ll inf=1ll<<60;
int b[30];
struct lb{
	int a[21];
	void init(){
		memset(a,0,sizeof(a));
	}
	void add(int x){
		fd(i,20,0){
			if (!(x&b[i])) continue;
			if (a[i]) x^=a[i];
			else {
				a[i]=x; break;
			}
		}
	}

	bool ask(int x){
		fd(i,20,0) {
			if (x&b[i]) x^=a[i];
		}
		return x==0;
	}
};
lb g[N][21],ans;
int f[N][21],d[N],a[N];
int to[N*2],nex[N*2],head[N],tot,n,x,y,k,q;

void merge(lb &x,lb y){
	fo(i,0,20) {
		x.add(y.a[i]);
	}
}
void add(int x,int y){
	to[++tot]=y; nex[tot]=head[x]; head[x]=tot;
}
void dfs(int x,int y){
	g[x][0].add(a[x]);
	for (int i=head[x];i;i=nex[i]){
		int v=to[i];
		if (v==y) continue;
		
		f[v][0]=x;
		d[v]=d[x]+1;
		dfs(v,x);
	}
}
void ask(int x,int y){
	if (d[x]<d[y]) swap(x,y);
	fd(k,20,0) {
		if (d[f[x][k]]>=d[y]) {
			merge(ans, g[x][k]);
			x=f[x][k];
		}	
	}
	if (x==y) {
		ans.add(a[x]);
		return;
	}
	fd(k,20,0) {
		if (f[x][k]^f[y][k]) {
			merge(ans, g[x][k]);
			merge(ans, g[y][k]);
			
			x=f[x][k]; y=f[y][k];
		}
	}
	ans.add(a[x]);
	ans.add(a[y]);
	ans.add(a[f[x][0]]);
}
int main()
{
//	freopen("data.in","r",stdin);
//	freopen("ans.out","w",stdout);
	

	b[0]=1;
	fo(i,1,20) b[i]=b[i-1]*2;
	
	scanf("%d",&n);
	fo(i,1,n) scanf("%d",&a[i]);
	
	fo(i,1,n-1){
		scanf("%d %d",&x,&y);
		add(x,y); add(y,x);
	}
	
	f[1][0]=1;
	g[1][0].add(a[1]);
	dfs(1,0);

	fo(j,1,20) fo(i,1,n) {
		f[i][j]=f[f[i][j-1]][j-1];

		merge(g[i][j], g[i][j-1]);
		merge(g[i][j], g[f[i][j-1]][j-1]);
	}
	
	scanf("%d",&q);
	while (q--){
		scanf("%d %d %d",&x,&y,&k);
		ans.init();
		
		ask(x,y);
		if (ans.ask(k)) A; else B;
	}
	return 0;
	
} 
 
  
 

来点优化,预处理第一次,完全没必要merge,直接赋值就行

#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<map>
#include<vector>
#include<set>
#include<iostream>
#include<queue>
#include<ctime>
#define A puts("YES")
#define B puts("NO")
//#define A puts("Yes")
//#define B puts("No")
#define fo(i,a,b) for (int (i)=(a);(i)<=(b);(i)++)
#define fd(i,b,a) for (int (i)=(b);(i)>=(a);(i)--)
#define mk(x,y) make_pair((x),(y))
#define lc (o<<1)
#define rc (o<<1|1)
using namespace std;
//typedef __int128 i128;
typedef double db;
typedef long long ll;
const int N=2e5+5;
const ll inf=1ll<<60;
int b[30];
struct lb{
	int a[21];
	void init(){
		memset(a,0,sizeof(a));
	}
	void add(int x){
		fd(i,20,0){
			if (!(x&b[i])) continue;
			if (a[i]) x^=a[i];
			else {
				a[i]=x; break;
			}
		}
	}

	bool ask(int x){
		fd(i,20,0) {
			if (x&b[i]) x^=a[i];
		}
		return x==0;
	}
};
lb g[N][21],ans;
int f[N][21],d[N],a[N];
int to[N*2],nex[N*2],head[N],tot,n,x,y,k,q;
void R(int &x){
	int t=0;
	char ch;
	for (ch=getchar();!('0'<=ch && ch<='9');ch=getchar());
	for(;('0'<=ch && ch<='9');ch=getchar()) t=t*10+ch-'0';
	x=t;
}
void merge(lb &x,lb y){
	fo(i,0,20) {
		x.add(y.a[i]);
	}
}
void add(int x,int y){
	to[++tot]=y; nex[tot]=head[x]; head[x]=tot;
}
void dfs(int x,int y){
	g[x][0].add(a[x]);
	for (int i=head[x];i;i=nex[i]){
		int v=to[i];
		if (v==y) continue;
		
		f[v][0]=x;
		d[v]=d[x]+1;
		dfs(v,x);
	}
}
void ask(int x,int y){
	if (d[x]<d[y]) swap(x,y);
	fd(k,20,0) {
		if (d[f[x][k]]>=d[y]) {
			merge(ans, g[x][k]);
			x=f[x][k];
		}	
	}
	if (x==y) {
		ans.add(a[x]);
		return;
	}
	fd(k,20,0) {
		if (f[x][k]^f[y][k]) {
			merge(ans, g[x][k]);
			merge(ans, g[y][k]);
			
			x=f[x][k]; y=f[y][k];
		}
	}
	ans.add(a[x]);
	ans.add(a[y]);
	ans.add(a[f[x][0]]);
}
int main()
{
//	freopen("data.txt","r",stdin);
//	freopen("ans.out","w",stdout);
	b[0]=1;
	fo(i,1,20) b[i]=b[i-1]*2;
	
	R(n);
	fo(i,1,n) R(a[i]);
	
	fo(i,1,n-1){
		R(x); R(y);
		add(x,y); add(y,x);
	}
	
	f[1][0]=1;
	g[1][0].add(a[1]);
	dfs(1,0);

	fo(j,1,20) fo(i,1,n) {
		f[i][j]=f[f[i][j-1]][j-1];
		
		fo(k,0,20) g[i][j].a[k]=g[i][j-1].a[k];
		
//		merge(g[i][j], g[i][j-1]);
		merge(g[i][j], g[f[i][j-1]][j-1]);
	}
	
//	return 0;
	
	R(q);
	while (q--){
		R(x); R(y); R(k);
		ans.init();
		
		ask(x,y);
		if (ans.ask(k)) A; else B;
	}
	return 0;
	
} 
 
  
 

还是T,再来点优化
想起来之前动物园那题,开始写的是倍增,但是T,看到别的大佬说可以交换顺序,让它尽量连续访问,交换之后就过了。

试着改一下
同时原来先将x跳到与y同高,这样的话要合并3k次,x,y分别跳只要2k次。

#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<map>
#include<vector>
#include<set>
#include<iostream>
#include<queue>
#include<ctime>
#define A puts("YES")
#define B puts("NO")
//#define A puts("Yes")
//#define B puts("No")
#define fo(i,a,b) for (register int (i)=(a);(i)<=(b);(i)++)
#define fd(i,b,a) for (int (i)=(b);(i)>=(a);(i)--)
#define mk(x,y) make_pair((x),(y))
#define lc (o<<1)
#define rc (o<<1|1)
using namespace std;
//typedef __int128 i128;
typedef double db;
typedef long long ll;
const int N=2e5+5;
const ll inf=1ll<<60;
int b[30];
struct lb{
	int a[21];
	void init(){
		memset(a,0,sizeof(a));
	}
	inline void add(int x){
		fd(i,19,0){
			if (!(x&b[i])) continue;
			if (a[i]) x^=a[i];
			else {
				a[i]=x; break;
			}
		}
	}

	bool ask(int x){
		fd(i,19,0) {
			if (x&b[i]) x^=a[i];
		}
		return x==0;
	}
};
lb g[19][N],ans;
int f[19][N],d[N],a[N];
int to[N*2],nex[N*2],head[N],tot,n,x,y,k,q;
void R(int &x){
	int t=0;
	char ch;
	for (ch=getchar();!('0'<=ch && ch<='9');ch=getchar());
	for(;('0'<=ch && ch<='9');ch=getchar()) t=t*10+ch-'0';
	x=t;
}
inline void merge(lb &x,lb y){
	fo(i,0,19) {
		x.add(y.a[i]);
	}
}
inline void add(int x,int y){
	to[++tot]=y; nex[tot]=head[x]; head[x]=tot;
}
void dfs(int x,int y){
	g[0][x].add(a[x]);
	for (int i=head[x];i;i=nex[i]){
		int v=to[i];
		if (v==y) continue;
		
		f[0][v]=x;
		d[v]=d[x]+1;
		dfs(v,x);
	}
}
int lca(int x,int y){
	if (d[x]<d[y]) swap(x,y);
	fd(k,18,0) {
		if (d[f[k][x]]>=d[y]) x=f[k][x];
	}
	if (x==y) return x;
	fd(k,18,0) {
		if (f[k][x]^f[k][y]) x=f[k][x],y=f[k][y];
	}
	return f[0][x];
}
void ask(int x,int y){
	int l=lca(x,y);
	
	fd(k,18,0) {
		if (d[f[k][x]]>=d[l]) {
			merge(ans, g[k][x]);
			x=f[k][x];
		}
	}
	
	fd(k,18,0) {
		if (d[f[k][y]]>=d[l]) {
			merge(ans, g[k][y]);
			y=f[k][y];
		}
	}
	ans.add(a[l]);
}
int main()
{
//	freopen("data.in","r",stdin);
//	freopen("ans.out","w",stdout);
	b[0]=1;
	fo(i,1,20) b[i]=b[i-1]*2;
	
	R(n);
	fo(i,1,n) R(a[i]);
	
	fo(i,1,n-1){
		R(x); R(y);
		add(x,y); add(y,x);
	}
	
	f[0][1]=1;
	g[0][0].add(a[1]);
	dfs(1,0);
	
	fo(j,1,18) fo(i,1,n) {
		f[j][i]=f[j-1][f[j-1][i]];
		
		fo(k,0,19) g[j][i].a[k]=g[j-1][i].a[k];
		merge(g[j][i], g[j-1][f[j-1][i]]);
	}
	
//	return 0;
	
	R(q);
	while (q--){
		R(x); R(y); R(k);
		ans.init();
		
		ask(x,y);
		if (ans.ask(k)) A; else B;
	}
	return 0;
	
} 
 
  
 

还可以继续优化,当j足够大,已经能够到根的话,就没必要合并了,直接赋值就行。

#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<map>
#include<vector>
#include<set>
#include<iostream>
#include<queue>
#include<ctime>
#define A puts("YES")
#define B puts("NO")
//#define A puts("Yes")
//#define B puts("No")
#define fo(i,a,b) for (register int (i)=(a);(i)<=(b);(i)++)
#define fd(i,b,a) for (register int (i)=(b);(i)>=(a);(i)--)
#define mk(x,y) make_pair((x),(y))
#define lc (o<<1)
#define rc (o<<1|1)
using namespace std;
//typedef __int128 i128;
typedef double db;
typedef long long ll;
const int N=2e5+5;
const ll inf=1ll<<60;
int b[30];
struct lb{
	int a[21];
	void init(){
		memset(a,0,sizeof(a));
	}
	inline void add(int x){
		fd(i,19,0){
			if (!(x&b[i])) continue;
			if (a[i]) x^=a[i];
			else {
				a[i]=x; break;
			}
		}
	}

	bool ask(int x){
		fd(i,19,0) {
			if (x&b[i]) x^=a[i];
		}
		return x==0;
	}
};
lb g[19][N],ans;
int f[19][N],d[N],a[N];
int to[N*2],nex[N*2],head[N],tot,n,x,y,k,q;
void R(int &x){
	int t=0;
	char ch;
	for (ch=getchar();!('0'<=ch && ch<='9');ch=getchar());
	for(;('0'<=ch && ch<='9');ch=getchar()) t=t*10+ch-'0';
	x=t;
}
inline void merge(lb &x,lb y){
	fo(i,0,19) {
		x.add(y.a[i]);
	}
}
inline void add(int x,int y){
	to[++tot]=y; nex[tot]=head[x]; head[x]=tot;
}
void dfs(int x,int y){
	g[0][x].add(a[x]);
	for (int i=head[x];i;i=nex[i]){
		int v=to[i];
		if (v==y) continue;
		
		f[0][v]=x;
		d[v]=d[x]+1;
		dfs(v,x);
	}
}
int lca(int x,int y){
	if (d[x]<d[y]) swap(x,y);
	fd(k,18,0) {
		if (d[f[k][x]]>=d[y]) x=f[k][x];
	}
	if (x==y) return x;
	fd(k,18,0) {
		if (f[k][x]^f[k][y]) x=f[k][x],y=f[k][y];
	}
	return f[0][x];
}
void ask(int x,int y){
	int l=lca(x,y);
	
	fd(k,18,0) {
		if (d[f[k][x]]>=d[l]) {
			merge(ans, g[k][x]);
			x=f[k][x];
		}
	}
	
	fd(k,18,0) {
		if (d[f[k][y]]>=d[l]) {
			merge(ans, g[k][y]);
			y=f[k][y];
		}
	}
	ans.add(a[l]);
}
int main()
{
//	freopen("data.in","r",stdin);
//	freopen("ans.out","w",stdout);
	b[0]=1;
	fo(i,1,20) b[i]=b[i-1]*2;
	
	R(n);
	fo(i,1,n) R(a[i]);
	
	fo(i,1,n-1){
		R(x); R(y);
		add(x,y); add(y,x);
	}
	
	f[0][1]=1;
	g[0][0].add(a[1]);
	dfs(1,0);
	
	fo(j,1,18) fo(i,1,n) {
		f[j][i]=f[j-1][f[j-1][i]];
		
		if (f[j-1][i]==1) {
			g[j][i]=g[j-1][i];
			continue;
		}
		fo(k,0,19) g[j][i].a[k]=g[j-1][i].a[k];
		merge(g[j][i], g[j-1][f[j-1][i]]);
	}
	
//	return 0;
	
	R(q);
	while (q--){
		R(x); R(y); R(k);
		ans.init();
		
		ask(x,y);
		if (ans.ask(k)) A; else B;
	}
	return 0;
	
} 
 
  

还能优化,当ans能够表示k的话就直接退出,对于大部分为yes的数据,加速明显。

#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<map>
#include<vector>
#include<set>
#include<iostream>
#include<queue>
#include<ctime>
#define A puts("YES")
#define B puts("NO")
//#define A puts("Yes")
//#define B puts("No")
#define fo(i,a,b) for (register int (i)=(a);(i)<=(b);(i)++)
#define fd(i,b,a) for (register int (i)=(b);(i)>=(a);(i)--)
#define mk(x,y) make_pair((x),(y))
#define lc (o<<1)
#define rc (o<<1|1)
using namespace std;
//typedef __int128 i128;
typedef double db;
typedef long long ll;
const int N=2e5+5;
const ll inf=1ll<<60;
int b[30];
struct lb{
	int a[21];
	int t;
	void init(){
		t=0;
		memset(a,0,sizeof(a));
	}
	inline void add(int x){
		fd(i,19,0){
			if (!(x&b[i])) continue;
			if (a[i]) x^=a[i];
			else {
				a[i]=x; t++; break;
			}
		}
	}

	bool ask(int x){
		fd(i,19,0) {
			if (x&b[i]) x^=a[i];
		}
		return x==0;
	}
};
lb g[19][N],ans;
int f[19][N],d[N],a[N];
int to[N*2],nex[N*2],head[N],tot,n,x,y,z,q;
void R(int &x){
	int t=0;
	char ch;
	for (ch=getchar();!('0'<=ch && ch<='9');ch=getchar());
	for(;('0'<=ch && ch<='9');ch=getchar()) t=t*10+ch-'0';
	x=t;
}
inline void merge(lb &x,lb y){
	fo(i,0,19) {
		if (x.t==20) break;
		x.add(y.a[i]);
	}
}
inline void add(int x,int y){
	to[++tot]=y; nex[tot]=head[x]; head[x]=tot;
}
void dfs(int x,int y){
	g[0][x].add(a[x]);
	for (int i=head[x];i;i=nex[i]){
		int v=to[i];
		if (v==y) continue;
		
		f[0][v]=x;
		d[v]=d[x]+1;
		dfs(v,x);
	}
}
int lca(int x,int y){
	if (d[x]<d[y]) swap(x,y);
	fd(k,18,0) {
		if (d[f[k][x]]>=d[y]) x=f[k][x];
	}
	if (x==y) return x;
	fd(k,18,0) {
		if (f[k][x]^f[k][y]) x=f[k][x],y=f[k][y];
	}
	return f[0][x];
}
void ask(int x,int y){
	int l=lca(x,y);
	
	fd(k,18,0) {
		if (d[f[k][x]]>=d[l]) {
			merge(ans, g[k][x]);
			x=f[k][x];
			if (ans.ask(z)) return;
		}
	}
	
	fd(k,18,0) {
		if (d[f[k][y]]>=d[l]) {
			merge(ans, g[k][y]);
			y=f[k][y];
			if (ans.ask(z)) return;
		}
	}
	ans.add(a[l]);
}
int main()
{
//	freopen("data.in","r",stdin);
//	freopen("ans.out","w",stdout);
	b[0]=1;
	fo(i,1,20) b[i]=b[i-1]*2;
	
	R(n);
	fo(i,1,n) R(a[i]);
	
	fo(i,1,n-1){
		R(x); R(y);
		add(x,y); add(y,x);
	}
	
	f[0][1]=1;
	g[0][0].add(a[1]);
	dfs(1,0);
	
	fo(j,1,18) fo(i,1,n) {
		f[j][i]=f[j-1][f[j-1][i]];
		
		if (f[j-1][i]==1) {
			g[j][i]=g[j-1][i];
			continue;
		}
		fo(k,0,19) g[j][i].a[k]=g[j-1][i].a[k];
		merge(g[j][i], g[j-1][f[j-1][i]]);
	}
	
//	return 0;
	
	R(q);
	while (q--){
		R(x); R(y); R(z);
		ans.init();
		
		ask(x,y);
		if (ans.ask(z)) A; else B;
	}
	return 0;
	
} 
 
  
 

怎么办?
别忘了,我们还有终极手段,循环展开。

#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<map>
#include<vector>
#include<set>
#include<iostream>
#include<queue>
#include<ctime>
#define A puts("YES")
#define B puts("NO")
//#define A puts("Yes")
//#define B puts("No")
#define fo(i,a,b) for (register int (i)=(a);(i)<=(b);(i)++)
#define fd(i,b,a) for (register int (i)=(b);(i)>=(a);(i)--)
#define mk(x,y) make_pair((x),(y))
#define lc (o<<1)
#define rc (o<<1|1)
using namespace std;
//typedef __int128 i128;
typedef double db;
typedef long long ll;
const int N=2e5+5;
const ll inf=1ll<<60;
int b[30];
struct lb{
	int a[21];
	int t;
	void init(){
		t=0;
		memset(a,0,sizeof(a));
	}
	inline void add(int x){
		fd(i,19,0){
			if (!(x&b[i])) continue;
			if (a[i]) x^=a[i];
			else {
				a[i]=x; t++; break;
			}
		}
	}

	inline bool ask(int x){
		fd(i,19,0) {
			if (x&b[i]) x^=a[i];
		}
		return x==0;
	}
};
lb g[19][N],ans;
int f[19][N],d[N],a[N];
int to[N*2],nex[N*2],head[N],tot,n,x,y,z,q;
inline void R(int &x){
	int t=0;
	char ch;
	for (ch=getchar();!('0'<=ch && ch<='9');ch=getchar());
	for(;('0'<=ch && ch<='9');ch=getchar()) t=t*10+ch-'0';
	x=t;
}
inline void merge(lb &x,lb y){
	fo(i,0,19) {
		if (x.t==20) break;
		x.add(y.a[i]);
	}
}
inline void add(int x,int y){
	to[++tot]=y; nex[tot]=head[x]; head[x]=tot;
}
void dfs(int x,int y){
	g[0][x].add(a[x]);
	for (register int i=head[x];i;i=nex[i]){
		int v=to[i];
		if (v==y) continue;
		
		f[0][v]=x;
		d[v]=d[x]+1;
		dfs(v,x);
	}
}
int lca(int x,int y){
	if (d[x]<d[y]) swap(x,y);
	fd(k,18,0) {
		if (d[f[k][x]]>=d[y]) x=f[k][x];
	}
	if (x==y) return x;
	fd(k,18,0) {
		if (f[k][x]^f[k][y]) x=f[k][x],y=f[k][y];
	}
	return f[0][x];
}
void ask(int x,int y){
	int l=lca(x,y);
	
	fd(k,18,0) {
		if (d[f[k][x]]>=d[l]) {
			merge(ans, g[k][x]);
			x=f[k][x];
			if (ans.ask(z)) return;
		}
	}
	
	fd(k,18,0) {
		if (d[f[k][y]]>=d[l]) {
			merge(ans, g[k][y]);
			y=f[k][y];
			if (ans.ask(z)) return;
		}
	}
	ans.add(a[l]);
}
int main()
{
//	freopen("data.in","r",stdin);
//	freopen("ans.out","w",stdout);
	b[0]=1;
	fo(i,1,20) b[i]=b[i-1]*2;
	
	R(n);
	fo(i,1,n) R(a[i]);
	
	fo(i,1,n-1){
		R(x); R(y);
		add(x,y); add(y,x);
	}
	
	f[0][1]=1;
	g[0][0].add(a[1]);
	dfs(1,0);
	
	fo(j,1,4) fo(i,1,n) {
		f[j][i]=f[j-1][f[j-1][i]];
		
		if (f[j-1][i]==1) {
			g[j][i]=g[j-1][i];
			continue;
		}
		fo(k,0,19) g[j][i].a[k]=g[j-1][i].a[k];
		merge(g[j][i], g[j-1][f[j-1][i]]);
	}
	
	fo(j,5,8) fo(i,1,n) {
		f[j][i]=f[j-1][f[j-1][i]];
		
		if (f[j-1][i]==1) {
			g[j][i]=g[j-1][i];
			continue;
		}
		fo(k,0,19) g[j][i].a[k]=g[j-1][i].a[k];
		merge(g[j][i], g[j-1][f[j-1][i]]);
	}
	
	fo(j,9,12) fo(i,1,n) {
		f[j][i]=f[j-1][f[j-1][i]];
		
		if (f[j-1][i]==1) {
			g[j][i]=g[j-1][i];
			continue;
		}
		fo(k,0,19) g[j][i].a[k]=g[j-1][i].a[k];
		merge(g[j][i], g[j-1][f[j-1][i]]);
	}
	
	fo(j,13,16) fo(i,1,n) {
		f[j][i]=f[j-1][f[j-1][i]];
		
		if (f[j-1][i]==1) {
			g[j][i]=g[j-1][i];
			continue;
		}
		fo(k,0,19) g[j][i].a[k]=g[j-1][i].a[k];
		merge(g[j][i], g[j-1][f[j-1][i]]);
	}
	
	fo(j,17,18) fo(i,1,n) {
		f[j][i]=f[j-1][f[j-1][i]];
		
		if (f[j-1][i]==1) {
			g[j][i]=g[j-1][i];
			continue;
		}
		fo(k,0,19) g[j][i].a[k]=g[j-1][i].a[k];
		merge(g[j][i], g[j-1][f[j-1][i]]);
	}
	
	
//	return 0;
	
	R(q);
	while (q--){
		R(x); R(y); R(z);
		ans.init();
		
		ask(x,y);
		if (ans.ask(z)) A; else B;
	}
	return 0;
	
} 
 
  
 


ohhhhhh!

正解待补。