AtCoder Grand Contest 049 E Increment Decrement

发布时间 2023-07-19 21:35:09作者: zltzlt

洛谷传送门

AtCoder 传送门

好题。同时考查了 slope trick 和选手的计数能力,不愧是 AGC E。


先考虑问题的第一部分。

你现在有一个初始全为 \(0\) 的序列 \(b\)。你每次可以给 \(b\) 单点 \(\pm 1\),代价为 \(1\),或区间 \(\pm 1\),代价为 \(m\)。求把 \(b\) 变成给定序列 \(a\) 的最小代价。

考虑先执行区间加减操作,设操作完的序列为 \(c\),那么之后单点加减操作的代价就是 \(\sum\limits_{i = 1}^n |c_i - a_i|\)。对 \(c\) 差分,可以发现最少的区间操作次数就是差分后数组的正数项的值的和。所以总代价就是 \(\sum\limits_{i = 1}^n |c_i - a_i| + m \sum\limits_{i = 1}^n \max(c_i - c_{i - 1}, 0)\)

考虑 dp,设 \(f_{i, j}\)\(c_i = j\) 的最小代价(显然最优时 \(c_i \ge 0\))。初始有 \(\forall i \ge 0, f_{0, i} = mi\)。转移枚举 \(c_{i - 1}\),可得:

\[f_{i, j} = |a_i - j| + \min\limits_k f_{i - 1, k} + m \max(j - k, 0) \]

看到转移式里面有计算绝对值,联想到 slope trick。发现 \(f_i\) 是凸函数,并且每段的斜率在 \([-1, m + 1]\) 之间。容易归纳,我们初始的图像是一条斜率为 \(m\) 的射线,然后我们先进行 \(f_{i, j} \gets \min\limits_k f_{i - 1, k} + m \max(j - k, 0)\) 的更新,这个更新产生的影响就是,斜率为 \(-1\) 的段被拍平(斜率变成 \(0\)),斜率为 \(m + 1\) 的段斜率变成 \(m\)。然后我们给这个图像整体加上 \(|a_i - j|\) 的分段函数,也就是把 \(\le a_i\) 的段斜率减少 \(1\)\(> a_i\) 的段斜率增加 \(1\)

考虑运用 slope trick,用 multiset 维护这个分段函数。回忆一下在 slope trick 中,我们维护的是分段函数图像变化的断点,并且一个断点代表斜率增加 \(1\)。那么在这题中,我们初始往 multiset 添加 \(m\)\(0\) 代表断点为 \(0\) 且斜率为 \(m\),然后当 \(i = 1\) 时,因为图像不存在斜率为 \(-1\)\(m + 1\) 的段,因此我们不需要进行删除操作,直接添加两个 \(a_i\) 表示 \(a_i\) 处斜率变化为 \(2\)。当 \(i > 1\) 时,我们先删除 multiset 中的最小值和最大值表示这两个断点被拍平了,不存在了,再添加两个 \(a_i\)

至于统计答案,我们在每次添加两个 \(a_i\) 后统计,此时 multiset 中的最小值就是斜率 \(-1 \to 0\) 变化的断点,因此我们把答案累加 \(a_i - p\),其中 \(p\)multiset 中的最小值(不用加绝对值是因为此时加入 \(a_i\)\(p\) 一定 \(\le a_i\))。

于是我们现在可以 \(O(n \log n)\) 求解这个问题了。


考虑问题的第二部分,即统计所有可能的 \(a_i\) 对应的答案之和。

考虑我们上面的算法流程。

初始往 multiset 中添加 \(m\)\(0\)\(i = 1\) 时,往 multiset 中添加 \(2\)\(a_i\),然后计算 \(a_i - p\),其中 \(p\)multiset 中最小值;\(i > 1\) 时,先删除 multiset 中的最小值和最大值,然后往其中添加 \(2\)\(a_i\),再计算 \(a_i - p\)

\(a_i\) 部分的贡献系数是容易统计的,就是 \(K^{n - 1}\)(选定 \(a_i\) 后其他的可以任意选,都能产生贡献)。问题还剩下统计所有 \(p\) 的和。

我们枚举 \(nK\) 个可能的 \(p\),分别计算最小值 \(< p\) 和最小值 \(\le p\) 的方案数,二者差分一下就是 \(p\) 的贡献系数。

