多项式小全家桶

发布时间 2023-08-28 17:00:43作者: _Famiglistimo

比较安全的模板,传入的数组 \(g\) 有初值也没有问题,且求解过程中不会对传入的 \(f\) 修改

#include <bits/stdc++.h>
using namespace std;
const int N = 1 << 17;
const int mod = 998244353;

bool mem1;
char buf[1 << 23], *p1 = buf, *p2 = buf;
#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1 ++)
int read() {
    int s = 0, w = 1; char ch = getchar();
    while(!isdigit(ch)) { if(ch == '-') w = -1; ch = getchar(); }
    while(isdigit(ch)) { s = s * 10 + (ch ^ 48), ch = getchar(); }
    return s * w;
}

template <typename A> int mul(A x) { return x; }
template <typename A, typename ...B> int mul(A x, B ...args) { return 1ll * x * mul(args ...) % mod; }
void inc(int &a, int b) { a = a >= mod - b ? a - mod + b : a + b; }
int ksm(int a, int b) {
	int res = 1;
	while(b > 0) {
		if(b & 1) res = mul(res, a);
		a = mul(a, a), b >>= 1;
	}
	return res;
}

int f[N], g[N], rev[N], iv[N];
int inv2 = ksm(2, mod - 2), inv3 = ksm(3, mod - 2);

void ntt(int *f, int op, int len) {
	for(int i = 0; i < len; ++ i)
		if(i < rev[i]) swap(f[i], f[rev[i]]);
	for(int i = 2; i <= len; i <<= 1) {
		int base = ksm(op == 1 ? 3 : inv3, (mod - 1) / i);
		for(int j = 0, p = i >> 1; j < len; j += i) 
			for(int k = 0, pw = 1; k < p; ++ k) {
				int x = f[j + k], y = mul(pw, f[j + k + p]);

				f[j + k] = (x + y) % mod, f[j + k + p] = (x - y + mod) % mod;
				pw = mul(pw, base);
			}
	}
	if(op == -1)
		for(int i = 0, inv = ksm(len, mod - 2); i < len; ++ i)
			f[i] = mul(f[i], inv);
}

int f_in[N], g_in[N];
void inv(int *f, int *g, int n) {
	int len;
	g[0] = ksm(f[0], mod - 2);
	for(len = 1; len < (n << 1); len <<= 1) {
		int lim = (len << 1);
		for(int i = 0; i < len; ++ i) g_in[i] = g[i], f_in[i] = f[i];

		for(int i = 0; i < lim; ++ i)
			rev[i] = (rev[i >> 1] >> 1) | ((i & 1) * len);
		ntt(f_in, 1, lim), ntt(g_in, 1, lim);
		for(int i = 0; i < lim; ++ i)
			f_in[i] = mul(g_in[i], (2ll - mul(f_in[i], g_in[i]) + mod) % mod);
		ntt(f_in, -1, lim);
		
		for(int i = 0; i < lim; ++ i) g[i] = f_in[i];
		for(int i = len; i < lim; ++ i) g[i] = 0;
	}
	for(int i = n; i < len; ++ i) g[i] = 0;
	for(int i = 0; i < len; ++ i) f_in[i] = g_in[i] = 0;
}

int f_sq[N], g_sq[N];
void sqrt(int *f, int *g, int n) {
	int len;
	g[0] = 1;
	for(len = 1; len < (n << 1); len <<= 1) {
		int lim = (len << 1);
		for(int i = 0; i < len; ++ i) f_sq[i] = f[i];
		inv(g, g_sq, len);

		for(int i = 0; i < lim; ++ i) 
			rev[i] = (rev[i >> 1] >> 1) | ((i & 1) * len);
		ntt(f_sq, 1, lim), ntt(g_sq, 1, lim);
		for(int i = 0; i < lim; ++ i) f_sq[i] = mul(f_sq[i], g_sq[i]);
		ntt(f_sq, -1, lim);

		for(int i = 0; i < lim; ++ i) g[i] = mul((f_sq[i] + g[i]) % mod, inv2);
		for(int i = len; i < lim; ++ i) g[i] = 0;
	}
	for(int i = n; i < len; ++ i) g[i] = 0;
	for(int i = 0; i < len; ++ i) f_sq[i] = g_sq[i] = 0;
}

void deriv(int *f, int *g, int n) {
	for(int i = 0; i < n; ++ i)
		g[i] = mul(i + 1, f[i + 1]);
	g[n - 1] = 0;
}
void inter(int *f, int *g, int n) {
	for(int i = n - 2; i >= 0; -- i)
		g[i + 1] = mul(f[i], iv[i + 1]);
	g[0] = 0;
}

int f_ln[N], g_ln[N];
void ln(int *f, int *g, int n) {
	deriv(f, f_ln, n);
	inv(f, g_ln, n);
	
	int lim = 1, bit = 0;
	while(lim <= 2 * n - 2) lim <<= 1, ++ bit;

	for(int i = 0; i < lim; ++ i)
		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << bit - 1);
	ntt(f_ln, 1, lim), ntt(g_ln, 1, lim);
	for(int i = 0; i < lim; ++ i)
		f_ln[i] = mul(f_ln[i], g_ln[i]);
	ntt(f_ln, -1, lim);
	for(int i = 0; i < n; ++ i) g[i] = f_ln[i];
	for(int i = 0; i < lim; ++ i) f_ln[i] = g_ln[i] = 0;

	inter(g, g, n);
}

int f_ex[N], g_ex[N];
void exp(int *f, int *g, int n) {
	int len;
	g[0] = 1;
	for(len = 1; len < (n << 1); len <<= 1) {
		int lim = (len << 1);
		ln(g, g_ex, len);
		for(int i = 0; i < len; ++ i)
			f_ex[i] = g[i], g_ex[i] = (f[i] - g_ex[i] + mod) % mod;
		++ g_ex[0];

		for(int i = 0; i < lim; ++ i) 
			rev[i] = (rev[i >> 1] >> 1) | ((i & 1) * len);
		ntt(f_ex, 1, lim), ntt(g_ex, 1, lim);
		for(int i = 0; i < lim; ++ i) f_ex[i] = mul(f_ex[i], g_ex[i]);
		ntt(f_ex, -1, lim);
		
		for(int i = 0; i < lim; ++ i) g[i] = f_ex[i];
		for(int i = len; i < lim; ++ i) g[i] = 0;
	}
	for(int i = n; i < len; ++ i) g[i] = 0;
	for(int i = 0; i < len; ++ i) f_ex[i] = g_ex[i] = 0;
}

int n, m;

bool mem2;
signed main() {
	cerr << (&mem2 - &mem1) / 1048576. << endl;
	iv[0] = iv[1] = 1;
	for(int i = 2; i < N; ++ i)
		iv[i] = mul(mod - mod / i, iv[mod % i]);

	n = read();
	for(int i = 0; i < n; ++ i) f[i] = read();
	exp(f, g, n);
	for(int i = 0; i < n; ++ i) printf("%d ", g[i]);
	return 0;
}