AtCoder Beginner Contest 258 G Grid Card Game

发布时间 2023-06-04 22:07:03作者: zltzlt

洛谷传送门

AtCoder 传送门

\(b_i = \sum\limits_{j = 1}^m a_{i, j}, c_j = \sum\limits_{i = 1}^n a_{i, j}\)

首先考虑这样一个事情,就是对于 \(b_i \le 0\) 的行有没有可能被选。如果选了它:

  • 如果没有选任何列,选这一行肯定不优;
  • 如果选了若干列,根据题目的要求,这若干列与这一行重叠的部分只可能是非负数。考虑看成是先选列再选行,那选这一行必然会造成负贡献,雪上加霜。

所以我们得到只选 \(b_i > 0\) 的行一定不劣。类似地,只选 \(c_j > 0\) 的列也一定不劣。

这种行列网格的题,可以考虑转化成二分图一类的问题。考虑直接硬冲一个最大费用最大流。下面设 \((u, v, x, y)\) 表示一条 \(u \to v\),容量 \(x\),费用 \(y\) 的边。行作为左部点,列作为右部点。

  • 对于 \(S\) 到左部点的边,若 \(b_i \le 0\) 直接不管,否则我们希望有流量就产生 \(b_i\) 的贡献,而不管流量多少。因此连边 \((S, i, 1, b_i), (S, i, +\infty, 0)\)
  • 对于一个匹配,可以看成是这一行跟这一列都选。那么重叠的部分要减掉,即连边 \((i, j, 1, -a_{i, j})\)
  • 对于右部点到 \(T\) 的边,类似地,连边 \((j, T, 1, c_j), (j, T, +\infty, 0)\)

最后答案就是最大费用。

但是理论复杂度是 \(O(n^5)\) 的,实际也会 T 得很惨

类似 ABC214H 地考虑,不妨换种思路,计算最少损失(初始的答案是 \(\sum\limits_{i = 1}^n b_i + \sum\limits_{j = 1}^m c_j\))。考虑转化成最小割,这样建图(用 \((u, v, x)\) 表示一条 \(u \to v\),容量为 \(x\) 的边):

  • 对于 \(b_i > 0\) 的行,连边 \((S, i, b_i)\)
  • 对于 \(c_j > 0\) 的列,连边 \((j, T, c_j)\)
  • 对于 \(a_{i, j} \ge 0\),连边 \((i, j, a_{i, j})\)
  • 对于 \(a_{i, j} < 0\),连边 \((i, j, +\infty)\)

这样保证了对于任意 \(S \to i \to j \to T\) 的路径:

  • 要么割掉 \(S \to i\)\(j \to T\) 的边,表示它们不同时被选;
  • 要么割掉 \(i \to j\) 的边,表示它们同时被选,但是因为有重叠,所以产生了 \(-a_{i, j}\) 的负贡献;
  • 如果 \(a_{i, j} < 0\),那么也能保证 \(i, j\) 不同时被选,因为 \(i \to j\) 的边权是 \(+\infty\),表示它不能被割,只能割 \(S \to i\)\(j \to T\) 的边。

最后答案就是 \(\sum\limits_{i = 1}^n b_i + \sum\limits_{j = 1}^m c_j - \text{mincut} = \sum\limits_{i = 1}^n b_i + \sum\limits_{j = 1}^m c_j - \text{maxflow}\)

因为这是二分图,所以时间复杂度是 \(O(n^{2.5})\)

code
// Problem: G - Grid Card Game
// Contest: AtCoder - AtCoder Beginner Contest 259
// URL: https://atcoder.jp/contests/abc259/tasks/abc259_g
// Memory Limit: 1024 MB
// Time Limit: 2000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef double db;
typedef long double ldb;
typedef pair<ll, ll> pii;

const int maxn = 110;
const int maxm = 1000100;
const ll inf = 0x3f3f3f3f3f3f3f3fLL;

ll n, m, a[maxn][maxn], head[maxm], len = 1, ntot, S, T, b[maxn], c[maxn], id[maxn][2];
struct edge {
	ll to, next, cap, flow;
} edges[maxm];

inline void add_edge(ll u, ll v, ll c, ll f) {
	edges[++len].to = v;
	edges[len].next = head[u];
	edges[len].cap = c;
	edges[len].flow = f;
	head[u] = len;
}

struct Dinic {
	ll d[maxm], cur[maxm];
	bool vis[maxm];
	
	inline void add(ll u, ll v, ll c) {
		add_edge(u, v, c, 0);
		add_edge(v, u, 0, 0);
	}
	
	bool bfs() {
		for (int i = 1; i <= ntot; ++i) {
			d[i] = -1;
			vis[i] = 0;
		}
		queue<int> q;
		q.push(S);
		d[S] = 0;
		vis[S] = 1;
		while (q.size()) {
			int u = q.front();
			q.pop();
			for (int i = head[u]; i; i = edges[i].next) {
				edge &e = edges[i];
				if (e.cap > e.flow && !vis[e.to]) {
					vis[e.to] = 1;
					d[e.to] = d[u] + 1;
					q.push(e.to);
				}
			}
		}
		return vis[T];
	}
	
	ll dfs(ll u, ll a) {
		if (u == T || !a) {
			return a;
		}
		ll flow = 0, f;
		for (ll &i = cur[u]; i; i = edges[i].next) {
			edge &e = edges[i];
			if (d[e.to] == d[u] + 1 && (f = dfs(e.to, min(a, e.cap - e.flow))) > 0) {
				e.flow += f;
				edges[i ^ 1].flow -= f;
				flow += f;
				a -= f;
				if (!a) {
					break;
				}
			}
		}
		return flow;
	}
	
	ll solve() {
		ll flow = 0;
		while (bfs()) {
			for (int i = 1; i <= ntot; ++i) {
				cur[i] = head[i];
			}
			ll t = dfs(S, inf);
			flow += t;
		}
		return flow;
	}
} solver;

void solve() {
	scanf("%lld%lld", &n, &m);
	for (int i = 1; i <= n; ++i) {
		for (int j = 1; j <= m; ++j) {
			scanf("%lld", &a[i][j]);
			b[i] += a[i][j];
			c[j] += a[i][j];
		}
	}
	ll s = 0;
	for (int i = 1; i <= n; ++i) {
		if (b[i] > 0) {
			s += b[i];
		}
	}
	for (int i = 1; i <= m; ++i) {
		if (c[i] > 0) {
			s += c[i];
		}
	}
	S = ++ntot;
	T = ++ntot;
	for (int i = 1; i <= n; ++i) {
		id[i][0] = ++ntot;
	}
	for (int i = 1; i <= m; ++i) {
		id[i][1] = ++ntot;
	}
	for (int i = 1; i <= n; ++i) {
		if (b[i] > 0) {
			solver.add(S, id[i][0], b[i]);
		}
	}
	for (int i = 1; i <= m; ++i) {
		if (c[i] > 0) {
			solver.add(id[i][1], T, c[i]);
		}
	}
	for (int i = 1; i <= n; ++i) {
		for (int j = 1; j <= m; ++j) {
			if (a[i][j] >= 0) {
				solver.add(id[i][0], id[j][1], a[i][j]);
			} else {
				solver.add(id[i][0], id[j][1], inf);
			}
		}
	}
	printf("%lld\n", s - solver.solve());
}

int main() {
	int T = 1;
	// scanf("%d", &T);
	while (T--) {
		solve();
	}
	return 0;
}