AtCoder Regular Contest 128 D Neq Neq

发布时间 2023-05-03 18:12:37作者: zltzlt

洛谷传送门

AtCoder 传送门

考虑把所有 \(a_i = a_{i+1}\) 的位置断开,分别计算然后把方案数乘起来。接下来的讨论假设 \(a_i \ne a_{i+1}\)

考虑一个 dp,设 \(f_i\)\([1,i]\) 最后剩下的集合的方案数。转移需要从之前所有可以被删的区间转移过来。

现在问题变成了如何判定 \((a_1, a_2, ..., a_m)\) 可以被删至只剩 \((a_1, a_m)\)

如果 \(m \le 3\) 一定可行,下面假设 \(m \ge 4\)

发现只要数组中不同数的个数 \(\ge 3\) 即可。

考虑证明。如果不同数的个数 \(\le 2\) 一定不可行,否则接下来找到 \(a_{i-1}, a_i, a_{i+1}\) 使得它们两两互不相同,如果删去 \(a_i\) 后不同数个数变成 \(2\),那么数组形式一定是 \((x,y,...,x,y,a_i,x,y,...)\)。此时删去 \(a_{i-1}\)\(a_{i+1}\) 即可。否则删去 \(a_i\)。证毕。

知道了这个限制,可以双指针 + 前缀和简单维护。

code
// Problem: D - Neq Neq
// Contest: AtCoder - Daiwa Securities Co. Ltd. Programming Contest 2021(AtCoder Regular Contest 128)
// URL: https://atcoder.jp/contests/arc128/tasks/arc128_d
// Memory Limit: 1024 MB
// Time Limit: 2000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;

const int maxn = 200100;
const ll mod = 998244353;

ll n, a[maxn], f[maxn], g[maxn], b[maxn], c[maxn], cnt, m;

inline void add(ll x) {
	cnt += (!c[x]);
	++c[x];
}

inline void del(ll x) {
	--c[x];
	cnt -= (!c[x]);
}

inline ll calc() {
	for (int i = 1, j = 0; i <= m; ++i) {
		f[i] = (i == 1);
		if (i > 1) {
			f[i] = (f[i] + f[i - 1]) % mod;
		}
		if (i > 2) {
			f[i] = (f[i] + f[i - 2]) % mod;
		}
		add(b[i]);
		while (cnt >= 3) {
			del(b[++j]);
		}
		f[i] = (f[i] + g[min(j, max(i - 3, 0))]) % mod;
		g[i] = (g[i - 1] + f[i]) % mod;
	}
	return f[m];
}

void solve() {
	scanf("%lld", &n);
	for (int i = 1; i <= n; ++i) {
		scanf("%lld", &a[i]);
	}
	ll ans = 1;
	for (int i = 1, j = 1; i <= n; i = (++j)) {
		while (j < n && a[j] != a[j + 1]) {
			++j;
		}
		m = 0;
		for (int k = i; k <= j; ++k) {
			b[++m] = a[k];
		}
		ans = ans * calc() % mod;
		cnt = 0;
		for (int k = i; k <= j; ++k) {
			c[a[k]] = 0;
		}
	}
	printf("%lld\n", ans);
}

int main() {
	int T = 1;
	// scanf("%d", &T);
	while (T--) {
		solve();
	}
	return 0;
}