洛谷 P5359 [SDOI2019] 染色

发布时间 2024-01-12 22:40:08作者: zltzlt

洛谷传送门

LOJ 传送门

dp 好题。

首先有一个显然的状态,设 \(f_{i, x, y}\) 为第 \(i\) 列上下两格的颜色分别为 \(x, y\) 的方案数。但是这样做时间复杂度至少为 \(O(nm^2)\),无法接受。

注意到全 \(0\) 列的转移是重复的。我们可以试着只在两个相邻非全 \(0\) 列转移。这样我们需要预处理全 \(0\) 列的转移系数。

我们考虑预处理转移系数时把一段全 \(0\) 列的旁边两列的数的不等关系记入状态中。一共只会有 \(7\) 种本质不同的状态,而其中 \(2\) 对状态是等价的,所以实际上只有 \(5\) 种状态:

0:    1:    2:    3:    4:
a..a  a..b  a..a  a..b  a..b
b..b  b..a  b..c  b..c  c..d
            ----  ----
            a..c  a..c
            b..b  b..a

其中相同字母代表的数相同,不同字母代表的数不同。

所以我们设 \(g_{i, j}\) 表示一段全 \(0\) 列的长度为 \(i\),相邻两列的状态为 \(j\),全 \(0\) 列填法的方案数。

注意到这个东西是可以递推的。以 \(g_{i, 4}\) 的转移为例。

若上一列的状态为 \(0\)

a..ac
b..bd

其中 ab 之前已经确定,所以转移系数为 \(1\)

若上一列的状态为 \(1\)

a..bc
b..ad

转移系数为 \(1\)

若上一列的状态为 \(2\)

a..ac
b..cd

a..cc
b..bd

其中上一列的 c 还没确定(注意两列的 c 不同),两种方案 c 都有 \(m - 3\) 种方案,所以转移系数为 \(2m - 6\)

若上一列的状态为 \(3\)

a..bc
b..cd

a..cc
b..ad

转移系数为 \(2m - 6\)

若上一列的状态为 \(4\)

a..cc
b..dd

考虑分类讨论。若上一列的 c 等于这一列的 d,那么上一列的 d\(m - 3\) 种取值;否则上一列的 cd 各有 \(m - 4\) 种取值,因此转移系数为 \((m - 4)^2\)

因此:

\[g_{i, 4} = g_{i - 1, 0} + g_{i - 1, 1} + (2m - 6) g_{i - 1, 2} + (2m - 6) g_{i - 1, 3} + (m - 3 + (m - 4)^2) g_{i - 1, 4} \]

再来考虑相邻非全 \(0\) 列的转移。设 \(f_{i, j}\) 表示考虑了前 \(i\) 列,第 \(i\) 列空着的那列填了 \(j\) 的方案数(若第 \(i\) 列上下都填了数则 \(j\) 只在 \(a_{i, 1}\) 处有值)。设上一个非全 \(0\) 的列为 \(k\)\(t = i - k - 1\) 表示中间全 \(0\) 列的数量。

这里要讨论的情况非常多。以上一个和当前非全 \(0\) 列都只有第一行填了数为例。

若两列第一行填的数相同:

a..a
?..b

此时 \(f_{i, j}\)\(j\) 代表 b 处的值。若 ? 处的值为 \(j\),那么此时的状态为 \(0\);若 ? 处的值不为 \(j\),那么此时的状态为 \(2\)。我们有:

\[f_{i, j} = f_{k, j} g_{t, 0} + \sum\limits_{p \ne j} f_{k, p} g_{t, 2} \]

若设 \(s = \sum\limits_{j = 1}^m f_{k, j}\),我们有:

\[f_{i, j} = f_{k, j} g_{t, 0} + (s - f_{k, j}) g_{t, 2} \]

若两列第一行填的数不同,讨论第 \(i\) 列第二行填的数与第 \(k\) 列第一行填的数是否相同:

a..b
?..a

a..b
?..c

a 处的值为 \(x\)b 处的值为 \(y\)。我们有:

