CF838B Diverging Directions

发布时间 2023-07-21 08:35:03作者: Ender_32k

好像挺神奇的,也可能是我菜。

以下称前 \(n-1\) 条边为「树边」,因为它们组成一棵树;后 \(n-1\) 条边为「回边」,因为它们由树节点回到根。

就是对于一个询问,如果 \(v\)\(u\) 的子树内,发现无论如何答案都要包括 \(u\to v\) 的只经过树边的路径。那么只走这条路径一定是最优的,直接维护 \(u\) 到根仅经过树边的距离 \(d_{u}\),答案就是 \(d_v-d_u\)

否则必然是 \(u\) 沿着树边到子树内的点 \(t\),然后从 \(t\) 沿回边到达 \(1\) 节点,再直接沿树边走向 \(v\)。那么这条路径的长度就是 \((d_t-d_u)+a_t+d_v\)\(a_u\) 表示 \(u\)\(1\) 的回边长度。

\(d_u+d_v\) 是定值,线段树维护 \(\min\limits_{t\in \text{sub(u)}}\{a_t+d_t\}\) 即可。修改边权相当于单点改或子树加,也可以用线段树维护。

代码很好写,跑得还挺快。

#include <bits/stdc++.h>
#define int long long
using namespace std;

namespace vbzIO {
	char ibuf[(1 << 20) + 1], *iS, *iT;
	#if ONLINE_JUDGE
	#define gh() (iS == iT ? iT = (iS = ibuf) + fread(ibuf, 1, (1 << 20) + 1, stdin), (iS == iT ? EOF : *iS++) : *iS++)
	#else
	#define gh() getchar()
	#endif
	#define pc putchar
	#define pi pair<int, int>
	#define tu3 tuple<int, int, int>
	#define tu4 tuple<int, int, int, int>
	#define mp make_pair
	#define mt make_tuple
	#define fi first
	#define se second
	#define pb push_back
	#define ins insert
	#define era erase
	#define clr clear
	inline int read () {
		char ch = gh();
		int x = 0;
		bool t = 0;
		while (ch < '0' || ch > '9') t |= ch == '-', ch = gh();
		while (ch >= '0' && ch <= '9') x = (x << 1) + (x << 3) + (ch ^ 48), ch = gh();
		return t ? ~(x - 1) : x;
	}
	inline void write(int x) {
		if (x < 0) {
			x = ~(x - 1);
			putchar('-');
		}
		if (x > 9)
			write(x / 10);
		putchar(x % 10 + '0');
	}
}
using vbzIO::read;
using vbzIO::write;

const int maxn = 4e5 + 200;
const int inf = 1e18;
struct seg { int tg, vl; } tr[maxn << 2];
int n, m, dfc, d[maxn], b[maxn], id[maxn], sz[maxn];
vector<pi> t[maxn];
pi ed[maxn];

void dfs(int u, int fa) {
    id[u] = ++dfc, sz[u] = 1;
    for (auto p : t[u]) {
        int v = p.fi, w = p.se;
        d[v] = d[u] + w - b[u] + b[v];
        dfs(v, u), sz[u] += sz[v];
    }
}

#define ls x << 1
#define rs x << 1 | 1
#define mid ((l + r) >> 1)
void pushup(int x) { tr[x].vl = min(tr[ls].vl, tr[rs].vl); }
void pushtg(int x, int c) { tr[x].vl += c, tr[x].tg += c; }
void pushdown(int x) {
    if (!tr[x].tg) return;
    pushtg(ls, tr[x].tg), pushtg(rs, tr[x].tg);
    tr[x].tg = 0;
}
void update(int l, int r, int s, int t, int c, int x) {
    if (s <= l && r <= t) return pushtg(x, c);
    pushdown(x);
    if (s <= mid) update(l, mid, s, t, c, ls);
    if (t > mid) update(mid + 1, r, s, t, c, rs);
    pushup(x);
}

int query(int l, int r, int s, int t, int x) {
    if (s <= l && r <= t) return tr[x].vl;
    pushdown(x);
    int res = inf;
    if (s <= mid) res = min(res, query(l, mid, s, t, ls));
    if (t > mid) res = min(res, query(mid + 1, r, s, t, rs));
    return res;
}

signed main() {
	n = read(), m = read();
    for (int i = 1, u, v, w; i <= n - 1; i++) {
        u = read(), v = read(), w = read();
        t[u].pb(mp(v, w)), ed[i] = mp(v, w);
    }
    for (int i = 1, u, v, w; i <= n - 1; i++) {
        u = read(), v = read(), w = read();
        ed[i + n - 1] = mp(u, w), b[u] = w;
    }
    dfs(1, 0);
    for (int i = 2; i <= n; i++) update(1, n, id[i], id[i], d[i], 1);
    while (m--) { 
        int op = read();
        if (op == 2) {
            int u = read(), v = read();
            int duv = (query(1, n, id[v], id[v], 1) - b[v]) - (query(1, n, id[u], id[u], 1) - b[u]);
            if (id[v] >= id[u] && id[v] <= id[u] + sz[u] - 1) write(duv), puts("");
            else write(duv + query(1, n, id[u], id[u] + sz[u] - 1, 1)), puts("");
        } else {
            int p = read(), w = read();
            if (p <= n - 1) update(1, n, id[ed[p].fi], id[ed[p].fi] + sz[ed[p].fi] - 1, w - ed[p].se, 1), ed[p].se = w;
            else update(1, n, id[ed[p].fi], id[ed[p].fi], w - b[ed[p].fi], 1), b[ed[p].fi] = w;
        }
    }
	return 0;
}