ARC134F Flipping Coins

发布时间 2023-07-23 21:46:22作者: Ender_32k

pb 讲课没讲的题,感觉很牛逼啊!但不是牛逼在多项式,因为多项式大家应该都会。

考虑从前往后扫的过程,只要有正面就翻成反面,所以最后只有可能是当 \(p_i<i\)\(i\) 没有被翻面时才对 \(k\) 有贡献。那么考虑一条链 \(1\to 2\to \cdots \to m\),并且 \(\forall 1\le i< m,p_i=i+1\),那么此时翻了 \(1\) 就会翻 \(2\),翻了 \(3\) 就会翻 \(4\)……不难发现此时当且仅当 \(2\nmid m\) 才对 \(k\)\(1\) 的贡献。

于是考虑把排列分成若干置换环,并单独考虑每一个环。不难发现,我们在每个置换环的最小值处断环成链,最后这个环对 \(k\) 的贡献就是这条链上正面朝上的硬币个数,就是这条链长度为奇数的极长上升子段个数

把所有链按照最小值从大到小进行排序,然后排成一排变成排列 \(q\)。不难发现 \(p,q\)一一对应的,同时 \(p\) 对应的 \(k\) 就是 \(q\) 长度为奇数的极长上升子段个数。

接下来是一个很神仙的转化:把每个极长上升子段拆成 \(\mathtt{ABABAB\cdots C}\) 的形式,即分成若干长度为 \(2\)\(1\) 的小段,长度为 \(2\) 的标为 \(\mathtt{AB}\),长度为 \(1\) 的标为 \(\mathtt{C}\)\(k\) 就是长为 \(1\) 的小段即 \(\mathtt{C}\) 的个数。

发现我们所需要满足的条件如下:

  • 由于 \(\mathtt{ABABAB\cdots}\) 形如上升状子段,所以每个 \(\mathtt{A}\) 代表的数都比后面相邻的 \(\mathtt{B}\) 小。
  • \(\mathtt{B}\) 是没有限制的,因为它可以作为长度为偶数的极长上升子段末尾,也可以在后面接上一个单独的一个 \(\mathtt{C}\)
  • 每个 \(\mathtt{C}\) 都要大于后面的数,不管是 \(\mathtt{A}\) 还是 \(\mathtt{C}\),否则可以匹配成 \(\mathtt{AB}\)

所以我们以每个 \(\mathtt{B}\) 为末尾断开,把序列分成形如 \((\mathtt{CC\cdots CAB})(\mathtt{CC\cdots CAB})(\mathtt{CC\cdots C})\) 的形式。

然后考虑生成函数,列出整个排列的 EGF \(F(x)\),形如 \(\mathtt{CC\cdots CAB}\) 的 EGF \(G(x)=\sum\limits_{i=2}(i-1)w^{i-2}\dfrac{x_i}{i!}\),形如 \(\mathtt{CC\cdots C}\) 的 EGF \(H(x)=\sum\limits_{i=0}w^i\dfrac{x^i}{i!}\),然后显然有 \(F(x)=F(x)G(x)+H(x)\),即 \(F(x)=\frac{H(x)}{1-G(x)}\),求逆即可。

(为什么是 EGF 傻子都知道,不要再过来找我问了。)

当然直接推朴素的 dp 应该有一个可以分治 NTT 的形式,这个也能做,不过本质上其实就是 EGF 卷积。

然后找个板子过来就 \(O(n\log^2n)\) 或者 \(O(n\log n)\) 随便做了。

// Problem: F - Flipping Coins
// Contest: AtCoder - AtCoder Regular Contest 134
// URL: https://atcoder.jp/contests/arc134/tasks/arc134_f
// Memory Limit: 1024 MB
// Time Limit: 5000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define int long long
using namespace std;

