树的重心

发布时间 2023-10-09 19:44:06作者: wscqwq

[CSP-S2019] 树的重心

需要了解重心的很多性质:

  1. 重心有 \(1\sim 2\) 个,满足删除重心后形成的连通块最大大小 \(\le \lfloor \dfrac{n}{2}\rfloor\),若有 \(2\) 个,\(n\) 为偶数,两个重心把树划分成了每个部分大小 \(\dfrac{n}{2}\) 的树。
  2. 以下方法可以找到一个重心:从根节点开始,往重儿子出发(多个选任意一个),最后一个满足 \(n-sz_u\le \lfloor \dfrac{n}{2}\rfloor\) 的点一定是重心。
  3. 如果按照2中方法找到重心,且有 \(2\) 个重心,必然是那个点的父节点。

然后我们考虑切掉一条边的情况,分为两部分:

  1. 第一部分,子树部分,可以用倍增的方法,\(f[u][i]\) 表示向重儿子的方向走 \(2^i\) 步,然后用倍增跳跃找到一个重心,最后验证父节点。
  2. 第二部分,除掉这棵子树的部分,可以考虑换根DP。我们只需要将 \(sz,f\) 数组换到以 \(v\) 为根即可。发现仅有 \(u\) 的信息会发生变化(注意 \(v\) 的信息不会发生变化,因为删除的是边,那么 \(v\) 的答案是第一部分,然后如果遍历到它的子节点,\(u\) 依然是它的子树,所以更改信息无用),单次修改是 \(\log\) 复杂度。修改完后同第一部分的求解方式。最后记得恢复现场。(注意在往子节点走时不需要恢复,在回溯时恢复)

第一部分可以在预处理中求解,也可以和第二部分一块儿求解。

#include<cstdio>
#include<iostream>
#include<cassert>
#include<algorithm>
#include<cstring>
using namespace std;
#define Ed for(int i=h[x];~i;i=ne[i])
#define Ls(i,l,r) for(int i=l;i<r;++i)
#define Rs(i,l,r) for(int i=l;i>r;--i)
#define Le(i,l,r) for(int i=l;i<=r;++i)
#define Re(i,l,r) for(int i=l;i>=r;--i)
#define L(i,l) for(int i=0;i<l;++i)
#define E(i,l) for(int i=1;i<=l;++i)
#define W(t) while(t--)
#define Wh while

const int N=300010,K=19,M=2*N;
int T,n,f[N][K],g[N][K],sz[N],p[N],cnt;
int h[N],e[M],ne[M],idx;//don't forget memset h!
typedef long long ll;
ll ans;
void add(int a,int b){
    e[idx]=b,ne[idx]=h[a],h[a]=idx++;
}
int calc_ans(int x){
    int tot=sz[x],mid=tot>>1;
    Re(i, 18, 0)
        if(f[x][i]&&tot-sz[f[x][i]]<=mid)x=f[x][i];
    if(!(tot&1)&&sz[x]==mid)x+=p[x];
    return x;
}
void upd(int x,int son){
    f[x][0]=son;
    E(i, 18)f[x][i]=f[f[x][i-1]][i-1];
}
void init(int x,int fa){
    int son=0;
    sz[x]=1,p[x]=fa;
    Ed{
        int j=e[i];
        if(j==fa)continue;
        init(j,x);
        if(sz[j]>sz[son])son=j;
        sz[x]+=sz[j];
    }
    upd(x,son);
}
void dfs(int x,int fa){
    int s1=0,s2=0,bksz=sz[x];
    Ed{
        int j=e[i];
        if(sz[j]>=sz[s1])s2=s1,s1=j;
        else if(sz[j]>sz[s2])s2=j;
    }
    Ed{
        int j=e[i];
        if(j==fa)continue;
        ans+=calc_ans(j);
        p[x]=j,sz[x]=n-sz[j];
        if(j!=s1)upd(x,s1);
        else upd(x,s2);
        ans+=calc_ans(x);
        dfs(j,x);
    }
    p[x]=fa,sz[x]=bksz;
    memcpy(f[x],g[x],sizeof f[x]);
}
int main(){
    #ifndef ONLINE_JUDGE
    freopen("1.in","r",stdin);
    #endif
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    cin>>T;
    W(T){
        cin>>n;
        ans=idx=0;
        memset(h,-1,n*4+4);
        L(i, n-1){
            int a,b;
            cin>>a>>b;
            add(a,b),add(b,a);
        }
        init(1,-1);
        memcpy(g,f,sizeof g);
        dfs(1,-1);
        cout<<ans<<'\n';
    }
    return 0;
}