\[f_{i, x} = f_{k, y} g_{t, 1} + (s - f_{k, y}) g_{t, 3} \]

\[\forall j \ne x, f_{i, j} = f_{k, y} g_{t, 3} + f_{k, j} g_{t, 2} + (s - f_{k, y} - f_{k, j}) g_{t, 4} \]

剩下情况的讨论见代码。

还要讨论前缀全 \(0\) 列和后缀全 \(0\) 列。前缀全 \(0\) 列在初始化 \(f\) 数组时讨论。后缀全 \(0\) 列给答案乘上一个系数即可。

需要特判全部列都是全 \(0\) 列的情况。

考虑优化。发现我们对于 \(f\) 数组只会进行全局乘,全局加,单点查,全局查,单点修改的操作。例如 \(f_{i, j} = f_{k, j} g_{t, 0} + (s - f_{k, j}) g_{t, 2}\),可以看成是先全局乘 \(g_{t, 0} - g_{t, 2}\) 再全局加 \(s \times g_{t, 2}\)。对于不符合条件的状态赋值 \(0\) 可以看成是区间乘 \(0\) 操作。

这样我们可以用线段树维护 \(f\) 数组。转移进行对应的全局乘和全局加操作即可。

时间复杂度 \(O(n \log m)\)。但是考虑到全部操作都是全局和单点的操作,所以应该可以仿照这题的方法优化到 \(O(n)\)

下面的代码用线段树实现,保留了一份 \(O(nm)\) 暴力供参考。

code
// Problem: P5359 [SDOI2019] 染色
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P5359
// Memory Limit: 500 MB
// Time Limit: 3000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;

const int maxn = 100100;
const ll mod = 1000000009;

ll n, m, a[maxn][2], g[maxn][5];

namespace Sub1 {
	ll f[2][maxn];
	
