CF1815E Bosco and Particle

发布时间 2023-06-24 17:44:02作者: Gemini7X

有个粒子初始在 \(0\) 位置,\(1\cdots n\) 位置分别为有一个对撞器,如果在 \(0\) 位置则向右,如果在 \(n + 1\) 位置则向左。每个对撞器有一个 \(01\) 串,初始所有对撞器的指针都在开头,当粒子走到 \(i\) 位置时,对撞器所指的值为 \(0\) 则不改变方向,否则反向,指针指向下一个位置,如果在串的末尾则指向开头。求最小的周期长度 \(c\) 满足任意 \(t\) 时间和 \(t + c\) 时间粒子在同一位置。

\(1\le n \le 10^6\)\(\sum |s_i|\le 10^6\)


注意到对于一个位置,无论在右边转了多久,回到这里后和直接从右边回来是一样。左边同理。所以我们只用考虑 \(0,1,2\) 这三个位置。

还有一个显然的事实是每个粒子只用保留它的最小整周期,然后你就可以跑一个暴力,求出一个周期从左边进入 \(a_i\) 次,往右边走出去 \(b_i\) 次。显然这个过程是对称的。

\(f_i\) 代表最后过程中 \(i\to {i+1}\) 的次数,由于左边和右边会右七七八八的破事,所以 \(i\) 这个位置可能进进出出多个周期,所以应该 \(\frac{f_i}{f_{i+1}}=\frac{a_i}{b_i}\)。使用主元法,用 \(f_0\) 表示出所有 \(f_i=f_0\prod_{j=1}^{i}\frac{b_j}{a_j}\)。我们需要构造这个 \(f_0\) 使得每一个 \(f_i\) 是整数,并且 \(b_i\mid f_i\)。根据每个质因子考虑,随时把不够的部分补进 \(f_0\) 里。不妨设 \(b_n\not=0\),统计这一部分即可。

#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e6 + 5, mod = 998244353;
int qmod(int x) { return x >= mod ? x - mod : x; }
int ksm(int a, int b = mod - 2) 
{
	int ret = 1;
	for (; b; b >>= 1, a = 1ll * a * a % mod) if (b & 1) ret = 1ll * ret * a % mod;
	return ret;
}
template <typename T>
void read(T &x)
{
	T sgn = 1;
	char ch = getchar();
	for (; !isdigit(ch); ch = getchar()) if (ch == '-') sgn = -1;
	for (x = 0; isdigit(ch); ch = getchar()) x = x * 10 + ch - '0';
	x *= sgn;
}
int n, prime[maxn], cnt, mn[maxn];
bool vis[maxn];
char s[maxn];
int nxt[maxn], a[maxn], b[maxn];
int mx[maxn], num[maxn];
void sieve(int mx)
{
	for (int i = 2; i <= mx; i++)
	{
		if (!vis[i]) prime[++cnt] = i, mn[i] = i;
		for (int j = 1; j <= cnt && prime[j] * i <= mx; j++)
		{
			vis[i * prime[j]] = 1;
			mn[i * prime[j]] = prime[j];
			if (i % prime[j] == 0) break;
		}
	}
}
int main()
{
	read(n); sieve(1000000);
	for (int _ = 1; _ <= n; _++)
	{
		scanf("%s", s + 1);
		int len = strlen(s + 1);
		for (int i = 2, j = 0; i <= len; i++)
		{
			while (j && s[i] != s[j + 1]) j = nxt[j];
			if (s[i] == s[j + 1]) j++;
			nxt[i] = j;
		}
		int per = len % (len - nxt[len]) == 0 ? len - nxt[len] : len;
		int cur = 0, dir = 1, pos = 1;
		do
		{
			cur += dir;
			if (cur == 1)
			{
				dir == 1 ? a[_]++ : b[_]++;
				if (s[pos] == '1') dir = -dir;
				pos = pos % per + 1;
			}
			else dir = cur == 0 ? 1 : -1;
		} while (cur != 0 || pos != 1);
// 		if (n == 50)
// 			printf("! %d %d\n", a[_], b[_]);
	}
	for (int i = 1; i <= n; i++)
	{
		int now = a[i];
		while (now > 1)
		{
			int p = mn[now];
			while (now % p == 0) now /= p, num[p]--;
			mx[p] += max(0, -num[p]);
			num[p] = max(num[p], 0);
		}
		if (!b[i]) break;
		now = b[i];
		while (now > 1)
		{
			int p = mn[now];
			while (now % p == 0) now /= p, num[p]++;
			mx[p] = max(mx[p], -num[p]);
		}
	}
	int f0 = 1;
	for (int i = 1; i <= cnt; i++) f0 = 1ll * f0 * ksm(prime[i], mx[prime[i]]) % mod;
	int ans = f0;
	for (int i = 1; i <= n; i++)
	{
		f0 = 1ll * f0 * b[i] % mod * ksm(a[i]) % mod;
		ans = qmod(ans + f0);
	} printf("%d\n", qmod(ans + ans));
	return 0;
}