[NOI2023] 深搜

发布时间 2023-07-31 20:34:35作者: 暗蓝色的星空

和考试的时候思路差不多。

首先考虑钦定一部分关键点是合法的 ,带上容斥系数。

对于一条非树边,要求其在任何一个钦定点作为根的时候都不是横叉边。

具体而言,对于一个钦定点集合,我们建出钦定点集合的虚树,那么符合条件的非树边有如下几类:

不妨先考虑特殊性质 \(B\) ,没有横叉边的情况:

  • 完全在某条虚树边上

  • 从虚树上某个点向非虚树上的子树方向的边

  • 虚树根向上方的边

对于每个点,预处理第二类和第三类的边的个数,记为 \(down_x,up_x\)

考试的时候写的是,\(dp_{x,i}\) 表示考虑 \(x\) 子树内的点和边,选出的、覆盖当前点的上端点最浅的非树边深度是 \(i\) 的方案数,转移和 [NOI2020]命运 一模一样。

但是这样需要线段树合并,而且不太好拓展。

我们考虑设 \(dp_x\) 表示 \(x\) 为虚树上的点的方案数,对于每个子树 \(y\),如果这子树里没选点,方案数是 \(2^{down_y}\) ,否则,方案数是枚举最浅的点 \(z\),用 \(dp_y\) 乘上 \(x\to z\) 的方案数,这个可以边转移边维护。

\(DFS\) 序为下标开一颗线段树,\(dp_x\) 存在 \(dfn_x\) 位置。

更新就是一个区间乘法,区间求和。

现在考虑一般情况。

注意到横叉边只有跨过虚树根的才合法,且必须被虚树包含且最多涉及到两棵子树。

对于一个二维平面上的点 \((x,y)\) ,我们表示选择的两个子树最浅的点的 \(dfs\) 序分别是 \(x,y\) 的方案数,初始时是 \(dp_x\times dp_y\)

对于每个横叉边,相当于一个矩形内的方案数乘二,最终求所有位置的和,可以离散化后做一遍扫描线解决。

复杂度 \(O(n\log n)\)

