AtCoder Regular Contest 110 E Shorten ABC

发布时间 2023-04-14 17:34:36作者: zltzlt

洛谷传送门

AtCoder 传送门

考虑把 \(\text{A}\) 看成 \(1\)\(\text{B}\) 看成 \(2\)\(\text{C}\) 看成 \(3\),那么一次操作相当于选择一个 \(a_i \ne a_{i+1}\)\(i\),将 \(a_i\)\(a_{i+1}\) 替换成一个数 \(a_i \oplus a_{i+1}\)

那么题目相当于把 \(a\) 划分成若干段,满足每段的异或和不为 \(0\) 且不是同一种字符或者长度为 \(1\)。将每段的异或和排成一个新数组 \(b\),对所有本质不同的 \(b\) 计数。

我们反过来观察对于一个固定的 \(b\),它有没有可能被形成。

考虑贪心,每次跳到下一个区间异或和为 \(b_i\) 的位置。并且要求最后剩下的一段异或和为 \(0\)

\(nxt_{i,j}\) 为最小的 \(k\) 满足 \(a_i \oplus a_{i+1} \oplus \cdots \oplus a_k = j\),这个是可以线性预处理的。

那么设 \(f_i\) 为将前 \(i\) 个字符划分为若干段,并且每一段都是所有异或和相同的段中右端点最靠左的段。

转移是 \(f_{nxt_{i+1,j}} \gets f_i\)

注意后面还有一段异或和为 \(0\),因此答案不仅仅是 \(f_n\)

注意特判 \(a_i\) 都相等的情况,此时前面的 \(f_i\) 传递不到 \(f_n\),答案为 \(1\)

code
// Problem: E - Shorten ABC
// Contest: AtCoder - AtCoder Regular Contest 110(Sponsored by KAJIMA CORPORATION)
// URL: https://atcoder.jp/contests/arc110/tasks/arc110_e
// Memory Limit: 1024 MB
// Time Limit: 2000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;

const int maxn = 1000100;
const ll mod = 1000000007;

ll n, f[maxn], nxt[maxn][4], a[maxn];
char s[maxn];

void solve() {
	scanf("%lld%s", &n, s + 1);
	for (int i = 1; i <= n; ++i) {
		if (s[i] == 'A') {
			a[i] = 1;
		} else if (s[i] == 'B') {
			a[i] = 2;
		} else {
			a[i] = 3;
		}
	}
	bool flag = 1;
	for (int i = 2; i <= n; ++i) {
		flag &= (a[i] == a[1]);
	}
	if (flag) {
		puts("1");
		return;
	}
	nxt[n + 1][0] = nxt[n + 1][1] = nxt[n + 1][2] = nxt[n + 1][3] = n + 1;
	for (int i = n; i; --i) {
		for (int j = 1; j <= 3; ++j) {
			if (a[i] == j) {
				nxt[i][j] = i;
			} else {
				nxt[i][j] = nxt[i + 1][j ^ a[i]];
			}
		}
	}
	f[0] = 1;
	for (int i = 1; i <= n; ++i) {
		for (int j = 1; j <= 3; ++j) {
			f[nxt[i][j]] = (f[nxt[i][j]] + f[i - 1]) % mod;
		}
	}
	ll x = 0, ans = 0;
	for (int i = n; i; --i) {
		if (x == 0) {
			ans = (ans + f[i]) % mod;
		}
		x ^= a[i];
	}
	printf("%lld\n", ans);
}

int main() {
	int T = 1;
	// scanf("%d", &T);
	while (T--) {
		solve();
	}
	return 0;
}