[LOJ3626/QOJ4889] 愚蠢的在线法官

发布时间 2023-10-19 20:09:08作者: 云浅知处

考虑这个矩阵长啥样,首先显然 \(A\) 不能重复否则答案是 \(0\)(有两行两列相同)。

\(A\) 重标号为 DFS 序的顺序,那么行列式的值不改变,因为交换 \(A_i,A_j\) 相当于同时交换两行两列。

考虑把权值 \(v\) 做树上差分,令 \(B_u=v_u-v_{fa(u)}\),那么就等价于对每个 \(i\)\(i\) 子树内的所有点形成的这个矩阵中的每个值都加上 \(B_i\)。那子树的 DFS 序是连续的,所以相当于做一个矩形加。

矩阵上做二维差分不影响行列式,考虑二维差分,那么矩形加就变成了四个单点加。

假设这个对应的 DFS 序区间是 \([l,r]\),那么相当于 M[l][l]++,M[r+1][r+1]++,M[l][r+1]--,M[r+1][l]--,发现这个类似于矩阵树的 Kirchhoff 矩阵,那相当于建一个图出来,然后对每个点连边 \([l,r+1]\),要数生成树个数。

那怎么做呢,发现由于各子树 DFS 序要么包含要么不交,然后你发现任取一个子图如果想同胚于 \(K_4\) 那最后我们发现肯定会出现相交,所以这个是广义串并联图,套广义串并联方法就行了。

时间复杂度 \(O(n\log n)\)

#include<bits/stdc++.h>

#define ll long long

using namespace std;

inline int read(){
	int x=0,f=1;char c=getchar();
	for(;(c<'0'||c>'9');c=getchar()){if(c=='-')f=-1;}
	for(;(c>='0'&&c<='9');c=getchar())x=x*10+(c&15);
	return x*f;
}

const int mod=998244353;
int ksm(int x,int y,int p=mod){
	int ans=1;
	for(int i=y;i;i>>=1,x=1ll*x*x%p)if(i&1)ans=1ll*ans*x%p;
	return ans%p;
}
int inv(int x,int p=mod){return ksm(x,p-2,p)%p;}
mt19937 rnd(time(0));
int randint(int l,int r){return rnd()%(r-l+1)+l;}
void add(int &x,int v){x+=v;if(x>=mod)x-=mod;}
void Mod(int &x){if(x>=mod)x-=mod;}
int cmod(int x){if(x>=mod)x-=mod;return x;}

const int N=5e5+5;
int n,m;
vector<int>G[N];

struct Node{
	int f,g;
	Node(int F,int G):f(F),g(G){}
	Node(){}
	bool operator<(const Node &rhs)const{
		if(f!=rhs.f)return f<rhs.f;
		return g<rhs.g;
	}
};

Node cmpr(Node wy,Node wz){return Node(1ll*wy.f*wz.f,cmod(1ll*wy.f*wz.g%mod+1ll*wy.g*wz.f%mod));}
Node twist(Node u,Node w){return Node(cmod(1ll*w.f*u.g%mod+1ll*u.f*w.g%mod),1ll*w.g*u.g%mod);}

queue<int>q;
set<pair<int,Node> >adj[N];

#define fi first
#define se second
#define mk make_pair

int dfn[N],dfc=0;
int a[N],val[N],fa[N],sz[N];
bool inq[N];

signed main(void){

#ifdef YUNQIAN
	freopen("in.in","r",stdin);
#endif

	n=read(),m=read();
	for(int i=1;i<=n;i++)val[i]=read();
	for(int i=1;i<=m;i++)a[i]=read();
	for(int i=1;i<=n-1;i++){
		int u=read(),v=read();
		G[u].emplace_back(v),G[v].emplace_back(u);		
	}
	
	function<void(int)>dfs=[&](int u){
		dfn[u]=++dfc,sz[u]=1;
		for(int v:G[u])if(v!=fa[u])fa[v]=u,dfs(v),sz[u]+=sz[v];
		add(val[u],mod-val[fa[u]]);
	};
	dfs(1);
	
	for(int i=1;i<=m;i++)a[i]=dfn[a[i]];sort(a+1,a+m+1);
	
	vector<int>fa(m+2);
	for(int i=1;i<=m+1;i++)fa[i]=i;
	function<int(int)>find=[&](int x){return x==fa[x]?x:fa[x]=find(fa[x]);};
	
	auto adde=[&](int x,int y,Node w){
		if(x==y)return ;fa[find(x)]=find(y);
		set<pair<int,Node> >::iterator t=adj[x].lower_bound(mk(y,Node(0,0)));
		Node cur=t->se;
		if(t->fi==y){
			adj[x].erase(mk(y,cur)),adj[x].insert(mk(y,twist(cur,w)));
			adj[y].erase(mk(x,cur)),adj[y].insert(mk(x,twist(cur,w)));
		}
		else adj[x].insert(mk(y,w)),adj[y].insert(mk(x,w));
	};
	
	for(int i=1;i<=n;i++){
		int l=dfn[i],r=dfn[i]+sz[i]-1;
		int L=lower_bound(a+1,a+m+1,l)-a;
		int R=upper_bound(a+1,a+m+1,r)-a-1;
		adde(L,R+1,Node(val[i],1));
	}
	
	m++;
	auto chk=[&](int x){
		if(inq[x])return ;
		if(adj[x].size()<=2)q.push(x),inq[x]=1;
	};
	
	for(int i=1;i<=m;i++)chk(i);
	for(int i=1;i<=m;i++)if(find(i)!=find(1))return puts("0"),0;
	
	int ans=1;
	while(q.size()){
		int x=q.front();q.pop();
		if(adj[x].size()==1){
			auto t=adj[x].begin();int y=t->fi;Node w=t->se;
			ans=1ll*ans*w.f%mod;
			adj[y].erase(mk(x,w)),chk(y);
		}
		else if(adj[x].size()==2){
			auto p=adj[x].begin(),q=--adj[x].end();
			int y=p->fi,z=q->fi;auto wy=p->se,wz=q->se;
			adj[y].erase(mk(x,wy)),adj[z].erase(mk(x,wz));
			Node w=Node(1ll*wy.f*wz.f%mod,cmod(1ll*wy.f*wz.g%mod+1ll*wy.g*wz.f%mod));
			auto r=adj[y].lower_bound(mk(z,Node(0,0)));
			if(r->fi!=z)adj[y].insert(mk(z,w)),adj[z].insert(mk(y,w));
			else{
				auto u=r->se;
				auto v=Node(cmod(1ll*w.f*u.g%mod+1ll*u.f*w.g%mod),1ll*w.g*u.g%mod);
				adj[y].erase(mk(z,u)),adj[y].insert(mk(z,v));
				adj[z].erase(mk(y,u)),adj[z].insert(mk(y,v));
				chk(y),chk(z);
			}
		}
	}
	
	cout<<ans<<endl;

	return 0;
}