AtCoder Regular Contest 139 E Wazir

发布时间 2023-05-26 16:05:36作者: zltzlt

洛谷传送门

AtCoder 传送门

好题。

这种题一般可以考虑,观察最优解的性质,对于性质计数。

发现如果 \(n,m\) 均为偶数,可以放满。就是类似这样:

#.#.#.
.#.#.#
#.#.#.
.#.#.#

因此答案就是 \(2\)

如果 \(n,m\) 有一个为偶数,不妨假设 \(n\) 为偶数。那么最优解形似:

#.#..
.#..#
#..#.
..#.#

可以发现答案是 \(n \times \frac{m - 1}{2}\),并且一行有且仅有两个连续格子是 .

那么对于 \(n,m\) 都是奇数的情况,答案是 \(\max(n \times \frac{m - 1}{2}, m \times \frac{n - 1}{2})\)。不妨假设 \(n \ge m\),那么答案是 \(n \times \frac{m - 1}{2}\),性质同上。

下面假设最优解是 \(n\) 行每行放 \(\frac{m - 1}{2}\) 个。

考虑将一个方阵映射到一个比较好计数的序列。设 \(a_i\) 为第 \(i\) 行,连续两个 . 的位置。那么实际上要满足:

\[\begin{cases} \forall i \in [1, n], a_i - a_{i + 1} \equiv \pm 1 \pmod{m} \\ a_1 = a_{n + 1} \end{cases} \]

考虑对于 \(b_i = a_i - a_{i + 1} = \pm 1\) 计数,这样要求 \(\sum\limits_{i=1}^n b_i = 0\)

注意因为 \(n,m\) 可能被 swap 过,所以只能保证 \(\min(n, m) \le 10^5\)

  • \(n \le 10^5\),枚举有多少个 \(a_i - a_{i + 1} = 1\),多少个 \(= -1\),用组合数算;
  • \(m \le 10^5\),考虑构造多项式 \(f(x) = (x + x^{-1})^n \pmod{x^m - 1}\),那么就是要求 \(f(x)\) 的常数项。感性理解,每次可以给指数 \(+1\)\(-1\),最后要求指数 \(= 0\)。这个可以多项式快速幂算。

注意最后答案要 \(\times m\),因为 \(a_1\) 任意。

时间复杂度 \(O(\min(n, m) \log n \log m)\)

code
// Problem: E - Wazir
// Contest: AtCoder - AtCoder Regular Contest 139
// URL: https://atcoder.jp/contests/arc139/tasks/arc139_e
// Memory Limit: 1024 MB
// Time Limit: 10000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef double db;
typedef long double ldb;
typedef pair<ll, ll> pii;

const int maxn = 1000100;
const int N = 1000000;
const ll mod = 998244353, G = 3;

inline ll qpow(ll b, ll p) {
	ll res = 1;
	while (p) {
		if (p & 1) {
			res = res * b % mod;
		}
		b = b * b % mod;
		p >>= 1;
	}
	return res;
}

ll n, m, fac[maxn], ifac[maxn], r[maxn];

void init() {
	fac[0] = 1;
	for (int i = 1; i <= N; ++i) {
		fac[i] = fac[i - 1] * i % mod;
	}
	ifac[N] = qpow(fac[N], mod - 2);
	for (int i = N - 1; ~i; --i) {
		ifac[i] = ifac[i + 1] * (i + 1) % mod;
	}
}

inline ll C(ll n, ll m) {
	if (n < m || n < 0 || m < 0) {
		return 0;
	} else {
		return fac[n] * ifac[m] % mod * ifac[n - m] % mod;
	}
}

typedef vector<ll> poly;

inline poly NTT(poly a, int op) {
	int n = (int)a.size();
	for (int i = 0; i < n; ++i) {
		if (i < r[i]) {
			swap(a[i], a[r[i]]);
		}
	}
	for (int k = 1; k < n; k <<= 1) {
		ll wn = qpow(op == 1 ? G : qpow(G, mod - 2), (mod - 1) / (k << 1));
		for (int i = 0; i < n; i += (k << 1)) {
			ll w = 1;
			for (int j = 0; j < k; ++j, w = w * wn % mod) {
				ll x = a[i + j], y = w * a[i + j + k] % mod;
				a[i + j] = (x + y) % mod;
				a[i + j + k] = (x - y + mod) % mod;
			}
		}
	}
	return a;
}

inline poly operator * (poly a, poly b) {
	a = NTT(a, 1);
	b = NTT(b, 1);
	int n = (int)a.size();
	for (int i = 0; i < n; ++i) {
		a[i] = a[i] * b[i] % mod;
	}
	a = NTT(a, -1);
	ll inv = qpow(n, mod - 2);
	for (int i = 0; i < n; ++i) {
		a[i] = a[i] * inv % mod;
	}
	return a;
}

inline poly qpow(poly a, ll p) {
	int n = (int)a.size();
	poly res(n);
	res[0] = 1;
	while (p) {
		if (p & 1) {
			res = res * a;
			for (int i = m; i <= m * 2; ++i) {
				res[i % m] = (res[i % m] + res[i]) % mod;
				res[i] = 0;
			}
		}
		a = a * a;
		for (int i = m; i <= m * 2; ++i) {
			a[i % m] = (a[i % m] + a[i]) % mod;
			a[i] = 0;
		}
		p >>= 1;
	}
	return res;
}

void solve() {
	scanf("%lld%lld", &n, &m);
	if (n % 2 == 0 && m % 2 == 0) {
		puts("2");
		return;
	}
	if (m % 2 == 0 || ((n & 1) && (m & 1) && n < m)) {
		swap(n, m);
	}
	ll ans = 0;
	if (n <= 100000) {
		for (int i = 0; i <= n; ++i) {
			if ((i - (n - i)) % m == 0) {
				ans = (ans + C(n, i)) % mod;
			}
		}
	} else {
		int k = 0;
		while ((1 << k) <= m * 2) {
			++k;
		}
		for (int i = 1; i < (1 << k); ++i) {
			r[i] = (r[i >> 1] >> 1) | ((i & 1) << (k - 1));
		}
		poly A(1 << k);
		A[1] = A[m - 1] = 1;
		poly res = qpow(A, n);
		ans = res[0];
	}
	printf("%lld\n", ans * (m % mod) % mod);
}

int main() {
	init();
	int T = 1;
	// scanf("%d", &T);
	while (T--) {
		solve();
	}
	return 0;
}