题解 P7468【[NOI Online 2021 提高组] 愤怒的小 N】

发布时间 2023-10-16 22:56:27作者: caijianhong

题解 P7468【[NOI Online 2021 提高组] 愤怒的小 N】

problem

首先是有一个字符串 \(S=\texttt{"0"}\),做无限次“将 \(S\) 的每一位取反接在 \(S\) 后面”的操作,形如 \(S=0110100110010110\cdots\)

另外给一个 \(k-1\) 次多项式 \(f\),求 \(\sum_{i=0}^{n-1}S_if(i).\)

\(n\leq 2^{5\times 10^5}, k\leq 500\)

solution 0

第一个观察是 \(S_i=parity(i)\)。因为每次将高位拿掉,值就反转。

考虑 dp。\(dp(i, j, 0/1)\) 表示 \([0,2^i)\)\(parity=0/1\) 的数字的 \(j\) 次方和。

转移

初值为 \(dp(0, j, 0)=[j=0]\) 表示只有 \(0\) 一个数字。

\[\begin{aligned} dp(i, j, e)&=dp(i-1, j, e)+\sum_{l=0, parity(l)\neq e}^{2^{i-1}-1}(l+2^{i-1})^j\\ &=dp(i-1, j, e)+\sum_{l=0, parity(l)\neq e}^{2^{i-1}-1}\sum_{t=0}^{j}\binom{t}{j}l^t(2^{i-1})^{j-t}\\ &=dp(i-1, j, e)+\sum_{t=0}^{j}\binom{j}{t}(2^{i-1})^{j-t}\sum_{l=0, parity(l)\neq e}^{2^{i-1}-1}l^t\\ &=dp(i-1, j, e)+\sum_{t=0}^{j}\binom{j}{t}(2^{i-1})^{j-t}dp(i-1, t, e\oplus 1)\\ \end{aligned} \]

统计答案

  • 取出 \(2^T=lowbit(n), L=n-2^T\)

  • 答案累加 \(\displaystyle\sum_{l=L, parity(l)\neq parity(L)}^{L+2^t-1}f(l)\)。注意这里 \(l-L, L\) 相加不进位,所以这玩意等于

  • \[\begin{aligned} \displaystyle\sum_{l=L, parity(l)\neq parity(L)}^{L+2^T-1}\sum_{j=0}^{k-1}f_j(l+L)^j &=\displaystyle\sum_{l=L, parity(l)\neq parity(L)}^{L+2^T-1}\sum_{j=0}^{k-1}\sum_{t=0}^{j}\binom{j}{t} f_jl^tL^{j-t}\\ &=\sum_{j=0}^{k-1}\sum_{t=0}^{j}\binom{j}{t} f_jL^{j-t}\displaystyle\sum_{l=L, parity(l)\neq parity(L)}^{L+2^T-1}l^t\\ &=\sum_{j=0}^{k-1}\sum_{t=0}^{j}\binom{j}{t} f_jL^{j-t}dp(T, t, parity(L)\oplus 1)\\ &=\sum_{t=0}^{k-1}dp(T, t, parity(L)\oplus 1)\sum_{j=t}^{k-1}\binom{j}{t} f_jL^{j-t}\\ \end{aligned} \]

  • \(n:=L\)

  • 明显枚举了所有区间。

optimize

现在的复杂度是 \(O(k^2\log n)\)

重量级结论是,\(i>j\)\(dp(i, j, 0)=dp(i, j, 1)=\frac{1}{2}\sum_{l=0}^{2^i-1}l^j\)。(怎么证明呢,待补,关键是对 \(i-1\to i\) 归纳,用二项式定理展开,考察各项系数)

