GDKOI2023题目乱写

发布时间 2023-10-12 15:30:52作者: _hjy

Day 1:

T1:直接随机一个向量 \(v\),计算 \(A \times B \times v\)\(C \times v\) ,判断是否相等,时间复杂度为 \(O(n^2)\)

正确性是可以保证的,不过我不是很会证明。

#include<iostream>
using namespace std;
const int MX=3005;
const int MOD=998244353;
int A[MX][MX],B[MX][MX],C[MX][MX];
int v[MX],v1[MX],v2[MX];
int read(){
	int x=0,ch=getchar();
	while(!(ch>=48&&ch<=57))ch=getchar();
	while(ch>=48&&ch<=57)x=x*10+ch-48,ch=getchar();
	return x;
}
int main(){
	freopen("matrix.in","r",stdin);
	freopen("matrix.out","w",stdout);
	srand(132784659);
	int T=read();
	while(T--){
		int n=read();
		for(int i=1;i<=n;i++)
			for(int j=1;j<=n;j++)
				A[i][j]=read();
		for(int i=1;i<=n;i++)
			for(int j=1;j<=n;j++)
				B[i][j]=read();
		for(int i=1;i<=n;i++)
			for(int j=1;j<=n;j++)
				C[i][j]=read();
		for(int i=1;i<=n;i++)v[i]=rand()%MOD;
		for(int i=1;i<=n;i++)
			for(int j=1;j<=n;j++)
				v1[i]=(v1[i]+1ll*v[j]*B[i][j])%MOD;
		for(int i=1;i<=n;i++)
			for(int j=1;j<=n;j++)
				v2[i]=(v2[i]+1ll*v1[j]*A[i][j])%MOD;
		for(int i=1;i<=n;i++)v1[i]=0;
		for(int i=1;i<=n;i++)
			for(int j=1;j<=n;j++)
				v1[i]=(v1[i]+1ll*v[j]*C[i][j])%MOD;
		bool flag=false;
		for(int i=1;i<=n;i++)
			if(v1[i]!=v2[i]){
				flag=true;
				break;
			}
		if(flag){
			putchar('N');
			putchar('o');
		}
		else{
			putchar('Y');
			putchar('e');
			putchar('s');
		}
		putchar('\n');
		for(int i=1;i<=n;i++)v[i]=v1[i]=v2[i]=0;
	}
	return 0;
}

T2:

\(F_{m,n}\) 为前 \(m\) 个无限制,后 \(n\) 个错排的方案数,原答案为 \(A_{m}^{n-m} \times F_{m,n-2m}\)

有以下两个式子

\[F_{m,n}=F_{m-1,n}+F_{m-1,n+1} \]

考虑加入一个无限制的点,可以和一个错排的点交换,减少一个错排限制,或者直接任意放置。

\[F_{m,n}=(m+n-1)F_{m,n-1}+(n-2)F_{m,n-2} \]

考虑最后一个错排元素,假如加入它之前已经是一个满足要求的序列,则让它与任意元素交换,有 \(m+n-1\) 种,原来的方案数为 \(F_{m,n-1}\)

若加入前并不是满足要求的错排,那么最多只有 \(1\) 个点在自己的位置上,让新加入的点与它交换。注意在自己位置上的点一定受错排限制,有 \(n-1\) 种选择,而原来的序列为长度为 \(n-2\) 的错排,有 \(F_{m,n-2}\) 种方案。

直接使用以上公式可以得到 \(60\) 分。对于满分做法,考虑分块。对 \(m\) 进行分块,设块长为 \(B\)

对于每个块,我们可以用 \(O(n)\) 的时间算出最小的的 \(m\)\(F_{m,i}\),对于比它大的 \(m\) ,我们可以使用第一个公式发现贡献是一个组合数,预处理出组合数然后乘上贡献即可。

复杂度分析:

对于每个块,首先要进行一次 \(O(n+B)\) 的预处理,这部分复杂度为 \(O(\frac{n^2}{B}+n)\)

而计算贡献时,容易发现每个查询复杂度是 \(O(B)\) 的,这部分复杂的为 \(O(qB)\)

由于 \(n,q\) 同阶,所以当 \(B\)\(O(\sqrt{n})\) 是复杂度最优,为 \(O(n \sqrt{n} + q \sqrt{n})\)