	void solve() {
		int k = 0, o = 0;
		for (int i = 1; i <= n; ++i) {
			if (!a[i][0] && !a[i][1]) {
				continue;
			}
			o ^= 1;
			if (!k) {
				if (i == 1) {
					if ((!a[i][0]) + (!a[i][1]) == 1) {
						for (int j = 1; j <= m; ++j) {
							f[o][j] = (j != a[i][0] && j != a[i][1]);
						}
					} else {
						f[o][a[i][1]] = 1;
					}
				} else {
					if ((!a[i][0]) + (!a[i][1]) == 1) {
						for (int j = 1; j <= m; ++j) {
							if (j == a[i][0] || j == a[i][1]) {
								continue;
							}
							f[o][j] = (g[i - 2][0] + g[i - 2][1] + g[i - 2][2] * (m - 2) % mod * 2 % mod + g[i - 2][3] * (m - 2) % mod * 2 % mod + g[i - 2][4] * (m - 2) % mod * (m - 3) % mod) % mod;
						}
					} else {
						f[o][a[i][1]] = (g[i - 2][0] + g[i - 2][1] + g[i - 2][2] * (m - 2) % mod * 2 % mod + g[i - 2][3] * (m - 2) % mod * 2 % mod + g[i - 2][4] * (m - 2) % mod * (m - 3) % mod) % mod;
					}
				}
			} else {
				int t = i - k - 1;
				ll s = 0;
				for (int j = 1; j <= m; ++j) {
					f[o][j] = 0;
					s = (s + f[o ^ 1][j]) % mod;
				}
				for (int j = 1; j <= m; ++j) {
					if (a[i][0]) {
						if (j == a[i][0] || (a[i][1] && j != a[i][1])) {
							continue;
						}
						if (a[k][0]) {
							if (a[i][0] == a[k][0]) {
								f[o][j] = (f[o ^ 1][j] * g[t][0] % mod + (s - f[o ^ 1][j] + mod) * g[t][2] % mod) % mod;
							} else if (a[k][0] == j) {
								f[o][j] = (f[o ^ 1][a[i][0]] * g[t][1] % mod + (s - f[o ^ 1][a[i][0]] + mod) * g[t][3] % mod) % mod;
							} else {
								f[o][j] = (f[o ^ 1][a[i][0]] * g[t][3] % mod + f[o ^ 1][j] * g[t][2] % mod + (s - f[o ^ 1][a[i][0]] - f[o ^ 1][j] + mod + mod) * g[t][4] % mod) % mod;
							}
						} else {
							if (a[i][0] == a[k][1]) {
								f[o][j] = (f[o ^ 1][j] * g[t][1] % mod + (s - f[o ^ 1][j] + mod) * g[t][3] % mod) % mod;
							} else if (a[k][1] == j) {
								f[o][j] = (f[o ^ 1][a[i][0]] * g[t][0] % mod + (s - f[o ^ 1][a[i][0]] + mod) * g[t][2] % mod) % mod;
							} else {
								f[o][j] = (f[o ^ 1][a[i][0]] * g[t][2] % mod + f[o ^ 1][j] * g[t][3] % mod + (s - f[o ^ 1][a[i][0]] - f[o ^ 1][j] + mod + mod) * g[t][4] % mod) % mod;
							}
						}
					} else {
						if (j == a[i][1]) {
							continue;
						}
						if (a[k][0]) {
							if (a[k][0] == j) {
								f[o][j] = (f[o ^ 1][a[i][1]] * g[t][0] % mod + (s - f[o ^ 1][a[i][1]] + mod) * g[t][2] % mod) % mod;
							} else if (a[k][0] == a[i][1]) {
								f[o][j] = (f[o ^ 1][j] * g[t][1] % mod + (s - f[o ^ 1][j] + mod) * g[t][3] % mod) % mod;
							} else {
								f[o][j] = (f[o ^ 1][j] * g[t][3] % mod + f[o ^ 1][a[i][1]] * g[t][2] % mod + (s - f[o ^ 1][j] - f[o ^ 1][a[i][1]] + mod + mod) * g[t][4] % mod) % mod;
							}
						} else {
							if (a[k][1] == j) {
								f[o][j] = (f[o ^ 1][a[i][1]] * g[t][1] % mod + (s - f[o ^ 1][a[i][1]] + mod) * g[t][3] % mod) % mod;
							} else if (a[k][1] == a[i][1]) {
								f[o][j] = (f[o ^ 1][j] * g[t][0] % mod + (s - f[o ^ 1][j] + mod) * g[t][2] % mod) % mod;
							} else {
								f[o][j] = (f[o ^ 1][j] * g[t][2] % mod + f[o ^ 1][a[i][1]] * g[t][3] % mod + (s - f[o ^ 1][j] - f[o ^ 1][a[i][1]] + mod + mod) * g[t][4] % mod) % mod;
							}
						}
					}
				}
			}
			k = i;
		}
		if (!k) {
			printf("%lld\n", (g[n - 2][0] * m % mod * (m - 1) % mod + g[n - 2][1] * m % mod * (m - 1) % mod + g[n - 2][2] * m % mod * (m - 1) % mod * (m - 2) % mod * 2 % mod + g[n - 2][3] * m % mod * (m - 1) % mod * (m - 2) % mod * 2 % mod + g[n - 2][4] * m % mod * (m - 1) % mod * (m - 2) % mod * (m - 3) % mod) % mod);
			return;
		}
		if (k == n) {
			ll ans = 0;
			for (int i = 1; i <= m; ++i) {
				ans = (ans + f[o][i]) % mod;
			}
			printf("%lld\n", ans);
			return;
		}
		int t = n - k - 1;
		ll ans = 0;
		for (int i = 1; i <= m; ++i) {
			ans = (ans + f[o][i] * (g[t][0] + g[t][1] + g[t][2] * (m - 2) % mod * 2 % mod + g[t][3] * (m - 2) % mod * 2 % mod + g[t][4] * (m - 2) % mod * (m - 3) % mod) % mod) % mod;
		}
		printf("%lld\n", ans);
	}
}

ll b[maxn];

namespace SGT {
	ll a[maxn << 2], add[maxn << 2], tag[maxn << 2];
	
	inline void pushup(int x) {
		a[x] = (a[x << 1] + a[x << 1 | 1]) % mod;
	}
	