直接做不好维护 multiset,但是如果 \(a_i \in \{0, 1\}\),我们就能维护 \(1\) 的个数来表示整个 multiset 了(非常经典的套路:任意值转 \(01\))。我们不妨让 \(a_i \gets [a_i \ge p]\),这样最小值 \(< p\) 等价于最小值 \(= 0\)

发现只有 \(1\) 的个数是 \(m + 2\) 时,multiset 中的最小值才不是 \(0\)。因此考虑一个容斥,总方案数减去最小值为 \(1\) 的方案数。总方案数显然是 \(nK^n\)(一共 \(n\) 轮,\(a\) 数组有 \(K^n\) 种产生方式),如果我们设 \(f_{i, j}\) 为进行到第 \(i\) 轮,multiset 中有 \(j\)\(1\) 的方案数,那么最小值为 \(1\) 的方案数就是 \(\sum\limits_{i = 1}^n K^{n - i} f_{i, m + 2}\)(乘上 \(K^{n - i}\) 是因为第 \(i + 1 \sim n\) 轮中 \(a_i\) 的选择都不影响第 \(i\) 轮的最小值是 \(1\))。

现在考虑 \(f_{i - 1} \to f_i\)。对于一个 \(f_{i - 1, j}\),我们先进行删除操作,即 \(j \gets j - [j > 0] - [j = m + 2]\),然后我们考虑选择 \(a_i\),设 \(t = \sum\limits_{j = 1}^K [b_{i, j} \ge p]\),也就是能选的 \(1\) 的个数,那么 \(f_{i, j} \gets (K - t) f_{i - 1, j}\)\(f_{i, j + 2} \gets t f_{i - 1, j}\)。记得特判 \(i = 1\)


至此我们终于以 \(O(n^3K)\) 的时间复杂度完成了这题。

code
// Problem: E - Increment Decrement
// Contest: AtCoder - AtCoder Grand Contest 049
// URL: https://atcoder.jp/contests/agc049/tasks/agc049_e
// Memory Limit: 1024 MB
// Time Limit: 2000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef double db;
typedef long double ldb;
typedef pair<ll, ll> pii;

const int maxn = 55;
const ll mod = 1000000007;

ll n, m, K, a[maxn][maxn], lsh[maxn * maxn], tot, pw[maxn], f[maxn][maxn];

inline void upd(ll &x, ll y) {
	x += y;
	(x >= mod) && (x -= mod);
}

inline ll calc(ll x) {
	mems(f, 0);
	ll cnt = 0;
	for (int i = 1; i <= K; ++i) {
		cnt += (a[1][i] >= x);
	}
	f[1][0] = K - cnt;
	f[1][2] = cnt;
	for (int i = 2; i <= n; ++i) {
		cnt = 0;
		for (int j = 1; j <= K; ++j) {
			cnt += (a[i][j] >= x);
		}
		for (int j = 0; j <= m + 2; ++j) {
			if (!f[i - 1][j]) {
				continue;
			}
			int nj = j - (j > 0) - (j == m + 2);
			upd(f[i][nj], f[i - 1][j] * (K - cnt) % mod);
			upd(f[i][nj + 2], f[i - 1][j] * cnt % mod);
		}
	}
	ll ans = n * pw[n] % mod;
	for (int i = 1; i <= n; ++i) {
		ans = (ans - f[i][m + 2] * pw[n - i] % mod + mod) % mod;
	}
	return ans;
}

void solve() {
	scanf("%lld%lld%lld", &n, &m, &K);
	pw[0] = 1;
	for (int i = 1; i <= n; ++i) {
		pw[i] = pw[i - 1] * K % mod;
	}
	ll ans = 0;
	for (int i = 1; i <= n; ++i) {
		for (int j = 1; j <= K; ++j) {
			scanf("%lld", &a[i][j]);
			lsh[++tot] = a[i][j];
			ans = (ans + a[i][j]) % mod;
		}
	}
	sort(lsh + 1, lsh + tot + 1);
	tot = unique(lsh + 1, lsh + tot + 1) - lsh - 1;
	ans = ans * pw[n - 1] % mod;
	for (int i = 1; i <= tot; ++i) {
		ans = (ans - lsh[i] * (calc(lsh[i] + 1) - calc(lsh[i]) + mod) % mod + mod) % mod;
	}
	printf("%lld\n", ans);
}

int main() {
	int T = 1;
	// scanf("%d", &T);
	while (T--) {
		solve();
	}
	return 0;
}