P2726 [SHOI2005] 树的双中心

发布时间 2023-11-01 17:35:03作者: Thunder_S

Description

给定一棵树 \(T=(V,E)\),其中 \(V\) 为节点集合,\(E\) 为边集合。对于 \(V\) 中的每个节点 \(v\),有一个权值函数 \(W(v)\),该函数的值均为正整数。记 \(d(u, v)\) 为节点 \(u\)\(v\) 之问的距离,表示它们之问唯一的一条路径的边数。若 \(u\)\(v\) 为同一个节点,则 \(d(u, v)=0\)。你的任务是找出两个不同的节点 \(x\)\(y\),使得以下表达式 \(S(x, y)\) 的值最小。

\[S(x,y)=\sum\limits_{v\in V}(W(v)·\min\{d(v,x),d(v,y)\}) \]

Solution

注意到,无论如何选择 \(x,y\),这 \(n\) 个点一定会被一条边分成两个点集。那么考虑枚举这条边,求出割了这条边后形成的两个树对应的带权重心。

考虑用 \(dp\) 来求带权重心。设 \(f_x\) 表示该树内所有点到 \(x\) 的距离之和。那么先跑一遍得出正确的 \(f_1\),转移:\(f_x=\sum f_y+sum_y\),其中 \(sum_y\) 表示子树 \(y\) 的点权和。此时除 \(1\) 以外的 \(f\) 记录的该子树内的答案。考虑换根,那么从 \(x\) 换到 \(y\) 就要令 \(f_y=f_x-sum_y+(sum_1-sum_y)\),其含义是 \(x\) 除儿子 \(y\) 以外的点要从 \(x\) 移到 \(y\),因此增加 \(sum_1-sum_y\),而子树 \(y\) 内的点则不再需要从 \(y\) 移到 \(x\),因此减去 \(sum_y\)

\(dp\) 的复杂度是 \(\mathcal O(n)\) 的,加上枚举一条边的总复杂度是 \(\mathcal O(n^2)\),无法通过本题。

考虑优化。观察换根的式子,你发现 \(y\)\(x\) 优当且仅当 \(2\times sum_y>sum_1\),而一旦不等式不满足就不用再往下搜了。因此考虑记录每个点的儿子中子树最大的,然后往最大儿子那边走即可。可是注意到,在删去某条边后,某些子树的大小会发生改变,此时原本的最大儿子不能再是最大的。但是,对于任意一个点来说,这种影响最多影响一个儿子。所以考虑同时记录次大儿子,当最大儿子不再是最大时,走次大儿子即可。最终复杂度 \(\mathcal O(nh)\)

Code

#include<cstdio>
#include<algorithm>
#define N 50005
#define inf 0x3f3f3f3f
using namespace std;
int n,tot,ans=inf,res1,res2,cant,sum[N],w[N],f[N],dep[N],fat[N],mx1[N],mx2[N];
struct node {int to,next,head;}a[N<<1];
void add(int x,int y)
{
    a[++tot].to=y;a[tot].next=a[x].head;a[x].head=tot;
    a[++tot].to=x;a[tot].next=a[y].head;a[y].head=tot;
}
void dfs(int x,int fa)
{
    sum[x]=w[x];fat[x]=fa;
    for (int i=a[x].head;i;i=a[i].next)
    {
        int y=a[i].to;
        if (y==fa) continue;
        dep[y]=dep[x]+1;
        dfs(y,x);
        sum[x]+=sum[y];
        f[x]+=f[y]+sum[y];
        if (sum[y]>=sum[mx1[x]]) mx2[x]=mx1[x],mx1[x]=y;
        else if (sum[y]>sum[mx2[x]]) mx2[x]=y;
    }
}
void dp(int x,int g,int all,int &res)
{
    res=min(res,g);
    int y=mx1[x];
    if (y==cant||sum[mx2[x]]>sum[mx1[x]]) y=mx2[x];
    if (!y) return;
    if (2*sum[y]>all) dp(y,g+(all-sum[y])-sum[y],all,res);
}
void cut(int x,int fa)
{
    for (int i=a[x].head;i;i=a[i].next)
    {
        int y=a[i].to;
        if (y==fa) continue; 
        cant=y;
        for (int now=x;now;now=fat[now]) sum[now]-=sum[y];
        res1=res2=inf;
        dp(y,f[y],sum[y],res1);
        dp(1,f[1]-f[y]-sum[y]*dep[y],sum[1],res2);
        ans=min(ans,res1+res2);
        for (int now=x;now;now=fat[now]) sum[now]+=sum[y];
        cut(y,x);
    }
}
int main()
{
    scanf("%d",&n);
    for (int i=1;i<n;++i)
    {
        int x,y;
        scanf("%d%d",&x,&y);
        add(x,y);
    }
    for (int i=1;i<=n;++i)
        scanf("%d",&w[i]);
    dfs(1,0);
    cut(1,0);
    printf("%d\n",ans);
    return 0;
}