	inline void pushdown(int x, int l, int r) {
		if (tag[x] != 1) {
			a[x << 1] = a[x << 1] * tag[x] % mod;
			a[x << 1 | 1] = a[x << 1 | 1] * tag[x] % mod;
			add[x << 1] = add[x << 1] * tag[x] % mod;
			add[x << 1 | 1] = add[x << 1 | 1] * tag[x] % mod;
			tag[x << 1] = tag[x << 1] * tag[x] % mod;
			tag[x << 1 | 1] = tag[x << 1 | 1] * tag[x] % mod;
			tag[x] = 1;
		}
		if (add[x]) {
			int mid = (l + r) >> 1;
			a[x << 1] = (a[x << 1] + add[x] * (mid - l + 1)) % mod;
			a[x << 1 | 1] = (a[x << 1 | 1] + add[x] * (r - mid)) % mod;
			add[x << 1] = (add[x << 1] + add[x]) % mod;
			add[x << 1 | 1] = (add[x << 1 | 1] + add[x]) % mod;
			add[x] = 0;
		}
	}
	
	inline void build(int rt, int l, int r) {
		tag[rt] = 1;
		if (l == r) {
			a[rt] = b[l];
			return;
		}
		int mid = (l + r) >> 1;
		build(rt << 1, l, mid);
		build(rt << 1 | 1, mid + 1, r);
		pushup(rt);
	}
	
	void update1(int rt, int l, int r, int ql, int qr, ll x) {
		if (ql > qr) {
			return;
		}
		if (ql <= l && r <= qr) {
			a[rt] = a[rt] * x % mod;
			add[rt] = add[rt] * x % mod;
			tag[rt] = tag[rt] * x % mod;
			return;
		}
		pushdown(rt, l, r);
		int mid = (l + r) >> 1;
		if (ql <= mid) {
			update1(rt << 1, l, mid, ql, qr, x);
		}
		if (qr > mid) {
			update1(rt << 1 | 1, mid + 1, r, ql, qr, x);
		}
		pushup(rt);
	}
	
	void update2(int rt, int l, int r, int ql, int qr, ll x) {
		if (ql > qr) {
			return;
		}
		if (ql <= l && r <= qr) {
			a[rt] = (a[rt] + x * (r - l + 1)) % mod;
			add[rt] = (add[rt] + x) % mod;
			return;
		}
		pushdown(rt, l, r);
		int mid = (l + r) >> 1;
		if (ql <= mid) {
			update2(rt << 1, l, mid, ql, qr, x);
		}
		if (qr > mid) {
			update2(rt << 1 | 1, mid + 1, r, ql, qr, x);
		}
		pushup(rt);
	}
	
	void modify(int rt, int l, int r, int x, ll y) {
		if (l == r) {
			a[rt] = y;
			return;
		}
		pushdown(rt, l, r);
		int mid = (l + r) >> 1;
		(x <= mid) ? modify(rt << 1, l, mid, x, y) : modify(rt << 1 | 1, mid + 1, r, x, y);
		pushup(rt);
	}
	
	ll query(int rt, int l, int r, int ql, int qr) {
		if (ql > qr) {
			return 0;
		}
		if (ql <= l && r <= qr) {
			return a[rt];
		}
		pushdown(rt, l, r);
		int mid = (l + r) >> 1;
		ll res = 0;
		if (ql <= mid) {
			res = (res + query(rt << 1, l, mid, ql, qr));
		}
		if (qr > mid) {
			res = (res + query(rt << 1 | 1, mid + 1, r, ql, qr));
		}
		return res;
	}
	
	void dfs(int rt, int l, int r) {
		if (l == r) {
			b[l] = a[rt];
			return;
		}
		pushdown(rt, l, r);
		int mid = (l + r) >> 1;
		dfs(rt << 1, l, mid);
		dfs(rt << 1 | 1, mid + 1, r);
	}
}

