ARC159F Good Division【性质,DP,线段树】

发布时间 2023-04-19 20:42:03作者: came11ia

定义一个序列是好的当且仅当其可以通过每次删去一对相邻的不同的数把序列删空。

给定一个长度为 \(2n\) 的序列 \(a\),求有多少种划分方式使得每一段都是好的。答案对 \(998244353\) 取模。

\(n \leq 5 \times 10^5\),时限 \(\text{5.0s}\)


先考虑什么样的数列是合法的,显然必要条件是长度为偶数,且不存在绝对众数。容易证明这也是充分的:当出现最多的出现次数恰为总个数的一半时,可以直接构造,否则可以任意操作直到删空或者满足上述条件。

这样容易得到一个暴力的 \(\mathcal{O}(n^2)\) DP。然后是一个神秘的观察:考虑有多少 \(i\),满足存在以 \(i\) 为结尾的子串使得其以 \(c\) 为绝对众数。对于当前考虑的 \(c\),我们将 \(c\) 所在的位置视为 \(1\),其它位置视为 \(-1\),然后对这个序列做前缀和,记为 \(s\)。那么一个位置 \(i\) 满足上述条件当且仅当 \(\min \limits_{j < i} s_j < s_i\)。结合图像感性理解一下,对于每个 \(+1\),其至多会多使后面的一个 \(i\) 满足这个条件,因此 \(i\) 的总数是 \(\mathcal{O}(cnt_c)\) 的,并且我们可以快速地找到这些位置。

然后就可以优化刚刚的那个 DP 了,我们每次先令 \(f_i = \sum \limits_{j<i} f_j\),然后枚举所有可能的非法颜色容斥。因为每个不合法段至多对应一种绝对众数,所以直接容斥就是对的。枚举 \(c\) 之后,我们要减去的是 \(\sum \limits_{j < i,s_j < s_i} f_j\),还是结合图像分析,考虑找到 \(i\) 之前第一个 \(s_i = s_j\) 的位置 \(j\),那么上面式子的值可以直接从 \(j\) 继承 \(<j\) 的部分,这可以用 map 来维护。而 \(j\) 后面的部分要么全满足要么全不满足,利用前缀和容易 \(\mathcal{O}(1)\) 求出。

剩下的问题是如何快速找到这个 \(j\)。用线段树维护 \(s\) 数组,然后线段树二分即可。具体来说,初始时假设每个位置都是 \(-1\),每次把对应位置修改为 \(1\),计算完答案再修改回去就行了。注意到任意一个子区间的值域一定是连续的,因此线段树维护区间加,区间 \(\min\),区间 \(\max\) 就可以快速判断区间内有没有要找的值了。总时间复杂度 \(\mathcal{O}(n \log n)\)

