P4381 [IOI2008] Island (求基环树直径)

发布时间 2023-08-09 15:34:16作者: Morning_Glory

也许更好的阅读体验

\(\mathcal{Description}\)
给一个基环树森林,求每棵树的直径的和,基环树的直径定义为,从一个点出发只能走到没走过的点(即一个环不能把所有边都选),所经过的路径边权和最大

\(\mathcal{Solution}\)
这题不是很难,只是略微麻烦,代码也好打,一遍就过了
对于每棵基环树独立搞

先找到这棵树的环,这个过程分为两步,第一步找到换上的一条边,可以利用类似网络流的存边方法给走过的边记录是否走过,给走过的点标记是否走过,当走一条没走过的边可以走到走过的点,说明这条边是环上的边,第二步从环上的点出发如果走一条没走过的边发现可以走到环上的点,说明这个点就在环上

找到环后,基环树会是这样的:

在这里插入图片描述
以下称环上的点的子树指的就是,以环上一点作为根节点,屏蔽要经过环上其它点的点所得子树(就是图中画的那样)
考虑直径的形式,第一种是没有经过环上的点,也就是在图中某个环上节点的子树中,第二种是经过环上的点
对于第一种,我们对环上每个点的子树求一边树的直径就可以了
对于第二种情况,这时直径会是,两个环上的点的子树中最长的链的长度相加再加上这两点在环上的最大距离,我们将每个环上的点的子树最长链的长度看作点权,就变成求环上两点的最大距离加上两点点权的和最大
这个过程并不难整,设环上所有边权加起来为\(sum\),用\(d_i\)表示该点逆时针走一步的边权,再设定一个起始点做一个前缀和记为\(s_i\),则两点逆时针路径即为\(s_i-s_j\),顺时针路径即为\(sum-(s_i-s_j)\),记\(mx_i\)表示该点点权(子树中最长链长度),则就是要求\(max\{d_i+d_j+dis(i,j)\}\),其中由于\(dis(i,j)\)可能是顺时针的也可能是逆时针的,干脆对顺时针和逆时针分别求一次,这样\(dis(i,j)=s_i-s_j\),就是要求\(max\{d_i+d_j+s_i-s_j\}\),这个几下前面最大的\(d_j-s_j\)即可

\(\mathcal{Code}\)

#include <cstdio>
#include <algorithm>
#include <deque>
#define ll long long
using namespace std;
const int maxn = 1e6 + 10;
const int maxm = 2e6 + 10;
int n, cnt = 1, now, tot, tim;
int head[maxn], nxt[maxm], to[maxm], w[maxm], q[maxn];
int vis[maxn], used[maxm];
ll ans, sum, mxl;
ll mx1[maxn], mx2[maxn], d[maxn];
void add (int u, int v, int val) { nxt[++cnt] = head[u], head[u] = cnt, to[cnt] = v, w[cnt] = val; }
void dfs (int u)//找环上的一条边放在now里
{
    vis[u] = tim;
    for (int e = head[u]; e; e = nxt[e])
        if (used[e] != tim) {
            used[e] = used[e ^ 1] = tim;
            if (vis[to[e]] != tim)    dfs(to[e]);
            else    now = e;
        }
}
bool dfs2 (int u)//将环上的点放进q 需要从环上的点出发
{
    vis[u] = tim;
    bool flag = false;
    for (int e = head[u]; e; e = nxt[e])
        if (used[e] != tim) {
            used[e] = used[e ^ 1] = tim;
            if (vis[to[e]] == tim)  q[++tot] = u, d[tot] = w[e], flag = true;
            else if (dfs2(to[e]))   q[++tot] = u, d[tot] = w[e], flag = true;
        }
    return flag;
}
void solve (int u)//求出该点为根的子树最长链和直径
{
    vis[u] = tim;
    for (int e = head[u]; e; e = nxt[e])
        if (vis[to[e]] != tim) {
            solve(to[e]);
            ll len = mx1[to[e]] + w[e];
            if (len > mx2[u])   mx2[u] = len;
            if (mx2[u] > mx1[u])    swap(mx1[u], mx2[u]);
        }
    mxl = max(mxl, mx1[u] + mx2[u]);
}
void dp ()
{
    ++tim;
    sum = mxl = 0;
    for (int i = 1; i <= tot; ++i)  vis[q[i]] = tim, sum += d[i];
    for (int i = 1; i <= tot; ++i)  solve(q[i]);
    for (int i = 2; i <= tot; ++i)  d[i] += d[i - 1];
    int p = 1;
    for (int i = 2; i <= tot; ++i) {
        mxl = max(mxl, mx1[q[i]] + mx1[q[p]] + d[i] - d[p]);
        if (mx1[q[i]] - d[i] > mx1[q[p]] - d[p])  p = i;
    }
    p = 1;
    for (int i = 2; i <= tot; ++i) {
        mxl = max(mxl, mx1[q[i]] + mx1[q[p]] + sum - d[i] + d[p]);
        if (mx1[q[i]] + d[i] > mx1[q[p]] + d[p])  p = i;
    }
    ans += mxl;
}
int main ()
{
    scanf("%d", &n);
    for (int i = 1, v, val; i <= n; ++i) {
        scanf("%d%d", &v, &val);
        add(i, v, val), add(v, i, val);
    }
    for (int i = 1; i <= n; ++i)
        if (!vis[i]) {
            ++tim;
            dfs(i);
            ++tim, tot = 0;
            dfs2(to[now]);
            dp();
        }
    printf("%lld\n", ans);
    return 0;
}

如有哪里讲得不是很明白或是有错误,欢迎指正
如您喜欢的话不妨点个赞收藏一下吧