namespace vbzIO {
	char ibuf[(1 << 20) + 1], *iS, *iT;
	#if ONLINE_JUDGE
	#define gh() (iS == iT ? iT = (iS = ibuf) + fread(ibuf, 1, (1 << 20) + 1, stdin), (iS == iT ? EOF : *iS++) : *iS++)
	#else
	#define gh() getchar()
	#endif
	#define mt make_tuple
	#define mp make_pair
	#define fi first
	#define se second
	#define pc putchar
	#define pb push_back
	#define ins insert
	#define era erase
	#define bg begin
	#define rbg rbegin
	typedef tuple<int, int, int> tu3;
	typedef pair<int, int> pi;
	inline int rd() {
		char ch = gh();
		int x = 0;
		bool t = 0;
		while (ch < '0' || ch > '9') t |= ch == '-', ch = gh();
		while (ch >= '0' && ch <= '9') x = (x << 1) + (x << 3) + (ch ^ 48), ch = gh();
		return t ? ~(x - 1) : x;
	}
	inline void wr(int x) {
		if (x < 0) {
			x = ~(x - 1);
			putchar('-');
		}
		if (x > 9)
			wr(x / 10);
		putchar(x % 10 + '0');
	}
}
using namespace vbzIO;

const int N = 3e6 + 300;
const int P = 998244353;
const int g = 114514;

int n, w, tr[N];
int f[N], h[N], fac[N], inv[N], ifac[N];

int qpow(int p, int q) {
    int res = 1;
    while (q) {
        if (q & 1) res = res * p % P;
        p = p * p % P, q >>= 1;
    }
    return res;
}

void init(int lim) {
	fac[0] = ifac[0] = inv[1] = 1;
	for (int i = 1; i <= lim; i++) {
		if (i > 1) inv[i] = inv[P % i] * (P - P / i) % P;
		fac[i] = fac[i - 1] * i % P, ifac[i] = ifac[i - 1] * inv[i] % P;
	}
}

const int gi = qpow(g, P - 2);

void NTT(int *f, int lim, int op){
    for (int i = 0; i < lim; i++) 
        if (i < tr[i]) swap(f[i], f[tr[i]]);
    for (int o = 2, k = 1; k < lim; o <<= 1, k <<= 1) {
    	int tg = qpow(~op ? g : gi, (P - 1) / o);
    	for (int i = 0; i < lim; i += o) {
    		for (int j = 0, w = 1; j < k; j++, (w *= tg) %= P) {
    			int x = f[i + j];
    			int y = w * f[i + j + k] % P;
    			f[i + j] = (x + y) % P;
    			f[i + j + k] = (x - y + P) % P;
			}
		}
	}
	if (op == 1) return;
	int iv = qpow(lim, P - 2);
	for (int i = 0; i < lim; i++) 
		(f[i] *= iv) %= P;
}

void INV(int *f, int *g, int n) {
	if (n == 1) return g[0] = qpow(f[0], P - 2), void();
	INV(f, g, (n + 1) >> 1);
	static int t[N];
	int lim = 1, len = 0;
	while (lim < (n << 1)) lim <<= 1, len++;
    for (int i = 0; i < lim; i++) 
		tr[i] = (tr[i >> 1] >> 1) | ((i & 1) << (len - 1));
	for (int i = 0; i < n; i++) t[i] = f[i];
	for (int i = n; i < lim; i++) t[i] = 0;
	NTT(t, lim, 1), NTT(g, lim, 1);
	for (int i = 0; i < lim; i++) 
		g[i] = (2 - g[i] * t[i] % P + P) % P * g[i] % P;
	NTT(g, lim, -1);
	for (int i = n; i < lim; i++) g[i] = 0;
}

int C(int n, int m) {
	if (n < m || m < 0) return 0;
	return fac[n] * ifac[m] % P * ifac[n - m] % P;
}

signed main() {
    n = rd(), w = rd(), init(n);
    f[0] = 1;
    for (int i = 2; i <= n; i++) 
    	f[i] = (P - 1) * ifac[i] % P * (i - 1) % P * qpow(w, i - 2) % P;
    INV(f, h, n + 1);
    for (int i = 0; i <= n; i++) 
    	(h[i] *= fac[i]) %= P;
    int res = 0;
    for (int i = 0; i <= n; i++) 
    	(res += C(n, i) * qpow(w, i) % P * h[n - i] % P) %= P;
    wr(res);
    return 0;
}