code
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef vector <int> vi;
constexpr int N = 1e6 + 5, M = N << 2, mod = 998244353;
int n, a[N], b[N];
map <int, int> t[N], g[N];
vi col[N], q[N];
int f[N], pref[N];
#define m ((l + r) >> 1)
int mx[M], mi[M], tag[M];
//bool debug = 0;
void up(int x) {
	mx[x] = max(mx[x << 1], mx[x << 1 | 1]);
	mi[x] = min(mi[x << 1], mi[x << 1 | 1]);
}
void ptag(int x, int v) { tag[x] += v, mx[x] += v, mi[x] += v; }
void down(int x) {
	if (tag[x]) ptag(x << 1, tag[x]), ptag(x << 1 | 1, tag[x]), tag[x] = 0;
}
void build(int x, int l, int r) {
	if (l == r) return mi[x] = mx[x] = -l, void();
	build(x << 1, l, m);
	build(x << 1 | 1, m + 1, r);
	up(x);
}
void add(int x, int l, int r, int ql, int qr, int v) {
	if (ql <= l && qr >= r) return ptag(x, v);
	down(x);
	if (ql <= m) add(x << 1, l, m, ql, qr, v);
	if (qr > m) add(x << 1 | 1, m + 1, r, ql, qr, v);
	up(x); 
}
int get(int x, int l, int r, int p) {
	if (l == r) return mx[x];
	down(x);
	if (p <= m) return get(x << 1, l, m, p);
	else return get(x << 1 | 1, m + 1, r, p);
}
int qry(int x, int l, int r, int ql, int qr, int v) {
//	if (debug) cout << "qry " << l << " " << r << ", " << "mx = " << mx[x] << ", " << "mi = " << mi[x] << "\n"; 
	if (ql > qr) return -1;
	if (l == r) {
		if (mi[x] <= v && mx[x] >= v) return l;
		return -1;
	}
	int ret = -1; down(x);
	if (qr <= m) return qry(x << 1, l, m, ql, qr, v); 
	if (mi[x << 1 | 1] <= v && mx[x << 1 | 1] >= v) ret = qry(x << 1 | 1, m + 1, r, ql, qr, v);
	if (ret == -1 && mi[x << 1] <= v && mx[x << 1] >= v) ret = qry(x << 1, l, m, ql, qr, v);
	return ret;
}
#undef m
signed main() {  
    ios :: sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n;
    n *= 2;
    for (int i = 1; i <= n; i++) {
		cin >> b[i];
		q[b[i]].emplace_back(i);
	}
	build(1, 1, n);
	for (int _ = 1; _ <= n; _++) if (q[_].size()) {
//		cerr << "now col : " << _ << "\n";
		int pre = -1, rest = 0;
		vi t;
		q[_].emplace_back(n + 1);
		for (auto i : q[_]) {
			if (i <= n) t.emplace_back(i), col[i].emplace_back(_);
			if (pre == -1) {
				pre = i, rest++;
				continue;
			}
			int j = pre + 1;
			while (j < i && rest > 1) t.emplace_back(j), col[j].emplace_back(_), rest--, j++;
			rest -= ((i - 1) - j + 1);
			if (rest < 0) rest = 0; 
			pre = i, rest++;
		}
		q[_].pop_back();
		for (auto i : q[_]) add(1, 1, n, i, n, 2);
		for (auto j : t) {
			int val = get(1, 1, n, j);
//			if (j == 5) debug = 1;
			int i = qry(1, 1, n, 1, j - 1, val);
//			debug = 0;
			:: t[j][_] = i;
//			cerr << j << " " << val << " " << i << "\n";
			if (i == -1 && val == 0) :: t[j][_] = 0; 
		}
		for (auto i : q[_]) add(1, 1, n, i, n, -2);
	}
//	for (int i = 1; i <= n; i++) {
//		cerr << "now : " << i << "\n";
//		for (auto _ : col[i]) {
//			cerr << "col " << _ << ", ";
//			cerr << "pre " << t[i][_] << "\n";
//		}
//	}
	f[0] = pref[0] = 1;
	for (int i = 1; i <= n; i++) {
		f[i] = pref[i - 1];
		for (auto _ : col[i]) {
			int j = t[i][_];
			g[i][_] = g[j < 0 ? 0 : j][_];
			if (j < 1) {
				if (b[i] == _) {
					g[i][_] = (g[i][_] + (pref[i - 1] + mod - pref[j < 0 ? 0 : j]) % mod) % mod;
					if (j < 0) g[i][_] = (g[i][_] + 1) % mod;
				}
			} else if (j < i - 1) {
				if (b[i] == _) {
					g[i][_] = (g[i][_] + (pref[i - 1] + mod - pref[j]) % mod) % mod;
				}
			}
			f[i] = (f[i] + mod - g[i][_]) % mod;
		}
		if (i % 2) f[i] = 0;
		pref[i] = (pref[i - 1] + f[i]) % mod;
	}
	cout << f[n] << "\n";
    return 0;  
}
/*
5
2 1 3 1 1 2 3 2 2 3
*/