题解 [SDOI2016] 游戏

发布时间 2023-12-30 12:16:58作者: XSC062

可以看出来出题人很想出一道把李超和别的什么东西凑起来的题目,于是给了这么一个缝合怪。

https://www.luogu.com.cn/problem/P4069

符号有点混乱。比如箭头又可以表示路径又可以表示赋值,代入语境应该还是好理解的。


看到 \(a\times dis + b\) 就应激反应出来是李超了,看到 \(s\to t\) 又瞬间反应过来是树剖,但是树剖的 DFN 和 \(dis\) 没有直接关联,赛时想不到怎么做就跑路了。

实际上这个转化很板。因为这是条路径,我们还在树链上跳,每次跳过的一个链上的 DFN 是连续的,对应的 \(dis\) 也是连续的。

估计是打 T4 的子树问题打傻了,没想到这个。

所以我们就相当于是给一条重链上的某个连续区间加了一个斜率为 \(a\),截距为 \(b\) 加上 一坨东西 的线段。用李超维护即可。


感觉讲的不清不楚的,那就再讲讲。

我们要让李超上任意一个点 \(u\) 代表的 \(x\) 是个定值。一坨东西 维护了这条线段相对于 \(s\) 的偏移量。令 \(r\gets \text{LCA of } s \text { and } t\)\(R\) 表示整棵树的根,\(d(u,v)\gets \text {distance between } u \text { and } v\)

  1. 对于 \(s\to r\) 上的每个点 \(u\)

    \[\begin{aligned} val_u&=a\times d(s, u)+b\\ &=a\times[d(s,R)-d(u,R)]+b\\ &=-a\times d(u,R)+[a\times d(s,R)+b] \end{aligned} \]

    \(a\times d(s,R)+b\) 是和 \(u\) 无关的定值(这意味着可以在同一个询问的树剖时直接线段树),\(d(u,R)\) 是只和 \(u\) 相关的值(这意味着对于任意询问都成立)。令斜率为 \(-a\),截距为 \(a\times d(s,R)+b\),李超上任意点 \(u\) 代表 \(x\)\(d(x,R)\),维护如此一条线段即可。

  2. 对于 \(r\to t\) 上的每个点 \(v\)

    \[\begin{aligned} val_v&=a\times d(s,v)+b\\ &= a\times [d(r,v)+d(s,r)]+b\\ &=a\times [d(v,R)-d(r,R)+d(s,r)]+b\\ &=a\times d(v,R)+[-a\times d(r,R)+a\times d(s,r)+b] \end{aligned} \]

    \(-a\times d(r,R)+a\times d(s,r)+b\) 是和 \(v\) 无关,只和询问中固定的 \(s,r,a,b\) 有关的定值;\(d(v,R)\) 是只和 \(v\) 相关的值且和上一个 case 里的 \(d(u,R)\) 同构,也用李超这么维护即可。

用李超维护把原问题转化为每次向若干重链上连续区间插入线段,求最低交点的问题。

注意到区间查询,李超需要加一个 pushup。具体怎么去操作呢?在加线段的时候除编号外新增一个变量维护当前区间内最低交点;我们就 pushup 这个东西。然后查询的时候就是在散区间的时候和原来一样查,整区间就在散区间答案的基础上再和当前区间的整体最低交点比一个 min。

