ICPC2021Kunming G Find the Maximum 题解

发布时间 2023-12-28 22:45:17作者: Martian148

Question

Find the Maximum

给出一个树,每个点有一个权值 \(b_n\),求一条树上路径 \(V\),要求 \(\frac{\sum_{u\in V (-x^2+b_u x)}}{|V|}\) 最大,其中 \(x\) 是自己选择的一个树

Solution

先转化一下 \(\frac{\sum_{u\in V (-x^2+b_u x)}}{|V|}\), 得到

\[\frac{\sum_{u\in V (-x^2+b_u x)}}{|V|}\le \frac{1}{4}(\frac{\sum_{u\in V}b_u}{|V|})^2 \]

也就是要求 \(\frac{\sum_{u\in V} b_u}{|V|}\) 的最值

可以二分枚举最后的答案 \(\frac{\sum_{u\in V} b_u}{|V|}\ge mid\) 得到

\[\frac{\sum_{u\in V} b_u}{|V|}\ge mid \Leftrightarrow \sum_{u\in V} b_u-mid\times |V| \ge 0\Leftrightarrow \sum_{u\in V} (b_u-mid) \ge 0 \]

只需要找一条路径满足 \(\sum_{u\in V} (b_u-mid) \ge 0\) 就好了

用树形 DP 去找就好了,定义 \(c_u=b_u-mid\)

这里看到清华的一种比较好的 DP 方法,定义 \(dp[x]\) 表示从 \(x\) 节点开始的一段连续的最大和,\(ans[x]\) 表示 \(x\) 子树内的最大路径和

转移时

for(int v:G[u]){
	ans[x]=max(ans[x],dp[x]+dp[v]);
	dp[x]=max(dp[x],dp[v]+c[x]);
	ans[x]=max(ans[x],ans[v]);
}

最后只需要判断 \(F[1]\) 是否大于 \(1\) 即可


考虑另外一种做法

一条长度大于或等于 \(4\) 的路径都能拆成 \(2\) 条长度大于 \(1\) 的路径,而这 \(2\) 条长度大于 \(1\) 的路径中必然满足其中 \(1\) 条的平均值不大于拆分前的平均值,所以最后选择的路径长度肯定为 \(2\)\(3\)

Code

#pragma GCC optimize(2)
#pragma GCC optimize(3)
#include<bits/stdc++.h>
using namespace std;
const double INF=1e5,eps=1e-12;
int n;
double ans;
vector<int> a,fa,vs;
vector<vector<int> > G;
vector<double> c,F;

inline int read(){
    int ret=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-f;ch=getchar();}
    while(ch<='9'&&ch>='0')ret=ret*10+ch-'0',ch=getchar();
    return ret*f;
}

void dfs(int x,int f){
    fa[x]=f;
    for(auto v:G[x]) if(v!=f) dfs(v,x);
    vs.push_back(x);
}


bool check(double mid){
    for(auto x:vs){
        F[x]=-1e100;
        c[x]=a[x]-mid;
        for(auto v:G[x]) if(v!=fa[x]){
            F[x]=max(F[x],c[x]+c[v]);
            c[x]=max(c[x],c[v]+a[x]-mid);
            F[x]=max(F[x],F[v]);
        }
    }
    return F[1]>=0;
}

int main(){
    n=read();
    
    a.assign(n+1,0);G.assign(n+1,vector<int>());c.assign(n+1,0); F.assign(n+1,0); fa.assign(n+1,0);

    for(int i=1;i<=n;i++) a[i]=read();
    for(int i=1;i<n;i++){
        int u,v;u=read();v=read();
        G[u].push_back(v);
        G[v].push_back(u);
    }
    dfs(1,0);
    
    check(3);

    double L=0,R=INF;
    double ans1=-INF;
    for(int i=1;i<=60;i++){
        double mid=(R+L)/2;
        if(check(mid)) {L=mid;}
        else R=mid;
    }
    ans1=L;
    
    for(int i=1;i<=n;i++) a[i]=-a[i];
    
    L=0,R=INF;
    double ans2=-INF;
    for(int i=1;i<=60;i++){
        double mid=(R+L)/2;
        if(check(mid)) {L=mid;}
        else R=mid;
    }
    ans2=L;

    ans=max(fabs(ans1),fabs(ans2));
    // printf("%.6lf %.6lf\n",ans1,ans2);
    printf("%.4lf\n",ans*ans/4);
    // cout<<clock()<<endl;
    return 0;
}