#include<bits/stdc++.h>
using namespace std;
const int N = 5e5+7;
struct edge 
{
	int y,next;
}e[2*N];
int flink[N],t=0;
void add(int x,int y)
{
	e[++t].y=y;
	e[t].next=flink[x];
	flink[x]=t;
}
int n,m,K;
bool key[N];
int dep[N],st[N],ed[N],tim=0;
int jump[N][21];
int fup[N],fdown[N],tree[N],pw[N];
const int mod = 1e9+7;
vector<int> Up[N],Down[N],G[N],A[N]; 
int Jump(int x,int y)
{
	for(int k=19;k>=0;k--)
	if(jump[x][k]&&dep[jump[x][k]]>dep[y])x=jump[x][k];
	return x;
}
int lca(int x,int y)
{
	if(dep[x]<dep[y])swap(x,y);
	for(int k=19;k>=0;k--)if(dep[jump[x][k]]>=dep[y])x=jump[x][k];
	if(x==y)return x;
	for(int k=19;k>=0;k--)if(jump[x][k]!=jump[y][k]) 
	{
		x=jump[x][k];
		y=jump[y][k];
	}
	return jump[x][0];
}
void dfs(int x,int pre)
{
	jump[x][0]=pre;
	for(int k=1;jump[x][k-1];k++)jump[x][k]=jump[jump[x][k-1]][k-1]; 
	dep[x]=dep[pre]+1;
	st[x]=++tim;
	for(int i=flink[x];i;i=e[i].next)
	{
		int y=e[i].y;
		if(y==pre)continue;
		dfs(y,x);
	}
	ed[x]=tim;
}
inline bool Anc(int x,int y){return st[x]<=st[y]&&st[y]<=ed[x];}
int U[N],V[N];
void dfs2(int x,int pre)
{
	tree[x]=Down[x].size();
	for(int i=flink[x];i;i=e[i].next)
	{
		int y=e[i].y;
		if(y==pre)continue;
		dfs2(y,x);
		tree[x]+=tree[y];
		fdown[x]+=fdown[y];
	}
}
int dis[N];
inline int plu(int a,int b){return a+b>=mod?a+b-mod:a+b;}
inline int dec(int a,int b){return a-b<0?a-b+mod:a-b;}
struct segment 
{
	int sum[N*4],tag[N*4];
	void build(int k,int l,int r)
	{
		sum[k]=0;tag[k]=1;
		if(l==r)
		{
			sum[k]=dis[l];
			return;
		}
		int mid=(l+r)>>1;
		build(k<<1,l,mid);
		build(k<<1|1,mid+1,r);
		sum[k]=plu(sum[k<<1],sum[k<<1|1]);
	}
	void pushup(int k)
	{
		sum[k]=plu(sum[k<<1],sum[k<<1|1]);
	}
	void pushtag(int k,int v)
	{
		sum[k]=1ll*sum[k]*v%mod;
		tag[k]=1ll*tag[k]*v%mod;
	}
	void pushdown(int k)
	{
		if(tag[k]^1)
		{
			pushtag(k<<1,tag[k]);
			pushtag(k<<1|1,tag[k]);
			tag[k]=1;
		}
	}
	void mul(int k,int l,int r,int L,int R,int v)
	{
		if(L<=l&&r<=R) 
		{
			pushtag(k,v);
			return;
		}
		pushdown(k);
		int mid=(l+r)>>1;
		if(L<=mid)mul(k<<1,l,mid,L,R,v);
		if(R>mid) mul(k<<1|1,mid+1,r,L,R,v);
		pushup(k);
	}
	int query(int k,int l,int r,int L,int R)
	{
		if(L>R)return 0;
		if(L<=l&&r<=R) return sum[k];
		pushdown(k);
		int mid=(l+r)>>1;
		int res=0;
		if(L<=mid)res=plu(res,query(k<<1,l,mid,L,R));
		if(R>mid) res=plu(res,query(k<<1|1,mid+1,r,L,R));
		return res;
	}
	void upd(int k,int l,int r,int x,int v)
	{
		if(l==r) 
		{
			sum[k]=v;
			return;
		}
		pushdown(k);
		int mid=(l+r)>>1;
		if(x<=mid)upd(k<<1,l,mid,x,v);
		else upd(k<<1|1,mid+1,r,x,v);
		pushup(k);
	}
}T,S;
int dp[N][4],f[4],ans=0;
int ipw[N];
const int inv2=(mod+1)/2;
int dct[N*5],tot=0;
struct Info 
{
	int x,l,r,v;
}seq[N*5];
bool cmp(Info A,Info B)
{
	return A.x<B.x;
}
int cnt=0;
int get(int x)
{
	return lower_bound(dct+1,dct+tot+1,x)-dct; 
}
int siz[N];
int calc(int x)
{
	if(A[x].empty())return dp[x][2];
	tot=0;
	dct[++tot]=st[x];
	dct[++tot]=ed[x]+1;
	for(int i:A[x])
	{
		int u=U[i],v=V[i];
		dct[++tot]=st[u];dct[++tot]=ed[u]+1;
		dct[++tot]=st[v];dct[++tot]=ed[v]+1;
	}
	sort(dct+1,dct+tot+1);
	tot=unique(dct+1,dct+tot+1)-dct-1;
	for(int i=1;i<=tot;i++)dis[i]=T.query(1,1,n,dct[i],dct[i+1]-1);
	S.build(1,1,tot);
	cnt=0;
	seq[++cnt]=(Info){st[x],st[x],ed[x]+1,1};
	seq[++cnt]=(Info){ed[x]+1,st[x],ed[x]+1,1};
	for(int i:A[x])
	{
		int u=U[i],v=V[i];
		seq[++cnt]=(Info){st[u],st[v],ed[v]+1,2};
		seq[++cnt]=(Info){ed[u]+1,st[v],ed[v]+1,inv2};
		seq[++cnt]=(Info){st[v],st[u],ed[u]+1,2};
		seq[++cnt]=(Info){ed[v]+1,st[u],ed[u]+1,inv2};
	}
	sort(seq+1,seq+cnt+1,cmp);
	int res=0;
	for(int i=1;i<=cnt;i++)
	{
		if(i>1)res=(res+1ll*T.query(1,1,n,seq[i-1].x,seq[i].x-1)*S.sum[1]%mod)%mod;
		S.mul(1,1,tot,get(seq[i].l),get(seq[i].r)-1,seq[i].v);
	}
	for(int i=flink[x];i;i=e[i].next)
	{
		int y=e[i].y;
		if(y==jump[x][0])continue;
		int V=T.query(1,1,n,st[y],ed[y]); 
		res=(res-1ll*V*V%mod+mod)%mod;
	}
	res=1ll*res*inv2%mod*ipw[tree[x]]%mod;
	return res;
}
void dfs3(int x,int pre)
{
	fup[x]+=siz[x]-Down[x].size();
	for(int i=flink[x];i;i=e[i].next)
	{
		int y=e[i].y;
		if(y==pre)continue;
		fup[y]=fup[x]+tree[x]-fdown[y];
		dfs3(y,x);
	}
	for(auto u:Down[x])T.mul(1,1,n,st[u],ed[u],2);
	dp[x][0]=1;
	for(int i=flink[x];i;i=e[i].next)
	{
		int y=e[i].y;
		if(y==pre)continue;
		for(int c=0;c<=3;c++)f[c]=dp[x][c],dp[x][c]=0;
		for(int c=0;c<=3;c++)
		{
			dp[x][c]=plu(dp[x][c],1ll*f[c]*pw[fdown[y]]%mod);
			dp[x][min(3,c+1)]=plu(dp[x][min(3,c+1)],1ll*f[c]*T.query(1,1,n,st[y],ed[y])%mod);
		}
	}
	for(int i=flink[x];i;i=e[i].next)
	{
		int y=e[i].y;
		if(y==pre)continue;
		T.mul(1,1,n,st[y],ed[y],pw[tree[x]-fdown[y]]);
	}
	int F=0;
	if(key[x])
	{
		for(int c=0;c<=3;c++)
		F=(F+mod-dp[x][c])%mod; 
	}
	F=(F+dp[x][3])%mod;
	ans=(ans+1ll*(F+calc(x))%mod*pw[fup[x]]%mod)%mod;
	F=(F+dp[x][2])%mod;
	T.upd(1,1,n,st[x],F);
}
int main()
{
	int tp;
	cin>>tp;
	cin>>n>>m>>K;
	for(int i=1;i<n;i++)
	{
		int x,y;
		scanf("%d %d",&x,&y);
		add(x,y);
		add(y,x);
	}
	dfs(1,0);
	for(int i=1;i<=m;i++)
	{
		int x,y;
		scanf("%d %d",&x,&y);
		if(st[x]>st[y])swap(x,y);
		if(Anc(x,y))
		{	
			Up[y].push_back(x);
			Down[x].push_back(y);
			fdown[Jump(y,x)]++;
		}
		else A[lca(x,y)].push_back(i);
		siz[x]++;siz[y]++;
		U[i]=x;V[i]=y;
	}
	for(int i=1;i<=K;i++)
	{
		int x;
		scanf("%d",&x);
		key[x]=1;
	}
	pw[0]=1;ipw[0]=1;
	for(int i=1;i<=m;i++)
	{
		pw[i]=1ll*pw[i-1]*2%mod;
		ipw[i]=1ll*ipw[i-1]*inv2%mod;
	}
	dfs2(1,0);
	T.build(1,1,n);
	dfs3(1,0);
	cout<<(mod-ans)%mod;
	return 0;
}