void solve() {
	scanf("%lld%lld", &n, &m);
	for (int i = 1; i <= n; ++i) {
		scanf("%lld", &a[i][0]);
	}
	for (int i = 1; i <= n; ++i) {
		scanf("%lld", &a[i][1]);
	}
	for (int i = 1; i <= n; ++i) {
		if (a[i][0] && a[i][1] && a[i][0] == a[i][1]) {
			puts("0");
			return;
		}
		if (i < n) {
			if (a[i][0] && a[i + 1][0] && a[i][0] == a[i + 1][0]) {
				puts("0");
				return;
			}
			if (a[i][1] && a[i + 1][1] && a[i][1] == a[i + 1][1]) {
				puts("0");
				return;
			}
		}
	}
	g[0][1] = g[0][3] = g[0][4] = 1;
	for (int i = 1; i <= n; ++i) {
		g[i][0] = (g[i - 1][1] + g[i - 1][3] * (m - 2) % mod * 2 % mod + g[i - 1][4] * (m - 2) % mod * (m - 3) % mod) % mod;
		g[i][1] = (g[i - 1][0] + g[i - 1][2] * (m - 2) % mod * 2 % mod + g[i - 1][4] * (m - 2) % mod * (m - 3) % mod) % mod;
		g[i][2] = (g[i - 1][1] + g[i - 1][2] * (m - 2) % mod + g[i - 1][3] * (m * 2 - 5) % mod + g[i - 1][4] * (m - 3) % mod * (m - 3) % mod) % mod;
		g[i][3] = (g[i - 1][0] + g[i - 1][2] * (m * 2 - 5) % mod + g[i - 1][3] * (m - 2) % mod + g[i - 1][4] * (m - 3) % mod * (m - 3) % mod) % mod;
		g[i][4] = (g[i - 1][0] + g[i - 1][1] + g[i - 1][2] * (m * 2 - 6) % mod + g[i - 1][3] * (m * 2 - 6) % mod + g[i - 1][4] * ((m - 3 + (m - 4) * (m - 4)) % mod) % mod) % mod;
	}
	if (max(n, m) <= 10000) {
		Sub1::solve();
		return;
	}
	int k = 0;
	for (int i = 1; i <= n; ++i) {
		if (!a[i][0] && !a[i][1]) {
			continue;
		}
		if (!k) {
			if (i == 1) {
				if ((!a[i][0]) + (!a[i][1]) == 1) {
					for (int j = 1; j <= m; ++j) {
						b[j] = (j != a[i][0] && j != a[i][1]);
					}
				} else {
					b[a[i][1]] = 1;
				}
			} else {
				if ((!a[i][0]) + (!a[i][1]) == 1) {
					for (int j = 1; j <= m; ++j) {
						if (j == a[i][0] || j == a[i][1]) {
							continue;
						}
						b[j] = (g[i - 2][0] + g[i - 2][1] + g[i - 2][2] * (m - 2) % mod * 2 % mod + g[i - 2][3] * (m - 2) % mod * 2 % mod + g[i - 2][4] * (m - 2) % mod * (m - 3) % mod) % mod;
					}
				} else {
					b[a[i][1]] = (g[i - 2][0] + g[i - 2][1] + g[i - 2][2] * (m - 2) % mod * 2 % mod + g[i - 2][3] * (m - 2) % mod * 2 % mod + g[i - 2][4] * (m - 2) % mod * (m - 3) % mod) % mod;
				}
			}
			SGT::build(1, 1, m);
		} else {
			int t = i - k - 1;
			ll s = SGT::a[1];
			if (a[i][0]) {
				if (a[k][0]) {
					if (a[i][0] == a[k][0]) {
						SGT::update1(1, 1, m, 1, m, (g[t][0] - g[t][2] + mod) % mod);
						SGT::update2(1, 1, m, 1, m, s * g[t][2] % mod);
					} else {
						ll p = SGT::query(1, 1, m, a[i][0], a[i][0]);
						SGT::update1(1, 1, m, 1, m, (g[t][2] - g[t][4] + mod) % mod);
						SGT::update2(1, 1, m, 1, m, (p * g[t][3] % mod + (s - p + mod) * g[t][4] % mod) % mod);
						SGT::modify(1, 1, m, a[k][0], (p * g[t][1] % mod + (s - p + mod) * g[t][3] % mod) % mod);
					}
				} else {
					if (a[i][0] == a[k][1]) {
						SGT::update1(1, 1, m, 1, m, (g[t][1] - g[t][3] + mod) % mod);
						SGT::update2(1, 1, m, 1, m, s * g[t][3] % mod);
					} else {
						ll p = SGT::query(1, 1, m, a[i][0], a[i][0]);
						SGT::update1(1, 1, m, 1, m, (g[t][3] - g[t][4] + mod) % mod);
						SGT::update2(1, 1, m, 1, m, (p * g[t][2] % mod + (s - p + mod) * g[t][4] % mod) % mod);
						SGT::modify(1, 1, m, a[k][1], (p * g[t][0] % mod + (s - p + mod) * g[t][2] % mod) % mod);
					}
				}
				SGT::modify(1, 1, m, a[i][0], 0);
			} else {
				if (a[k][0]) {
					if (a[k][0] == a[i][1]) {
						SGT::update1(1, 1, m, 1, m, (g[t][1] - g[t][3] + mod) % mod);
						SGT::update2(1, 1, m, 1, m, s * g[t][3] % mod);
					} else {
						ll p = SGT::query(1, 1, m, a[i][1], a[i][1]);
						SGT::update1(1, 1, m, 1, m, (g[t][3] - g[t][4] + mod) % mod);
						SGT::update2(1, 1, m, 1, m, (p * g[t][2] % mod + (s - p + mod) * g[t][4] % mod) % mod);
						SGT::modify(1, 1, m, a[k][0], (p * g[t][0] % mod + (s - p + mod) * g[t][2] % mod) % mod);
					}
				} else {
					if (a[k][1] == a[i][1]) {
						SGT::update1(1, 1, m, 1, m, (g[t][0] - g[t][2] + mod) % mod);
						SGT::update2(1, 1, m, 1, m, s * g[t][2] % mod);
					} else {
						ll p = SGT::query(1, 1, m, a[i][1], a[i][1]);
						SGT::update1(1, 1, m, 1, m, (g[t][2] - g[t][4] + mod) % mod);
						SGT::update2(1, 1, m, 1, m, (p * g[t][3] % mod + (s - p + mod) * g[t][4] % mod) % mod);
						SGT::modify(1, 1, m, a[k][1], (p * g[t][1] % mod + (s - p + mod) * g[t][3] % mod) % mod);
					}
				}
				SGT::modify(1, 1, m, a[i][1], 0);
			}
			if (a[i][0] && a[i][1]) {
				SGT::update1(1, 1, m, 1, a[i][1] - 1, 0);
				SGT::update1(1, 1, m, a[i][1] + 1, m, 0);
			}
		}
		k = i;
	}
	if (!k) {
		printf("%lld\n", (g[n - 2][0] * m % mod * (m - 1) % mod + g[n - 2][1] * m % mod * (m - 1) % mod + g[n - 2][2] * m % mod * (m - 1) % mod * (m - 2) % mod * 2 % mod + g[n - 2][3] * m % mod * (m - 1) % mod * (m - 2) % mod * 2 % mod + g[n - 2][4] * m % mod * (m - 1) % mod * (m - 2) % mod * (m - 3) % mod) % mod);
		return;
	}
	SGT::dfs(1, 1, m);
	if (k == n) {
		ll ans = 0;
		for (int i = 1; i <= m; ++i) {
			ans = (ans + b[i]) % mod;
		}
		printf("%lld\n", ans);
		return;
	}
	int t = n - k - 1;
	ll ans = 0;
	for (int i = 1; i <= m; ++i) {
		ans = (ans + b[i] * (g[t][0] + g[t][1] + g[t][2] * (m - 2) % mod * 2 % mod + g[t][3] * (m - 2) % mod * 2 % mod + g[t][4] * (m - 2) % mod * (m - 3) % mod) % mod) % mod;
	}
	printf("%lld\n", ans);
}

int main() {
	int T = 1;
	// scanf("%d", &T);
	while (T--) {
		solve();
	}
	return 0;
}