【差分约束】P7624 [AHOI2021初中组] 地铁 题解

发布时间 2023-10-07 13:32:50作者: Pengzt

P7624

\(d_i\) 表示 \(1\) 号车站到 \(i\) 号车站的距离,\(len\) 表示环形地铁的总长度。

考虑题中给的条件:

\(type_i = 0\) 时,若 \(u_i < v_i\),即可表示为 \(d_{v_i} - d_{u_i} \ge L_i \iff d_{u_i} \le d_{v_i} - L_i\)。否则就是要经过 \(n\) 号点和 \(1\) 号点,则可表示为 \(d_{u_i} - d_{v_i} \le len - L_i\),即 \(d_{u_i} \le d_{v_i} + len - L_i\)

同理,当 \(type_i = 2\) 时,若 \(u_i < v_i\),有 \(d_{v_i} \le d_{u_i} + L_i\)
否则有 \(d_{u_i} \le d_{v_i} + L_i - len\)

若将 \(len\) 看作常数,容易发现这些不等式可以用差分约束解决。

发现当图中有负环时,根据 \(len\) 的系数和来判断这个环的权值的关系,显然具有单调性,使用二分解决即可。

推荐写 Bellman-Ford 判负环,代码简便且跑的较快。代码中二分的细节参考 meyi。

这里放一份 SPFA 的代码。

代码:

const int N = 510;
const ll inf = 1e12;
int n, m;
int vis[N], cnt[N], pre[N];
ll dis[N];
int head[N], cnte;
struct edge {
	int u, v, w, k, nxt;
} e[N << 1];

void adde(int u, int v, int w, int k) {
	e[++cnte].nxt = head[u];
	e[cnte].u = u;
	e[cnte].v = v;
	e[cnte].w = w;
	e[cnte].k = k;
	head[u] = cnte;
}
int check(ll mid) {
	queue<int> q;
	for (int i = 1; i <= n; i++) dis[i] = INF, vis[i] = 0;
	memset(cnt, 0, sizeof(cnt));
	memset(pre, 0, sizeof(pre));
	q.push(1);
	dis[1] = 0, vis[1] = cnt[1] = 1;
	while (!q.empty()) {
		int u = q.front(); q.pop();
		vis[u] = 0;
		for (int i = head[u], v; i != 0; i = e[i].nxt)
			if (dis[v = e[i].v] > dis[u] + e[i].w + e[i].k * mid) {
				dis[v] = dis[u] + e[i].w + e[i].k * mid;
				pre[v] = i;
				if (vis[v] == 0) {
					vis[v] = 1;
					q.push(v);
					if (++cnt[v] > n) {
						int x = u;
						for (int j = 1; j <= n; j++) x = e[pre[x]].u;
						int sk = e[pre[x]].k;
						for (int y = e[pre[x]].u; y != x; y = e[pre[y]].u) sk += e[pre[y]].k;
						return sk > 0 ? 1 : -1;
					}
				}
			}
	}
	return 0;
}
ll get_l() {
	ll l = 0, r = inf;
	while (l < r) {
		ll mid = (l + r) / 2;
		int tmp = check(mid);
		if (tmp == 0) r = mid;
		else if (tmp == 1) l = mid + 1;
		else r = mid - 1;
	}
	return l;
}
ll get_r() {
	ll l = 0, r = inf;
	while (l < r) {
		ll mid = (l + r + 1) / 2;
		int tmp = check(mid);
		if (tmp == 0) l = mid;
		else if (tmp == 1) l = mid + 1;
		else r = mid - 1;
	}
	return l;
}

int main() {
	ios
	cin >> n >> m;
	for (int i = 1; i < n; i++) adde(i + 1, i, -1, 0);
	adde(1, n, -1, 1);
	for (int i = 1, opt, u, v, w; i <= m; i++) {
		cin >> opt >> u >> v >> w;
		if (opt == 1) {
			if (u < v) adde(v, u, -w, 0);
			else adde(v, u, -w, 1);
		} else {
			if (u < v) adde(u, v, w, 0);
			else adde(u, v, w, -1);
		}
	}
	ll r = get_r();
	if (r >= inf) {
		cout << -1 << "\n";
		return 0;
	}
	ll l = get_l();
	cout << r - l + 1 << "\n";
	return 0;
}