【模板】多项式乘法、乘法逆、除法、取模、常系数齐次线性递推

发布时间 2023-09-24 21:42:55作者: caijianhong

以下代码必须开 -O2

#include <algorithm>
#include <cassert>
#include <cstdio>
#include <cstring>
#include <vector>
using namespace std;
#ifdef LOCAL
#define debug(...) fprintf(stderr, ##__VA_ARGS__)
#else
#define debug(...) void(0)
#endif
typedef long long LL;
template <unsigned P> struct modint {
    unsigned v; modint() : v(0) {}
    template <class T> modint(T x) { x %= (int)P, v = x < 0 ? x + P : x; }
    modint operator+() const { return *this; }
    modint operator-() const { return modint(0) - *this; }
    modint inv() const { return assert(v), qpow(*this, P - 2); }
    friend int raw(const modint &self) { return self.v; }
    template <class T> friend modint qpow(modint a, T b) {
        modint r = 1;
        for (; b; b >>= 1, a *= a) if (b & 1) r *= a;
        return r;
    }
    modint &operator+=(const modint &rhs) { if (v += rhs.v, v >= P) v -= P; return *this; }
    modint &operator-=(const modint &rhs) { if (v -= rhs.v, v >= P) v += P; return *this; }
    modint &operator*=(const modint &rhs) { v = 1ull * v * rhs.v % P; return *this; }
    modint &operator/=(const modint &rhs) { return *this *= rhs.inv(); }
    friend modint operator+(modint lhs, const modint &rhs) { return lhs += rhs; }
    friend modint operator-(modint lhs, const modint &rhs) { return lhs -= rhs; }
    friend modint operator*(modint lhs, const modint &rhs) { return lhs *= rhs; }
    friend modint operator/(modint lhs, const modint &rhs) { return lhs /= rhs; }
    friend bool operator==(const modint &lhs, const modint &rhs) { return lhs.v == rhs.v; }
    friend bool operator!=(const modint &lhs, const modint &rhs) { return lhs.v != rhs.v; }
};
typedef modint<998244353> mint;
const int glim(const int &x){return 1 << (32 - __builtin_clz(x));}
const int bitctz(const int &x){return __builtin_ctz(x);}
const vector<mint> wns = []() -> vector<mint> {
    vector<mint> wns = {};
    for (int j = 1; j <= 23; j++)
        wns.push_back(qpow(mint(3), (998244353 - 1) >> j));
    return wns;
}();
void ntt(vector<mint> &a, const int &op) {
    const int n = a.size();
    for (int i = 1, r = 0; i < n; i++) {
        r ^= n - (1 << (bitctz(n) - bitctz(i) - 1));
        if (i < r) swap(a[i], a[r]);
    }
    vector<mint> w(n);
    for (int k = 1, len = 2; len <= n; k <<= 1, len <<= 1) {
        const mint wn = wns[bitctz(k)];
        for (int i = raw(w[0] = 1); i < k; i++) w[i] = w[i - 1] * wn;
        for (int i = 0; i < n; i += len) {
            for (int j = 0; j < k; j++) {
                const mint x = a[i + j], y = a[i + j + k] * w[j];
                a[i + j] = x + y, a[i + j + k] = x - y;
            }
        }
    }
    if (op == -1) {
        mint iz = mint(1) / n;
        for (int i = 0; i < n; i++) a[i] *= iz;
        reverse(a.begin() + 1, a.end());
    }
}
vector<mint> getInv(const vector<mint> &a, int lim) {
    vector<mint> b = {1 / a[0]};
    for (int len = 2; len <= glim(lim); len <<= 1) {
        vector<mint> c(a.begin(), a.begin() + min(len, (int)a.size()));
        b.resize(len << 1), ntt(b, 1);
        c.resize(len << 1), ntt(c, 1);
        for (int i = 0; i < len << 1; i++)
            b[i] = b[i] * (2 - c[i] * b[i]);
        ntt(b, -1), b.resize(len);
    }
    b.resize(lim);
    return b;
}
vector<mint> multiple(vector<mint> a, vector<mint> b) {
    int rLen = a.size() + b.size() - 1, len = glim(rLen);
    a.resize(len), ntt(a, 1);
    b.resize(len), ntt(b, 1);
    for (int i = 0; i < len; i++) a[i] *= b[i];
    ntt(a, -1), a.resize(rLen);
    return a;
}
vector<mint> divide(vector<mint> f, vector<mint> g) {
    if (f.size() < g.size()) return {};
    int rLen = f.size() - g.size() + 1;
    reverse(f.begin(), f.end());
    reverse(g.begin(), g.end());
    f = multiple(f, getInv(g, rLen));
    f.resize(rLen), reverse(f.begin(), f.end());
    return f;
}
vector<mint> modulo(vector<mint> f, vector<mint> g) {
    int rLen = g.size() - 1;
    vector<mint> q = multiple(g, divide(f, g));
    q.resize(rLen), f.resize(rLen);
    for (int i = 0; i < rLen; i++) f[i] -= q[i];
    return f;
}
vector<mint> qpow(vector<mint> a, int b, vector<mint> m) {
    vector<mint> r = {1};
    for (; b; b >>= 1, a = modulo(multiple(a, a), m)) {
        if (b & 1) r = modulo(multiple(r, a), m);
    }
    return r;
}
int main() {
    int n, k;
    scanf("%d%d", &n, &k);
    vector<mint> m(k + 1), a(k);
    m[k] = 1;
    for (int i = k - 1, x; i >= 0; i--) scanf("%d", &x), m[i] = -x;
    for (int i = 0, x; i < k; i++) scanf("%d", &x), a[i] = x;
    vector<mint> b = qpow({0, 1}, n, m);
    mint ans = 0;
    for (int i = 0; i < k; i++) ans += b[i] * a[i];
    printf("%d\n", raw(ans));
    return 0; 
}