换句话来说,对于 \(i>j\) 的一大段区间,我们直接求出整段区间的 \(f\) 的和,然后除以二就断定是区间的答案。这一大段区间,只算 \(i\geq k,j<k\) 的,就是 \(0\) 到 “\(n\) 的二进制表示中后面 \(k\) 为改成 \(0\)” 减一,于是可以计算。并观察到 \(f\) 的前缀和是 \(k-1\) 次多项式,考虑直接拉格朗日插值,\(O(k^2)-O(n+k)\) 完成这一部分。

可能发生 \(i<j\) 的区间,假定是 \(i<k\) 的,暴力计算是 \(O(k^3)\) 的。

所以总的复杂度是 \(O(\log n+k^3)\)。就是将其中一个很大的 \(\log n\) 用结论打成 \(k\)

code


#include <cstdio>
#include <vector>
#include <cstring>
#include <cassert>
#include <algorithm>
using namespace std;
#ifdef LOCAL
#define debug(...) fprintf(stderr, ##__VA_ARGS__)
#else
#define debug(...) void(0)
#endif
typedef long long LL;
template <unsigned P>
struct modint {
    unsigned v;
    modint() : v(0) {}
    template <class T>
    modint(T x) { x %= (int)P, v = x < 0 ? x + P : x; }
    modint operator+() const { return *this; }
    modint operator-() const { return modint(0) - *this; }
    modint inv() const { return assert(v), qpow(*this, P - 2); }
    friend int raw(const modint &self) { return self.v; }
    template <class T> friend modint qpow(modint a, T b) {
        modint r = 1;
        for (; b; b >>= 1, a *= a) if (b & 1) r *= a;
        return r;
    }
    modint &operator+=(const modint &rhs) { if (v += rhs.v, v >= P) v -= P; return *this; }
    modint &operator-=(const modint &rhs) { if (v -= rhs.v, v >= P) v += P; return *this; }
    modint &operator*=(const modint &rhs) { v = 1ull * v * rhs.v % P; return *this; }
    modint &operator/=(const modint &rhs) { return *this *= rhs.inv(); }
    friend modint operator+(modint lhs, const modint &rhs) { return lhs += rhs; }
    friend modint operator-(modint lhs, const modint &rhs) { return lhs -= rhs; }
    friend modint operator*(modint lhs, const modint &rhs) { return lhs *= rhs; }
    friend modint operator/(modint lhs, const modint &rhs) { return lhs /= rhs; }
    friend bool operator==(const modint &lhs, const modint &rhs) { return lhs.v == rhs.v; }
    friend bool operator!=(const modint &lhs, const modint &rhs) { return lhs.v != rhs.v; }
};
typedef modint<1000000007> mint;
vector<mint> multiple(const vector<mint> &a, const vector<mint> &b) {
    vector<mint> c(a.size() + b.size() - 1);
    for (int i = 0; i < a.size(); i++) {
        for (int j = 0; j < b.size(); j++) c[i + j] += a[i] * b[j];
    }
    return c;
}
vector<mint> addition(const vector<mint> &a, const vector<mint> &b) {
    vector<mint> c(max(a.size(), b.size()));
    for (int i = 0; i < a.size(); i++) c[i] += a[i];
    for (int i = 0; i < b.size(); i++) c[i] += b[i];
    return c;
}
vector<mint> divide(vector<mint> a, mint b1) {
    vector<mint> res(a.size() - 1);
    for (int i = (int) a.size() - 1; i >= 1; i--) {
        mint coe = res[i - 1] = a[i];
        a[i - 1] -= a[i] * b1;
    }
    return res;
}
vector<mint> numes[510];
mint idenos[510];
vector<mint> lagrange(const vector<mint> &a, const vector<mint> &b) {
    assert(a.size() == b.size());
    vector<mint> ans(a.size());
    for (int i = 0; i < a.size(); i++) {
        mint coe = b[i];
        for (int j = 0; j < a.size(); j++) ans[j] += numes[i][j] * coe;
    }
    return ans;
}
mint getValue(const vector<mint> &a, mint x) {
    mint res = 0;
    for (int i = (int) a.size() - 1; i >= 0; i--)
        res = res * x + a[i];
    return res;
}
int n, k;
char a[1 << 19];
vector<mint> f, sumG[510], sumF; //sumG[j](n) = sum{i=0..n-1} i^j
mint dp[510][510][2], qp2[1 << 19], binom[510][510];
const mint inv2 = 1 / mint(2);
void init() {
    for (int i = raw(qp2[0] = 1); i <= max(k * k, n); i++) qp2[i] = qp2[i - 1] + qp2[i - 1];
    for (int i = 0; i < k; i++) {
        binom[i][0] = 1;
        for (int j = 1; j <= i; j++) binom[i][j] = binom[i - 1][j] + binom[i - 1][j - 1];
    }
    vector<mint> per = {};
    for (int i = 1; i <= k + 1; i++) per.push_back(i);
    vector<mint> ans(per.size()), product = {1};
    for (int i = 0; i < per.size(); i++) 
        product = multiple(product, {-per[i], 1});
    for (int i = 0; i < per.size(); i++) {
        numes[i] = divide(product, -per[i]);
        idenos[i] = 1;
        for (int j = 0; j < per.size(); j++)
            if (i != j) idenos[i] *= per[i] - per[j];
        idenos[i] = 1 / idenos[i];
        for (int j = 0; j < per.size(); j++) numes[i][j] *= idenos[i];
    }
    for (int j = 0; j < k; j++) {//这一段没用,,,
        vector<mint> tmp = {};
        for (int i = 1; i <= k + 1; i++) tmp.push_back(qpow(mint(i - 1), j));
        for (int i = 1; i <= k; i++) tmp[i] += tmp[i - 1];
        sumG[j] = lagrange(per, tmp);
    }
    { 
        vector<mint> tmp = {};
        for (int i = 1; i <= k + 1; i++) tmp.push_back(getValue(f, i - 1));
        for (int i = 1; i <= k; i++) tmp[i] += tmp[i - 1];
        sumF = lagrange(per, tmp);
    }
}
void DP() {
    for (int j = 0; j < k; j++) dp[0][j][0] = !j;
    for (int i = 1; i < min(n, k); i++) {
    //for (int i = 1; i < n; i++) {
        memcpy(dp[i], dp[i - 1], sizeof dp[i]);
        for (int j = 0; j < k; j++) {
            for (int e: {0, 1}) {
                for (int t = 0; t <= j; t++) {
                    dp[i][j][e] += dp[i - 1][t][1 - e] * binom[j][t] * qp2[(i - 1) * (j - t)];
                }
            }
        }
    }
    //forall i > j, dp[i][j][e] = sumg[j](2^i) / 2
}
mint solve() {
    mint L = 0, ans = 0;
    bool flag = 0;
    if (n > k) {
        mint lim = 0;
        for (int i = n - 1; i >= k; i--) if (a[i]) lim += qp2[i];
        ans += getValue(sumF, lim) * inv2;
        for (int i = n - 1; i >= k; i--) if (a[i]) {
            L += qp2[i], flag ^= 1;
        }
    }
    for (int i = min(k, n) - 1; i >= 0; i--) if (a[i]) {
        for (int t = 0; t < k; t++) {
            mint coe = 0, now = 1;
            for (int j = t; j < k; j++, now *= L) 
                coe += binom[j][t] * f[j] * now;
            ans += dp[i][t][flag ^ 1] * coe;
        }
        L += qp2[i], flag ^= 1;
    }
    return ans;
}
int main() {
    scanf("%s%d", a, &k), n = strlen(a);
    for (int i = 0; i < n; i++) a[i] -= '0';
    reverse(a, a + n);
    f = vector<mint>(k);
    for (int i = 0; i < k; i++) scanf("%u", &f[i].v);
    init(), DP();
    printf("%d\n", raw(solve()));
    return 0;
}