[gym103860D]Tree Partition

发布时间 2023-10-10 16:42:01作者: LuoyuSitfitw

D - Tree Partition

考虑将树转换到一个序列上,钦定\(1\)为根节点,\(1\)的父亲为\(0\),在序列上,孩子向父亲连边

然后考虑设\(dp\)状态\(dp[i][j]\)表示前\(i\)个点,分成\(j\)段的方案数,那么\(dp[i][j]\)\(dp[k][j-1]\)转移过来要满足以下条件之一:

\(i\)的后向边\((a,b)\)满足\(a\leq i\)\(b>i\),区间\([i,j]\)的前向边\((a,b)\)满足\(a\in[i,j]\)\(b<i\)

\(x_1\)表示倒数第二个后向边的起点,\(x_2\)表示倒数第一个后向边的起点

那么对于\(<x_1\)\(k\)不能转

\([x_1,x_2)\)\(k\),要满足\([k+1,i]\)没有前向边

\([x_2,i)\)\(k\),要满足\([k+1,i]\)有且仅有一条前向边

#include<bits/stdc++.h>
#define pb push_back
using namespace std;
const int N=2e5+5,MOD=998244353;
int n,k;
int head[N],cnt=1,fa[N],dp[N][405];
struct node{
	int nxt,v;
}tree[N<<1];
void add(int u,int v){
	tree[++cnt]={head[u],v},head[u]=cnt;
	tree[++cnt]={head[v],u},head[v]=cnt;
}
void dfs(int u){
	for(int i=head[u],v;i;i=tree[i].nxt){
		if(fa[u]==(v=tree[i].v)) continue;
		fa[v]=u,dfs(v);
	}
}
int ad(int x,int y){
	x+=y;
	if(x>=MOD) x-=MOD;
	if(x<0) x+=MOD;
	return x;
}
int c[N],tr[N];
vector<int> sum[2][405],id[2];
set<int> pos;
int suf[N],top; 
int main(){
	scanf("%d%d",&n,&k),++n,add(1,2);
	for(int i=1,u,v;i<n-1;++i) scanf("%d%d",&u,&v),++u,++v,add(u,v);
	for(int i=1;i<=n;++i) pos.insert(i);
	fa[2]=1,dfs(2);
	id[0].pb(0),id[1].pb(0);
	for(int i=0;i<=k;++i) sum[0][i].pb(0),sum[1][i].pb(0);
	dp[1][0]=1,sum[0][0].pb(1),id[0].pb(1);
	for(int i=1;i<=k;++i) sum[0][i].pb(0);
	tr[1]=1;
	for(int i=2,x1,x2;i<=n;++i){
		if(fa[i]>i) suf[++top]=i;
		if(fa[i]<i){
			if(pos.size()&&(*prev(pos.end()))>=fa[i])
				for(auto it=pos.lower_bound(fa[i]);it!=pos.end()&&(*it)<i;){
					++c[*it];
					if(c[*it]>1){
						id[1].pop_back();
						for(int j=0;j<=k;++j) sum[1][j].pop_back();
						it=pos.erase(it);
					}else ++it;
				}
			if(pos.size()&&(*prev(pos.end()))>=fa[i])	
				for(auto it=pos.lower_bound(fa[i]);it!=pos.end()&&(*it)<i;++it){
					id[0].pop_back(),id[1].pb(*it);
					for(int j=0;j<=k;++j) sum[0][j].pop_back(),sum[1][j].pb(ad(sum[1][j].back(),dp[*it][j]));
					tr[*it]=sum[1][0].size()-1;
				}
		}
		while(top&&fa[suf[top]]<=i) --top;
		if(!top) x1=x2=1;
		else{
			x2=suf[top--];
			while(top&&fa[suf[top]]<=i) --top;
			if(top) x1=suf[top]; else x1=1; 
			suf[++top]=x2;
		}
		int l1,r1,l2,r2;
		if(*(id[0].end()-1)<x1||x1>x2-1) l1=r1=0;
		else l1=tr[*lower_bound(id[0].begin(),id[0].end(),x1)],r1=tr[*(upper_bound(id[0].begin(),id[0].end(),x2-1)-1)];
		if(*(id[1].end()-1)<x2||x2>i-1) l2=r2=0;
		else l2=tr[*lower_bound(id[1].begin(),id[1].end(),x2)],r2=tr[*(upper_bound(id[1].begin(),id[1].end(),i-1)-1)];
		for(int j=1;j<=k;++j){
			dp[i][j]=ad(ad(sum[0][j-1][r1],-sum[0][j-1][max(0,l1-1)]),ad(sum[1][j-1][r2],-sum[1][j-1][max(0,l2-1)]));
			sum[0][j].pb(ad(sum[0][j].back(),dp[i][j]));
//			if(j<=i-1) cout<<i<<" "<<j<<" "<<dp[i][j]<<"----\n"<<x1<<" "<<x2<<" "<<l1<<" "<<r1<<" "<<l2<<" "<<r2<<"\n"<<ad(sum[0][j-1][r1],-sum[0][j-1][max(0,l1-1)])<<" "<<ad(sum[1][j-1][r2],-sum[1][j-1][max(0,l2-1)])<<endl;
		}
		sum[0][0].pb(sum[0][0].back()),id[0].pb(i),tr[i]=sum[0][0].size()-1;
	}
	for(int i=1;i<=k;++i) printf("%d\n",dp[n][i]);
 
	return 0;