#define int long long
namespace XSC062 {
using namespace fastIO;
const int maxk = 25;
const int maxn = 1e6 + 5;
const int maxm = 1e6 + 5;
const int inf = 123456789123456789;
//#define DEBUG

#ifdef DEBUG
#define Z(x) x
#else
#define Z(x)
#endif
#define lt (p << 1)
#define rt (lt | 1)
struct __ { int k, b; };
struct _ { int l, r, u, d; };
struct ____ {
	int v, w;
	____() {}
	____(int v1, int w1) {
		v = v1, w = w1; 
	}
};
__ a[maxm];
_ t[maxn << 2];
int tot, si, now;
int f[maxn][maxk];
int dis[maxn], dep[maxn];
int dfn[maxn], tab[maxn];
std::vector<____> g[maxn];
int n, m, ty, x, y, w, k, b;
int siz[maxn], top[maxn], son[maxn];
int max(int x, int y) {
	return x > y ? x : y;
}
int min(int x, int y) {
	return x < y ? x : y;
}
void swap(int &x, int &y) {
	x ^= y ^= x ^= y;
	return;
}
void DFS1(int x) {
	siz[x] = 1;
	for (auto i : g[x]) {
		if (i.v == f[x][0]) continue;
		f[i.v][0] = x;
		for (int j = 1; j <= si; ++j)
			f[i.v][j] = f[f[i.v][j - 1]][j - 1];
		dep[i.v] = dep[x] + 1;
		dis[i.v] = dis[x] + i.w;
		DFS1(i.v);
		if (siz[i.v] > siz[son[x]])
			son[x] = i.v;
		siz[x] += siz[i.v];
	}
	return;
}
void DFS2(int x, int t) {
	top[x] = t;
	dfn[x] = ++now, tab[now] = x;
	if (son[x]) DFS2(son[x], t);
	for (auto i : g[x]) {
		if (i.v == f[x][0] || i.v == son[x])
			continue;
		DFS2(i.v, i.v);
	}
	return;
}
void bld(int p, int l, int r) {
	t[p].l = l, t[p].r = r, t[p].d = inf;
	if (l == r) return;
	int mid = (l + r) >> 1;
	bld(lt, l, mid), bld(rt, mid + 1, r);
	return;
}
int getv(int id, int x) {
	if (!id) return inf;
	Z(printf("get (%lld, %lld) = %lld\n",
		id, x, dis[tab[x]] * a[id].k + a[id].b));
	return dis[tab[x]] * a[id].k + a[id].b;
}
void pushup(int p) {
	if (t[p].l == t[p].r) return;
	t[p].d = min(t[p].d, min(t[lt].d, t[rt].d));
	Z(printf("[%lld, %lld]: pushup to %lld\n",
		t[p].l, t[p].r, t[p].d));
	return;
}
void chg(int p, int id) {
	t[p].u = id;
	Z(int tmp = t[p].d);
	t[p].d = min(getv(id, t[p].l),
				 getv(id, t[p].r));
	Z(printf("[%lld, %lld]: %lld -> %lld\n",
				t[p].l, t[p].r, tmp, t[p].d)); 
	return;
}
void upd(int p, int id) {
	if (!t[p].u) {
		chg(p, id), pushup(p);
		return;
	}
	int mid = (t[p].l + t[p].r) >> 1;
	int v1 = getv(t[p].u, mid),
		v2 = getv(id, mid);
	if (v2 < v1) swap(t[p].u, id);
	v1 = getv(t[p].u, t[p].l);
	v2 = getv(id, t[p].l);
	if (v2 < v1) upd(lt, id);
	v1 = getv(t[p].u, t[p].r);
	v2 = getv(id, t[p].r);
	if (v2 < v1) upd(rt, id);
	chg(p, t[p].u);
	pushup(p);
	return;
}
void add(int p, int l, int r, int id) {
	if (l <= t[p].l && t[p].r <= r) {
		upd(p, id);
		return;
	}
	int mid = (t[p].l + t[p].r) >> 1;
	if (l <= mid) add(lt, l, r, id);
	if (r > mid) add(rt, l, r, id);
	pushup(p);
	return;
}
int ask(int p, int l, int r) {
	l = max(l, t[p].l);
	r = min(r, t[p].r);
	int res = min(getv(t[p].u, l),
				  getv(t[p].u, r));
	Z(printf("[%lld, %lld]: res = %lld\n",
		t[p].l, t[p].r, res));
	if (l <= t[p].l && t[p].r <= r)
		return min(res, t[p].d);
	int mid = (t[p].l + t[p].r) >> 1;
	if (l <= mid) res = min(res, ask(lt, l, r));
	if (r > mid) res = min(res, ask(rt, l, r));
	Z(printf("[%lld, %lld]: res = %lld\n",
		t[p].l, t[p].r, res));
	return res;
}
void add(int x, int y, int w) {
	g[x].push_back(____(y, w));
	return;
}
void ins(int l, int r, int k, int b) {
	a[++tot].k = k, a[tot].b = b;
	add(1, l, r, tot);
	return;
}
int getLCA(int x, int y) {
	if (dep[x] < dep[y]) swap(x, y);
	for (int i = si; ~i; --i) {
		if (dep[f[x][i]] >= dep[y])
			x = f[x][i];
	}
	if (x == y) return x;
	for (int i = si; ~i; --i) {
		if (f[x][i] != f[y][i])
			x = f[x][i], y = f[y][i];
	}
	return f[x][0];
}
void inst(int s, int t, int k, int b) {
	int r = getLCA(s, t), u = s;
	while (top[u] != top[r]) {
		ins(dfn[top[u]], dfn[u],
					-k, k * dis[s] + b);
		u = f[top[u]][0];
	}
	ins(dfn[r], dfn[u], -k, k * dis[s] + b);
	u = t;
	int d = dis[s] - dis[r];
	while (top[u] != top[r]) {
		ins(dfn[top[u]], dfn[u],
				k, -k * dis[r] + k * d + b);
		u = f[top[u]][0];
	}
	ins(dfn[r], dfn[u],
			k, -k * dis[r] + k * d + b);
	return;
}
int qry(int x, int y) {
	int res = inf;
	while (top[x] != top[y]) {
		if (dep[top[x]] < dep[top[y]])
			swap(x, y);
		Z(printf("ask %lld -> %lld: %lld\n",
			x, top[x],
			ask(1, dfn[top[x]], dfn[x])));
		res = min(res,
			ask(1, dfn[top[x]], dfn[x]));
		x = f[top[x]][0];
	}
	if (dep[x] < dep[y]) swap(x, y);
	Z(printf("ask %lld -> %lld: %lld\n",
		x, y, ask(1, dfn[y], dfn[x])));
	res = min(res, ask(1, dfn[y], dfn[x]));
	return res;
}
int main() {
//	freopen("game1.in", "r", stdin);
	read(n), read(m);
	si = log(n) / log(2.0);
	for (int i = 1; i < n; ++i) {
		read(x), read(y), read(w);
		add(x, y, w), add(y, x, w);
	}
	bld(1, 1, n);
	dep[1] = 1, DFS1(1), DFS2(1, -1);
	Z(for (int i = 1; i <= n; ++i)
		printf("dfn[%lld] = %lld\n", i, dfn[i]));
	while (m--) {
		read(ty), read(x), read(y);
		if (ty == 1) {
			read(k), read(b);
			inst(x, y, k, b);
		}
		else print(qry(x, y), '\n');
	}
	return 0;
}
} // namespace XSC062
#undef int