2023牛客暑期多校练营6 A-Tree 树上背包+并查集

发布时间 2023-08-28 13:01:21作者: touchfishman

2023牛客暑期多校练营6 A-Tree 树上背包+并查集

题目链接

题意:

 给出一棵树,节点为黑色或者白色,定义整棵树的贡献为,任意白点到任意黑点所经过路径上的最大边权之和,节点i原本颜色已给出,可以花费c[i]代价翻转节点i的颜色,问最大贡献是多少。

做法:

首先我们思考怎么处理最大边权的问题,怎么去确定某条路径上的最大边权。

  • 答案是类似kruscal的处理办法,我们可以将边权排序,从小到大枚举,用并查集来维护两颗子树
  • 设u,v为当前枚举到边的两端,即以u为根的子树和以v为根的子树,因为是从小到大枚举,u,v中不会存在比当前边更大的边,那么这条边的贡献即为: 边权*(子树u中白点数量*子树v中黑点数量+子树u中黑点数量*子树v中白点数量)
  • 在计算贡献时我们只需要知道子树的大小,不需要知道树长什么样,因此直接并查集维护大小即可。

接下来就是常规的优化树上背包,定义状态dp[i][j],i表示连通块的编号,j表示有几个黑点,转移方程如下,其中tmp为临时数组

    ll val  = w*(j*(sz[v]-k)+k*(sz[u]-j));
    //j为u中黑点数,k为v中黑点数,sz表示子树大小
    tmp[j+k] = max(tmp[j+k],dp[u][j]+dp[v][k]+val);

总的来说这题还是非常秒的,但想明白并查集后就非常简单了。

代码:

#define fst std::ios::sync_with_stdio(false), std::cin.tie(0), std::cout << std::fixed << std::setprecision(20)
#define le "\n"
#define ll long long 
#include <bits/stdc++.h>
using namespace std;
const int N=3000+50;
const int mod=998244353;
int col[N],cost[N],sz[N];
ll dp[N][N];
int fa[N];
struct node{
    int u,v;
    ll w;
    bool operator < (const node & nxt) const{
        return w < nxt.w;
    }
};

int find(int x){
    return fa[x]==x ? x: fa[x] = find(fa[x]);
}

int main() {
    int n; cin>>n;
    for(int i=1;i<=n;i++) cin>>col[i];
    for(int i=1;i<=n;i++) cin>>cost[i];
    vector<node> e;    
    for(int i=1;i<n;i++){
        int u,v; ll w;
        cin>>u>>v>>w;
        e.push_back({u,v,w});
    }
    sort(e.begin(),e.end());

    for(int i=1;i<=n;i++){
        sz[i] = 1;
        fa[i] = i;
        dp[i][col[i]] = 0;
        dp[i][col[i]^1] = -cost[i];
    }

    for(auto [u,v,w]: e){
        u = find(u), v = find(v);
        vector<ll> tmp(sz[u]+sz[v]+1,-1e18);
        for(int j=0;j<=sz[u];j++){
            for(int k=0;k<=sz[v];k++){
                ll val  = w*(j*(sz[v]-k)+k*(sz[u]-j));
                tmp[j+k] = max(tmp[j+k],dp[u][j]+dp[v][k]+val);
            }
        }
        fa[v] = u;
        sz[u] += sz[v];
        for(int j=0;j<=sz[u];j++) dp[u][j] = tmp[j];
    }

    int x = find(1);
    cout<<*max_element(dp[x],dp[x]+1+n)<<le;
    return 0;
}