P9555 「CROI · R1」浣熊的阴阳鱼

发布时间 2023-08-22 19:33:51作者: One_JuRuo

思路

暴力

比赛的时候想过树链剖分,然后想不出来怎么处理区间合并,再加上树链剖分代码量比较大,我又比较懒,就随手写了个暴力拿了40pts。

思路就是暴力求得 \(u\)\(v\) 的简单路径,然后暴力枚举模拟一遍。

40pts 代码

#include<bits/stdc++.h>
using namespace std;
int n,q,a[100005],e[200005],ne[200005],h[100005],idx=1,k,x,y;
int dep[100005],fa[100005][25],z[100005],cnt;
int f[2];
void dfs1(int u,int f)
{
	for(int i=h[u];i;i=ne[i]) if(e[i]!=f) fa[e[i]][0]=u,dep[e[i]]=dep[u]+1,dfs1(e[i],u);
}
inline void add(int a,int b)
{
	e[idx]=b,ne[idx]=h[a],h[a]=idx++;
}
int LCA(int x,int y)
{
	if(dep[x]<dep[y]) swap(x,y);
	for(int i=20;i>=0;--i) if(dep[fa[x][i]]>=dep[y]) x=fa[x][i];
	if(x==y) return x;
	for(int i=20;i>=0;--i) if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
	return fa[x][0];
}
int main()
{
	scanf("%d%d",&n,&q);
	for(int i=1;i<=n;++i) scanf("%d",&a[i]);
	for(int i=1;i<n;++i) scanf("%d%d",&x,&y),add(x,y),add(y,x);
	fa[1][0]=1,dep[1]=1,dfs1(1,0);
	for(int i=1;i<=20;++i) for(int u=1;u<=n;++u) fa[u][i]=fa[fa[u][i-1]][i-1];
	while(q--)
	{
		scanf("%d%d",&k,&x);
		if(k==1) a[x]^=1;
		else
		{
			int f[2],lca,ans,cnt,sum;
			scanf("%d",&y),f[0]=f[1]=cnt=ans=0,lca=LCA(x,y);
			while(x!=lca) z[++cnt]=x,x=fa[x][0];
			z[++cnt]=x,cnt+=dep[y]-dep[lca],sum=cnt;
			while(y!=lca) z[sum--]=y,y=fa[y][0];
			sum=0;
			for(int i=1;i<=cnt;++i)
			{
				if(f[a[z[i]]^1]) ans++,f[a[z[i]]^1]--,sum--;
				else if(sum<2) f[a[z[i]]]++,sum++;
			}
			printf("%d\n",ans);
		}
	}
	return 0;
}

正解:树链剖分

显然,暴力时间复杂度非常高,最后一个 subtask 会 TLE。

看了 MaxBlazeResFire 巨佬的题解后恍然大悟,是类似状压的思路,因为篮子里不可能放一阴一阳,因为遇到就会被吃掉,所以篮子里的状态只有 \(5\) 个。分别是:啥都没有,有一条阴,有两条阴,有一条阳,有两条阳,分别用数字 \(0,1,2,3,4\) 对应。

因为前面区间过了后,篮子里的鱼会对后续区间产生影响,所以我们自然而然地想到,需要用一个变量储存某个状态经过该区间后的状态,用 \(suf_i\) 表示状态 \(i\) 经过该区间后状态为 \(suf_i\),然后再随便用一个 \(sum_i\) 代表该区间能吃几条阴阳鱼,也就是答案贡献。

区间合并就很容易推出来了,假设我们用区间 \(a\) 和区间 \(b\) 合并为区间 \(c\)

\[c.suf_i=b.suf_{a.suf_i} \]

\[c.sum_i=a.sum_i+b.sum_a.suf_i \]

用代码表示就是,

for(long long i=0;i<=4;i++) c.sum[i]=a.sum[i]+b.sum[a.suf[i]],c.suf[i]=b.suf[a.suf[i]];

需要注意的是,因为对于同一个区间,方向不同,答案也不一样,所以需要一个对应的反转的数组。

题目要求里有两个操作,一个是单点修改,一个是区间查询。

单点修改很容易,这里就不展开了。

重点是查询,其实查询和上面的暴力思路差别不大,只不过上面是老实的挨个找,这里可以用树链剖分分成一段一段,这就是这道题用树链剖分的原因。

还不会树链剖分的,可以去这道模板题

对于点 \(u\), \(v\)。我们设它们的 LCA 为 \(l\)

整个路径就分成了两部分:\(u\to l\)\(l\to v\)

由于树的性质,点往上找很容易,往下找就很麻烦。

所以我们就都往上找,用一个 vector 存下每段。

需要注意的是,两边方向不一样,记得要颠转。

因为树链剖分是从上往下,所以到时候需要注意那些部分需要用反转的数组,哪些部分用没反转的数组。

AC 代码

