[AGC049E] Increment Decrement

发布时间 2023-06-09 16:57:44作者: Smallbasic

它改变了计数——AGC传。

先考虑只给你一个序列该如何求答案。假设执行完所有区间操作之后的序列为 \(p_i\),那么区间操作的贡献是 $ c \sum \max p_{i+1}-p_i,0$,单点操作的贡献是 \(\sum |p_i-a_i|\)

考虑 dp,设 \(f_{i,j}\) 表示 \(a_i\) 变到 \(j\) 上式的最小代价,初始 \(f_{0,j}=cj\),答案为 \(f_{n+1,0}\), 那么有转移:

\[f_{i,j}=\min_k f_{i-1,k}+c \cdot \max\{j-k,0\}+|j - a_i| \]

我们发现它是由几个凸函数相加起来的,我们可以用 slope trick 优化它(slope trick 详见 CF713C)。具体的,集合中有 \(c\) 个 0。考虑加入 \(|x-a_i|\) 之后折线的变化。初始折线是前面一段斜率为 \(-1\) 的直线,中间斜率递增,到最后斜率为 \(c+1\) 的一段直线。加入我们将原本斜率为 \(-1\) 处的直线斜率变成 \(0\),原本斜率为 \(c+1\) 的直线变为 \(c\),然后再将 \(<a_i\) 的地方斜率 \(-1\)\(>a_i\) 的地方斜率 \(+1\)。对应到集合上就是我们 \(\textbf{先加入}\) 两个 \(a_i\),再弹出最大最小值(因为 \(a_i\) 可能就是最值)。考虑本次对答案的贡献,设最小值位置为 \(p\),则这次操作使得答案的最小值增加了 \(a_i-p\)

再考虑计数,首先 \(a_i\) 的和是好计算的,我们只需要考虑 \(-p\) 的部分,容易想到统计每个数作为最小值被删除了多少次。差分一下转化为求 \(<x\) 的数被作为最小值删除了多少次。将 \(i\) 看成 \([i\ge x]\),统计 \(0\) 被删了几次,容易发现只有全是 \(1\) 的时候 \(0\) 才不会被删除。设 \(g_{i,j}\) 表示考虑到前 \(i\) 个数,集合中有 \(j\)\(1\) 的方案数,答案即为 \(nk^n - \sum g_{i,c+2}\cdot k^{n-i}\)

考虑 \(g\) 的求法,假设第 \(i\) 个数有 \(c_1\)\(1\)\(c_0\)\(0\),那么令 \(t=j-[j>0]-[j==c+2]\) 表示有 \(j\)\(1\) 时去掉最大最小还剩多少个 \(1\),那我们有:\(c_1 g_{i-1,j}\rightarrow g_{i,t+2},c_0 g_{i-1,j}\rightarrow g_{i,t}\)

#include <bits/stdc++.h>

using namespace std;

const int N = 55, mod = 1e9 + 7;

inline int read() {
	register int s = 0, f = 1; register char ch = getchar();
	while (!isdigit(ch)) f = (ch == '-' ? -1 : 1), ch = getchar();
	while (isdigit(ch)) s = (s * 10) + (ch & 15), ch = getchar();
	return s * f;
}

int n, c, k, pw[N], b[N][N], val[N][N], f[N][N];
set<int> st;

inline int calc(int x) {
	for (int i = 1; i <= n; ++i)
		for (int j = 1; j <= k; ++j)
			val[i][j] = (b[i][j] >= x);
	memset(f, 0, sizeof f); f[0][0] = 1;
	for (int i = 1; i <= n; ++i) {
		int c1 = 0;
		for (int j = 1; j <= k; ++j)
			c1 += val[i][j];
		for (int j = 0; j <= c + 2; ++j) {
			int t = j - (j > 0) - (j == c + 2);
			f[i][t + 2] += 1ll * c1 * f[i - 1][j] % mod;
			if (f[i][t + 2] >= mod) f[i][t + 2] -= mod;
			f[i][t] += 1ll * (k - c1) * f[i - 1][j] % mod;
			if (f[i][t] >= mod) f[i][t] -= mod;
		}
	}
	int res = 1ll * n * pw[n] % mod;
	for (int i = 1; i <= n; ++i) {
		res -= 1ll * f[i][c + 2] * pw[n - i] % mod;
		if (res < 0) res += mod;
	} return res;
}

int main() {
	n = read(); c = read(); k = read();
	for (int i = 1; i <= n; ++i)
		for (int j = 1; j <= k; ++j)
			st.insert(b[i][j] = read());
	pw[0] = 1; for (int i = 1; i <= n; ++i) pw[i] = 1ll * k * pw[i - 1] % mod;
	int res = 0;
	for (int i = 1; i <= n; ++i)
		for (int j = 1; j <= k; ++j) {
			res += b[i][j];
			if (res >= mod) res -= mod;
		}
	res = 1ll * res * pw[n - 1] % mod;
	for (int i : st) {
		res -= 1ll * i * ((calc(i + 1) + mod - calc(i)) % mod) % mod;
		if (res < 0) res += mod;
	} printf("%d\n", res);
	return 0;
}