[CF958F3] Lightsabers (hard)

发布时间 2023-12-09 09:45:37作者: 徐子洋

题目链接

对于一种元素 \(v\),假设它在给出可重集合中出现了 \(t\) 次,那么容易把它表示成基础的生成函数形式:\(1+x+x^2+x^3+\dots+x^t\)

显然,把所有元素的生成函数卷一下就是答案。但是这样最坏情况为 \(O(nm\log n)\)的,不能通过这道题。

在思考优化方式时,容易想到启发式合并来优化这个过程。但是启发式合并本质上就是对分治的同一层区间进行无序合并,二者复杂度相同,故而我采取了比较好写的分治。

点击查看代码
#include <bits/stdc++.h>
namespace Poly{
    constexpr int N = 2.7e5 + 10;
    constexpr double Pi = acos(-1);
    int n, m, p[N];
    std::complex<double> f[N], g[N];
    void FFT(int n, std::complex<double> *c, int x = 1){
        for(int i = 0; i < n; ++i) if(i < p[i]) std::swap(c[i], c[p[i]]);
        for(int b = 2, k = 1; b <= n; b <<= 1, k <<= 1){
            std::complex<double> t(cos(2 * Pi / b), sin(2 * Pi / b) * x), w(1, 0);
            for(int i = 0; i < n; i += b, w = 1){
                for(int j = 0; j < k; ++j, w *= t){
                    c[i + j] += c[i + j + k] * w;
                    c[i + j + k] = c[i + j] - c[i + j + k] * w - c[i + j + k] * w;
                }
            }
        }
    }
    std::vector<double> Convolution(const auto &a, const auto &b){
        if(!a.size() || !b.size()) return std::vector<double>();
        n = a.size(), m = b.size();
        int l = 1 << (int)ceil(log2(n + m - 1));
        for(int i = 0; i < l; ++i){
            f[i] = i < n? a[i] : 0, g[i] = i < m? b[i] : 0;
            p[i] = (p[i >> 1] >> 1) | (i & 1? l >> 1 : 0);
        }
        n += m - 1, FFT(l, f), FFT(l, g);
        for(int i = 0; i < l; ++i) f[i] *= g[i];
        FFT(l, f, -1); std::vector<double> ret;
        for(int i = 0; i < n; ++i)
            ret.emplace_back(f[i].real() / l);
        return ret;
    }
}
#define FL(i, a, b) for(int i = (a); i <= (b); ++i)
#define FR(i, a, b) for(int i = (a); i >= (b); --i)
typedef std::vector<int> vi;
using Poly::Convolution;
constexpr int N = 2e5 + 10, P = 1009;
int n, m, k, cnt[N];
vi Solve(int l, int r){
    if(l == r) return vi(cnt[l] + 1, 1);
    int mid = l + r >> 1;
    vi t1 = Solve(l, mid), t2 = Solve(mid + 1, r), ret;
    auto t = Convolution(t1, t2);
    for(auto x: t) ret.emplace_back((long long)round(x) % P);
    vi().swap(t1), vi().swap(t2), std::vector<double>().swap(t);
    return move(ret);
}
int main(){
    scanf("%d%d%d", &n, &m, &k); int x;
    FL(i, 1, n) scanf("%d", &x), ++cnt[x];
    vi ans = Solve(1, m);
    printf("%d\n", ans[k]);
    return 0;
}