D - Tree Partition
考虑将树转换到一个序列上,钦定\(1\)为根节点,\(1\)的父亲为\(0\),在序列上,孩子向父亲连边
然后考虑设\(dp\)状态\(dp[i][j]\)表示前\(i\)个点,分成\(j\)段的方案数,那么\(dp[i][j]\)从\(dp[k][j-1]\)转移过来要满足以下条件之一:
点\(i\)的后向边\((a,b)\)满足\(a\leq i\),\(b>i\),区间\([i,j]\)的前向边\((a,b)\)满足\(a\in[i,j]\),\(b<i\)
设\(x_1\)表示倒数第二个后向边的起点,\(x_2\)表示倒数第一个后向边的起点
那么对于\(<x_1\)的\(k\)不能转
\([x_1,x_2)\)的\(k\),要满足\([k+1,i]\)没有前向边
\([x_2,i)\)的\(k\),要满足\([k+1,i]\)有且仅有一条前向边
#include<bits/stdc++.h>
#define pb push_back
using namespace std;
const int N=2e5+5,MOD=998244353;
int n,k;
int head[N],cnt=1,fa[N],dp[N][405];
struct node{
int nxt,v;
}tree[N<<1];
void add(int u,int v){
tree[++cnt]={head[u],v},head[u]=cnt;
tree[++cnt]={head[v],u},head[v]=cnt;
}
void dfs(int u){
for(int i=head[u],v;i;i=tree[i].nxt){
if(fa[u]==(v=tree[i].v)) continue;
fa[v]=u,dfs(v);
}
}
int ad(int x,int y){
x+=y;
if(x>=MOD) x-=MOD;
if(x<0) x+=MOD;
return x;
}
int c[N],tr[N];
vector<int> sum[2][405],id[2];
set<int> pos;
int suf[N],top;
int main(){
scanf("%d%d",&n,&k),++n,add(1,2);
for(int i=1,u,v;i<n-1;++i) scanf("%d%d",&u,&v),++u,++v,add(u,v);
for(int i=1;i<=n;++i) pos.insert(i);
fa[2]=1,dfs(2);
id[0].pb(0),id[1].pb(0);
for(int i=0;i<=k;++i) sum[0][i].pb(0),sum[1][i].pb(0);
dp[1][0]=1,sum[0][0].pb(1),id[0].pb(1);
for(int i=1;i<=k;++i) sum[0][i].pb(0);
tr[1]=1;
for(int i=2,x1,x2;i<=n;++i){
if(fa[i]>i) suf[++top]=i;
if(fa[i]<i){
if(pos.size()&&(*prev(pos.end()))>=fa[i])
for(auto it=pos.lower_bound(fa[i]);it!=pos.end()&&(*it)<i;){
++c[*it];
if(c[*it]>1){
id[1].pop_back();
for(int j=0;j<=k;++j) sum[1][j].pop_back();
it=pos.erase(it);
}else ++it;
}
if(pos.size()&&(*prev(pos.end()))>=fa[i])
for(auto it=pos.lower_bound(fa[i]);it!=pos.end()&&(*it)<i;++it){
id[0].pop_back(),id[1].pb(*it);
for(int j=0;j<=k;++j) sum[0][j].pop_back(),sum[1][j].pb(ad(sum[1][j].back(),dp[*it][j]));
tr[*it]=sum[1][0].size()-1;
}
}
while(top&&fa[suf[top]]<=i) --top;
if(!top) x1=x2=1;
else{
x2=suf[top--];
while(top&&fa[suf[top]]<=i) --top;
if(top) x1=suf[top]; else x1=1;
suf[++top]=x2;
}
int l1,r1,l2,r2;
if(*(id[0].end()-1)<x1||x1>x2-1) l1=r1=0;
else l1=tr[*lower_bound(id[0].begin(),id[0].end(),x1)],r1=tr[*(upper_bound(id[0].begin(),id[0].end(),x2-1)-1)];
if(*(id[1].end()-1)<x2||x2>i-1) l2=r2=0;
else l2=tr[*lower_bound(id[1].begin(),id[1].end(),x2)],r2=tr[*(upper_bound(id[1].begin(),id[1].end(),i-1)-1)];
for(int j=1;j<=k;++j){
dp[i][j]=ad(ad(sum[0][j-1][r1],-sum[0][j-1][max(0,l1-1)]),ad(sum[1][j-1][r2],-sum[1][j-1][max(0,l2-1)]));
sum[0][j].pb(ad(sum[0][j].back(),dp[i][j]));
// if(j<=i-1) cout<<i<<" "<<j<<" "<<dp[i][j]<<"----\n"<<x1<<" "<<x2<<" "<<l1<<" "<<r1<<" "<<l2<<" "<<r2<<"\n"<<ad(sum[0][j-1][r1],-sum[0][j-1][max(0,l1-1)])<<" "<<ad(sum[1][j-1][r2],-sum[1][j-1][max(0,l2-1)])<<endl;
}
sum[0][0].pb(sum[0][0].back()),id[0].pb(i),tr[i]=sum[0][0].size()-1;
}
for(int i=1;i<=k;++i) printf("%d\n",dp[n][i]);
return 0;