#include<iostream>
#include<vector>
using namespace std;

#define ll long long

const int B=500,MX=2e5;
const int MOD=998244353;

ll fact[MX+B+7],inv[MX+B+7];
ll F[MX+B+7];
ll g[B+7][B+7];

struct query{
	int m,n,t;
};

vector<query > querys[B+7]; 
int border[B+7];
ll ans[MX+7],conf[MX+7];

ll ksm(ll a,ll b){
	ll ret=1;
	while(b){
		if(b&1)ret=ret*a%MOD;
		a=a*a%MOD;
		b>>=1;
	}
	return ret;
}

void init(int N){
	fact[0]=inv[0]=1;
	for(int i=1;i<=N;i++)fact[i]=fact[i-1]*i%MOD;
	inv[N]=ksm(fact[N],MOD-2);
	for(int i=N-1;i>=1;i--)inv[i]=inv[i+1]*(i+1)%MOD;
}

ll A(int x,int y){
	if(y>x)return 0;
	return fact[x]*inv[x-y]%MOD;
}

ll C(int x,int y){
	if(y>x)return 0;
	return fact[x]*inv[y]%MOD*inv[x-y]%MOD;
}

int main(){
	freopen("qwq.in","r",stdin);
	freopen("qwq.out","w",stdout);
	ios::sync_with_stdio(false);
	int T;
	cin>>T;
	init(MX+B);
	for(int i=1;i<=T;i++){
		int n,m;
		cin>>n>>m;
		conf[i]=A(n-m,m);
		querys[m/B].push_back((query){m,n-2*m,i});
		border[m/B]=max(border[m/B],n-2*m);
		//cout<<m/B<<endl;
	}
	for(int i=0;i<B;i++){
		if(querys[i].size()==0)continue;
		int m=i*B;
		int r=border[i]+B;
		//cout<<"i: "<<i<<" m: "<<m<<" r: "<<r<<endl;
		F[0]=fact[m];
		F[1]=fact[m]*m%MOD;
		for(int j=2;j<=r;j++)
			F[j]=((m+j-1)*F[j-1]%MOD+(j-1)*F[j-2]%MOD)%MOD;
		for(int j=0;j<querys[i].size();j++){
			int n=querys[i][j].n,km=querys[i][j].m,t=querys[i][j].t;
			//cout<<"ans "<<t<<": F("<<km-m<<","<<n<<")"<<endl;
			ans[t]=0;
			for(int k=n;k<=n+km-m;k++)
				ans[t]=(ans[t]+C(km-m,k-n)*F[k]%MOD)%MOD;
			//cout<<ans[t]<<' '<<conf[t]<<endl;
		}
	}
	for(int i=1;i<=T;i++)
		cout<<ans[i]*conf[i]%MOD<<endl;
	return 0;
}

T3:

还没写代码,所以先乱写。

如果没有边的限制就直接数位dp。

如果有就直接大力容斥。

DAY 2:

T1:

考虑中间那个点,往三个方向延伸,然后你去记录最长的三个,整个问题被转化为一个三维偏序。

直接做是 \(O(n \log^2 n)\)的,考虑你不需要知道有多少个,所以可以直接对第一维排序,第二维用线段树记录第三维的最值。

时间复杂度 \(O((n+q )\log n)\)

代码比较难写。

code:

#include<bits/stdc++.h>
using namespace std;
#define lc x<<1
#define rc x<<1|1
#define pii pair<int,int >
const int MX=2e5+7;
vector<int > g[MX];
int read(){
	int x=0,ch=getchar();
	while(!(ch>=48&&ch<=57))ch=getchar();
	while(ch>=48&&ch<=57)x=x*10+ch-48,ch=getchar();
	return x;
}
void print(int x){
	if(x>9)print(x/10);
	putchar(x%10+48);
}
void print3(int x,int y,int z){
	print(x);
	putchar(' ');
	print(y);
	putchar(' ');
	print(z);
	putchar('\n');
}
struct info{
	int x,p,ed,v;
};
vector<info > sorted[MX];
int a[MX],b[MX],c[MX];
int eda[MX],edb[MX],edc[MX];
int jump[MX][25],dep[MX];
int rra[MX],rrb[MX],rrc[MX];
int rx[MX],ry[MX],rz[MX];
int n,q;
bool cmp(info x,info y){
	return x.v>y.v;
}
void dfs1(int x,int fa){
	jump[x][0]=fa;
	dep[x]=dep[fa]+1;
	for(int i=0;i<g[x].size();i++){
		int to=g[x][i];
		if(to==fa)continue;
		dfs1(to,x);
		info tmp;
		tmp.p=to;
		tmp.x=x;
		if(sorted[to].size()!=0){
			tmp.v=sorted[to][0].v+1;
			tmp.ed=sorted[to][0].ed;
		}
		else {
			tmp.v=1;
			tmp.ed=to;
		}
		sorted[x].push_back(tmp);
	}
	sort(sorted[x].begin(),sorted[x].end(),cmp);
}
void updabc(int x){
	int tot=sorted[x].size();
	if(tot>=1){
		a[x]=sorted[x][0].v;
		eda[x]=sorted[x][0].ed;
	}
	if(tot>=2){
		b[x]=sorted[x][1].v;
		edb[x]=sorted[x][1].ed;
	}
	if(tot>=3){
		c[x]=sorted[x][2].v;
		edc[x]=sorted[x][2].ed;
	}
}
void dfs2(int x,int fa){
	for(int i=0;i<g[x].size();i++){
		int to=g[x][i];
		if(to==fa)continue;
		if(sorted[x].size()!=1){
			info tmp;
			if(sorted[x][0].p==to)tmp=sorted[x][1];
			else tmp=sorted[x][0];
			tmp.x=to;
			tmp.p=x;
			tmp.v++;
			sorted[to].push_back(tmp);
		}
		else{
			info tmp;
			tmp.x=to;
			tmp.p=x;
			tmp.v=1;
			tmp.ed=x;
			sorted[to].push_back(tmp);
		}
		sort(sorted[to].begin(),sorted[to].end(),cmp);
		updabc(to);
		dfs2(to,x);
	}
}
void init(){
	for(int i=1;i<n;i++){
		int x,y;
		x=read();
		y=read();
		g[x].push_back(y);
		g[y].push_back(x);
	}
	dfs1(1,1);
	updabc(1);
	dfs2(1,1);
	for(int i=1;i<=20;i++)
		for(int j=1;j<=n;j++)
			jump[j][i]=jump[jump[j][i-1]][i-1];
}
struct points{
	int a,b,c,type,t;
}p[MX<<1];
bool cmp2(points x,points y){
	if(x.a!=y.a)return x.a>y.a;
	else return x.type<y.type;
}
struct node{
	int l,r;
	pii data;
}s[MX<<2];
void upd(int x){
	s[x].data=max(s[lc].data,s[rc].data);
}
void build(int x,int l,int r){
	s[x].l=l;
	s[x].r=r;
	if(l==r){
		s[x].data=make_pair(0,0);
		return;
	}	
	int mid=l+r>>1;
	build(lc,l,mid);
	build(rc,mid+1,r);
	upd(x);
}
void modify(int x,int k,int v,int id){
	if(s[x].l==s[x].r){
		s[x].data.second=id;
		s[x].data.first=v;
		return;
	}
	int mid=s[x].l+s[x].r>>1;
	if(k<=mid)modify(lc,k,v,id);
	else modify(rc,k,v,id);
	upd(x);
}
int query1(int x,int k){
	if(s[x].l==s[x].r)return s[x].data.first;
	int mid=s[x].l+s[x].r>>1;
	if(k<=mid)return query1(lc,k);
	else return query1(rc,k);
}
pii query2(int x,int l,int r){
	//cout<<s[x].l<<' '<<s[x].r<<' '<<l<<' '<<r<<endl;
	if(s[x].l>=l&&s[x].r<=r)return s[x].data;
	int mid=s[x].l+s[x].r>>1;
	pii ans=make_pair(0,0);
	if(l<=mid)ans=max(ans,query2(lc,l,r));
	if(r>mid)ans=max(ans,query2(rc,l,r));
	return ans;
}
int ans[MX];
int grand(int x,int k){
	int cnt=0;
	while(k){
		if(k&1)x=jump[x][cnt];
		cnt++;
		k>>=1;
	}
	return x;
}
int lca(int x,int y){
	if(dep[x]<dep[y])swap(x,y);
	x=grand(x,dep[x]-dep[y]);
	if(x==y)return x;
	for(int i=20;i>=0;i--)
		if(jump[x][i]!=jump[y][i])
			x=jump[x][i],y=jump[y][i];
	return jump[x][0];
}
int link(int x,int y,int t){
	int LCA=lca(x,y);
	int d1=dep[x]-dep[LCA],d2=dep[y]-dep[LCA];
	if(LCA==x)return grand(y,d2-t);
	if(LCA==y)return grand(x,t);
	if(t<=d1)return grand(x,t);
	return grand(y,d1+d2-t);
}
int main(){
	ios::sync_with_stdio(false);
	freopen("game.in","r",stdin);
	freopen("game.out","w",stdout);
	n=read();
	init();
	//for(int i=1;i<=n;i++)cout<<a[i]<<' '<<b[i]<<' '<<c[i]<<endl;
	//cout<<"----------"<<endl;
	for(int i=1;i<=n;i++){
		p[i].a=a[i];
		p[i].b=b[i];
		p[i].c=c[i];
		p[i].type=0;
		p[i].t=i;
	}
	q=read();
	for(int i=1;i<=q;i++){
		int x,y,z,sum;
		x=read();
		y=read();
		z=read();
		rx[i]=x;
		ry[i]=y;
		rz[i]=z;
		sum=(x+y+z)/2;
		rra[i]=p[n+i].a=sum-min(x,min(y,z));
		rrc[i]=p[n+i].c=sum-max(x,max(y,z));
		rrb[i]=p[n+i].b=sum-p[n+i].a-p[n+i].c;
		p[n+i].type=1;
		p[n+i].t=i;
		//cout<<rra[i]<<' '<<rrb[i]<<' '<<rrc[i]<<endl;
	}
	//cout<<"----------"<<endl;
	sort(p+1,p+1+n+q,cmp2);
	build(1,0,n);
	for(int i=1;i<=n+q;i++){
		if(p[i].type==0){
			int mxb=query1(1,p[i].b);
			if(p[i].c>=mxb){
				//cout<<"add "<<p[i].a<<' '<<p[i].b<<' '<<p[i].c<<endl;
				modify(1,p[i].b,p[i].c,p[i].t);
			}
		}
		else{
			pii tmp=query2(1,p[i].b,n);
			//cout<<"answer "<<p[i].a<<' '<<p[i].b<<' '<<p[i].c<<' '<<tmp.first<<' '<<tmp.second<<endl;
			if(tmp.first<p[i].c)ans[p[i].t]=-1;
			else ans[p[i].t]=tmp.second;
		}
	}
	for(int i=1;i<=q;i++){
		int core=ans[i];
		//cout<<core<<endl;
		//cout<<eda[core]<<' '<<edb[core]<<' '<<edc[core]<<endl;
		if(rra[i]+rrb[i]==rx[i]){
			if(rra[i]+rrc[i]==ry[i])
				print3(link(core,eda[core],rra[i]),link(core,edb[core],rrb[i]),link(core,edc[core],rrc[i]));
			else
				print3(link(core,edb[core],rrb[i]),link(core,eda[core],rra[i]),link(core,edc[core],rrc[i]));
		}
		else if(rra[i]+rrb[i]==ry[i]){
			if(rra[i]+rrc[i]==rx[i])
				print3(link(core,eda[core],rra[i]),link(core,edc[core],rrc[i]),link(core,edb[core],rrb[i]));
			else
				print3(link(core,edb[core],rrb[i]),link(core,edc[core],rrc[i]),link(core,eda[core],rra[i]));
		}
		else{
			if(rra[i]+rrc[i]==rx[i])
				print3(link(core,edc[core],rrc[i]),link(core,eda[core],rra[i]),link(core,edb[core],rrb[i]));
			else
				print3(link(core,edc[core],rrc[i]),link(core,edb[core],rrb[i]),link(core,eda[core],rra[i]));
		}
	}
	return 0;
}