#include<bits/stdc++.h>
using namespace std;
long long n,q,a,b,k,siz[100005],son[100005],dep[100005],top[100005],f[100005][17],c[100005];
long long dfn[100005],wc[100005],cnt;
vector<long long>v[100005];
struct node{long long sum[5],suf[5];}t[400005],ft[400005],s[2];
inline void init()//s是单位数组,用来直接赋值,这个可以自己按照定义推出来 
{
	s[0].sum[0]=0,s[0].sum[1]=1,s[0].sum[2]=1,s[0].sum[3]=0,s[0].sum[4]=0;
	s[0].suf[0]=3,s[0].suf[1]=0,s[0].suf[2]=1,s[0].suf[3]=4,s[0].suf[4]=4;
	s[1].sum[0]=0,s[1].sum[1]=0,s[1].sum[2]=0,s[1].sum[3]=1,s[1].sum[4]=1;
	s[1].suf[0]=1,s[1].suf[1]=2,s[1].suf[2]=2,s[1].suf[3]=0,s[1].suf[4]=3;
}
/*树链剖分部分*/
void dfs1(long long u,long long fa)
{
	siz[u]=1,dep[u]=dep[fa]+1;
	for(long long j:v[u])
		if(j!=fa)
		{
			dfs1(j,u),siz[u]+=siz[j],f[j][0]=u;
			if(siz[j]>siz[son[u]]) son[u]=j;
		}
}
void dfs2(long long u)
{
	dfn[u]=++cnt,wc[cnt]=c[u];
	if(!top[u]) top[u]=u;
	if(son[u]) top[son[u]]=top[u],dfs2(son[u]);
	for(long long j:v[u]) if(j!=son[u]&&j!=f[u][0]) dfs2(j);
}
/*线段树部分*/ 
inline node merge(node a,node b){node c;for(long long i=0;i<=4;i++) c.sum[i]=a.sum[i]+b.sum[a.suf[i]],c.suf[i]=b.suf[a.suf[i]];return c;}//区间合并 
inline void pushup(long long p){t[p]=merge(t[p<<1],t[p<<1|1]),ft[p]=merge(ft[p<<1|1],ft[p<<1]);}//注意:反向的合并是右+左 
void build(long long p,long long l,long long r)
{
	if(l==r)
	{
		t[p]=ft[p]=s[wc[l]];
		return;
	}
	long long mid=l+r>>1;
	build(p<<1,l,mid),build(p<<1|1,mid+1,r);
	pushup(p);
}
void update(long long p,long long l,long long r,long long x,long long k)//单点修改 
{
	if(l==r)
	{
		t[p]=ft[p]=s[k];
		return;
	}
	long long mid=l+r>>1;
	if(x<=mid) update(p<<1,l,mid,x,k);
	else update(p<<1|1,mid+1,r,x,k);
	pushup(p);
}
node ask(long long p,long long l,long long r,long long L,long long R)//正向区间查询 
{
	if(L<=l&&r<=R) return t[p];
	long long mid=l+r>>1;
	if(R<=mid) return ask(p<<1,l,mid,L,R);
	if(L>mid) return ask(p<<1|1,mid+1,r,L,R);
	return merge(ask(p<<1,l,mid,L,R),ask(p<<1|1,mid+1,r,L,R));
}
node revask(long long p,long long l,long long r,long long L,long long R)//反向区间查询 
{
	if(L<=l&&r<=R) return ft[p];
	long long mid=l+r>>1;
	if(R<=mid) return revask(p<<1,l,mid,L,R);
	if(L>mid) return revask(p<<1|1,mid+1,r,L,R);
	return merge(revask(p<<1|1,mid+1,r,L,R),revask(p<<1,l,mid,L,R));//区间右+左 
}
inline long long LCA(long long x,long long y)//求LCA 
{
	while(top[x]!=top[y])
	{
		if(dep[top[x]]<dep[top[y]]) swap(x,y);
		x=f[top[x]][0];
	}
	return (dep[x]<dep[y])?x:y;
}
inline long long nson(long long rt,long long u)//求rt的那个儿子是u的祖宗 
{
	for(long long i=16;i>=0;--i) if(dep[f[u][i]]>dep[rt]) u=f[u][i];
	return (u==rt)?0:u;//记得特判特殊情况 
}
inline long long gask(long long u,long long v)
{
	long long lca=LCA(u,v),j=nson(lca,v),len;node ans;
	vector<pair<long long,long long>>l,r;//用两个vector存两边的路径 
	while(top[u]!=top[lca]) l.push_back(make_pair(-dfn[top[u]],-dfn[u])),u=f[top[u]][0];//左侧路径是深度高到低,所以需要反转,赋称负值标记 
	l.push_back(make_pair(-dfn[lca],-dfn[u]));
	if(j)
	{
		while(top[v]!=top[j]) r.push_back(make_pair(dfn[top[v]],dfn[v])),v=f[top[v]][0];//右侧路径是深度低到高,不需要反转 
		r.push_back(make_pair(dfn[j],dfn[v]));
	}
	reverse(r.begin(),r.end());//颠转右侧路径 
	for(pair<long long,long long>i:r) l.push_back(i);//把左右侧路径合并
	/*先算一端区间,让后面好写一点*/ 
	if(l[0].first<0) ans=revask(1,1,n,-l[0].first,-l[0].second);//需要反转 
	else ans=ask(1,1,n,l[0].first,l[0].second);
	len=l.size();
	for(long long i=1;i<len;++i)
		if(l[i].first<0) ans=merge(ans,revask(1,1,n,-l[i].first,-l[i].second));//需要反转 
		else ans=merge(ans,ask(1,1,n,l[i].first,l[i].second));
	return ans.sum[0];//求得是以状态0(什么鱼也没有)开始的答案 
}
int main()
{
	init();
	scanf("%lld%lld",&n,&q);
	for(long long i=1;i<=n;++i) scanf("%lld",&c[i]);
	for(long long i=1;i<n;++i) scanf("%lld%lld",&a,&b),v[a].push_back(b),v[b].push_back(a);	
	f[1][0]=1,dfs1(1,0),dfs2(1),build(1,1,n);
	for(long long j=1;j<=16;++j) for(long long i=1;i<=n;i++) f[i][j]=f[f[i][j-1]][j-1];
	for(long long i=1;i<=q;++i)
	{
		scanf("%lld%lld",&k,&a);
		if(k==1) c[a]^=1,update(1,1,n,dfn[a],c[a]);
		else scanf("%lld",&b),printf("%lld\n",gask(a,b));
	}
	return 0;
}