T2:不会,据说要用BM,部分分都要FWT

T3:
先考虑如何做一条链。\(O(k)\) 做法就是直接模拟题意思,考虑倍增。

\(F(j,i)\) 表示由 \(i\) 开始长度为 \(2^j\) 的答案,在合并时我们发现后面一半每个元素异或的系数不太对,会少一个 \(2^j\)

例子:

1 2 3 4

5 6 7 8

不过发现由于后一半的系数二进制位都小于 \(j\) ,所以可以直接把加法转为异或,然后就变成了\(2^j \oplus i \oplus v_i\),而只有 \(v_i\) 在第 \(j\) 位的取值会影响这个结果,所以把它的前缀和预处理出来,然后倍增即可。

考虑如何上树。对于每个点,我们可以维护一个前缀和的答案。对于每层点,给他们定一个任意的顺序,然后维护它之前所有点加起来的答案。

考虑倍增过程。\(F(j,i)\) 表示对于点 \(i\) 维护深度为 \(2^j\) 的答案。容易发现过程是类似的,不过要维护一个最右端的儿子。

时间复杂度 \(O(n \log n + q \log n)\)

#include<bits/stdc++.h>
using namespace std;

#define ll long long

const int MX=1e6+7;

vector<int > g[MX],dep[MX];
ll pre[MX],lst[MX],sz[MX];
ll son[25][MX],sum[25][MX],f[25][MX];
ll v[MX];
int n,q;

ll read(){
	int x=0,ch=getchar();
	while(!(ch>=48&&ch<=57))ch=getchar();
	while(ch>=48&&ch<=57)x=x*10+ch-48,ch=getchar();
	return x;
}

void print(ll x){
	if(x>9)print(x/10);
	putchar(x%10+48);
}

void dfs1(int x,int d){
	pre[x]=lst[d];
	lst[d]=x;
	son[0][x]=lst[d+1];
	dep[d].push_back(x);
	for(int i=0;i<g[x].size();i++){
		int to=g[x][i];
		son[0][x]=to;
		dfs1(to,d+1);
	}
}

void dfs2(int x){
	for(int i=0;i<g[x].size();i++){
		int to=g[x][i];
		dfs2(to);
	}
	int v=son[0][x];
	sz[x]+=sz[v];
	for(int t=0;t<=20;t++)sum[t][x]+=sum[t][v];
}

ll ans(int x,int k){
	if(x==0)return 0;
	k++;
	ll tx=x,tk=k,b=0,ret=0;
	for(int t=20;t>=0;t--){
		if(k>=(1<<t)){
			ret+=f[t][x];
			x=son[t][x];
			k-=(1<<t);
		}
	}
	b=x;
	x=tx;
	k=tk;
	for(int t=20;t>=0;t--){
		if(k>=(1<<t)){
			x=son[t][x];
			k-=(1<<t);
			ll c=sz[x]-sz[b],s=sum[t][x]-sum[t][b];
			ret+=c*(1<<t)-s*(1<<(t+1));
		}
	}
	return ret;
}

int main(){
	freopen("tree.in","r",stdin);
	freopen("tree.out","w",stdout);
	n=read();
	for(int i=1;i<=n;i++)v[i]=read();
	for(int i=2;i<=n;i++){
		int fa=read();
		g[fa].push_back(i);
	}
	dfs1(1,0);
	for(int i=0;i<=n;i++){
		for(int j=0;j<dep[i].size();j++){
			int u=dep[i][j];
			f[0][u]=v[u];
			sz[u]=1;
			for(int t=0;t<=20;t++)
				sum[t][u]=(v[u]>>t)&1;
			if(j!=0){
				int w=dep[i][j-1];
				sz[u]+=sz[w];
				f[0][u]+=f[0][w];
				for(int t=0;t<=20;t++)sum[t][u]+=sum[t][w];
			}
		}
	}
	dfs2(1);
	for(int j=1;j<=20;j++)
		for(int i=1;i<=n;i++){
			int u=son[j-1][i];
			int v=son[j-1][u];
			son[j][i]=v;
			ll c=sz[u]-sz[v],s=sum[j-1][u]-sum[j-1][v];
			f[j][i]=f[j-1][i]+f[j-1][u]+c*(1<<(j-1))-s*(1<<j);
		}
	int q=read();
	while(q--){
		int x=read();
		int k=read();
		print( ans(x,k)-ans(pre[x],k) );
		putchar('\n');
	}
	return 0;
}

总结:VP:\(100 + 20 + 20 + 100 + 0 + 10=250\),我菜死了