ABC262Ex Max Limited Sequence 题解

发布时间 2023-05-09 18:51:26作者: 音街ウナ

题意

给定 \(m\) 个限制 \((l_i,r_i,p_i)\)\(n,k\),求满足以下条件的长度为 \(n\) 的不同序列 \(a=(a_1,a_2,\cdots,a_n)\) 的数目。

  1. \(\forall i \in[1,n],0\leq a_i\leq k\)

  2. \(\forall i \in[1,m],\max \limits_{j\in[l_i, r_i]}a_j=p_i\)

P4229,但数据更强,目测只允许 \(O(m\log m)\) 或类似复杂度通过。


考虑将条件 2 中每个限制拆分为:

  • \(\exists j\in[l_i,r_i],a_j=p_i\)
  • \(\forall j\in[l_i,r_i],a_j\le p_i\)

即任意位置 \(j\) 的取值最多不能超过覆盖它的各个限制中 \(p\) 的最小值 \(c_j=\min_{i\in[1,m],j\in[l_i,r_i]}p_i\),同时对于每个限制,至少有一个位置取到 \(p_i\);因此,取到 \(p_i\) 的位置 \(j\) 一定满足 \(p_i=a_j\)

对位置离散化后,利用线段树及标记永久化进行简单的区间取 \(\min\) 操作,则最后对于限制 \(i\)\(c_j=p_i\) 的位置 \(j\) 才有可能取值 \([0,p_i]\),而该限制覆盖的其它显然只能取到 \([0,q]\),且一定有 \(q<p\),故 \(a_j\le p_i\),即这些位置并不影响该限制的满足性。

综上所述,我们可以对 \(p\) 离散化,把覆盖后结果为 \(x\) 的部分单独取出并进行 DP,求出满足 \(p_i=x\) 的所有限制的方案数。所有 \(x\) 的方案数之积即为答案。

此处细节:

  1. 如果有 \(d\) 个位置不被任何限制覆盖,这些位置方案数为 \((k+1)^d\)
  2. 如果一个限制 \(i\) 满足 \(\forall j\in[l_i,r_i],c_j < p_i\),则总方案数为 \(0\)。线段树维护 \(c_j\) 区间最大值即可。

将所有 \(c_j=x\) 的位置及 \(l,r\) 离散化后作为一个序列 \(a'=(a'_1,a'_2,\cdots.a'_t)\) 集中考虑。注意有些位置离散化后是一个长度大于 \(1\) 的连续段,设其长度为 \(v_i\)

现在考虑如何在 \(O(t\log t)\) 时间内解决以下 DP 问题:

给定 \(m\) 个限制 \((l'_i,r'_i)\)\(x\),求满足以下条件的序列个数:

  • \(\forall i \in[1,t],\max \limits_{j\in[l'_i, r'_i]}a'_j=x\)

不妨设 \(f_{i}\) 为分配前 \(i\) 位的方案数,\(g_{i,j}\) 为分配前 \(i\) 位,使得满足 \(a'_k=x\) 的最大的 \(k\) 等于 \(j\) 的方案数。

显然如果第 \(i\) 位并非必须取 \(x\)\(\forall j \in [1,i],g_{i,j}=x^{v_i}g_{i-1,j}\)。这个操作可以用线段树区间乘完成。

如果第 \(i\) 位取 \(x\),即 \(g_{i,i} = ((x+1)^{v_i} - x^{v_i})f_{i-1}\)

如果不取,考虑若存在部分区间最后一个覆盖的位置为 \(i\),那么设这些区间左端点的最大值为 \(h_i\),则 \([h_i,i]\) 至少需取一个 \(x\)。所以 \(\forall j \in [1,h_i-1],g_{i,j} = 0\)。线段树维护区间推平即可。

综上所述,\(f_i=g_{i,i}+x^{v_i}\sum^{i-1}_{j=h_i}g_{i,j}\)。区间和也可用线段树维护。

此处细节:

  1. 需取模,计算可能出现负数。
  2. 离散化后序列相邻两项在原序列不一定连续。

目标:\(f_t\)。答案:\((k+1)^d\prod_{x\in p}f_t\)。时间复杂度 \(O(m\log m)\),空间线性。


#include <bits/stdc++.h>
#define maxm 400005
#define inf 0x3f3f3f3f
#define ad(i, j) i = (i % mod + j % mod + mod) % mod
#define mu(i, j) i = (i * j) % mod
#define ls(p) (p << 1)
#define rs(p) ((p << 1) + 1)
using namespace std;

const int mod = 998244353;
int n, m, k, num, nx, nf, b[maxm], bx[maxm];
int c[maxm];
struct Query {
    int l, r, x;
} a[maxm];
vector<int> q[maxm], v[maxm];
inline bool cmp(int i, int j) { return a[i].r == a[j].r ? a[i].l > a[j].l : a[i].r < a[j].r; }
long long ans, f[maxm], g[maxm];
// Basic
inline int len(int x) { return (x & 1) ? (b[(x >> 1) + 1] - b[x >> 1] - 1) : 1; }
inline int qpow(int x, int p) {
    if (!p) return 1;
    long long tx = qpow(x, p >> 1);
    return (p & 1) ? (tx * tx % mod * x % mod) : (tx * tx % mod);
}
// BIT
namespace Seg {
struct SegTree {
    long long sum, mul, cov;
} t[maxm * 4];
void build(int p, int l, int r) {
    t[p] = {g[l], 1, -1};
    if (l < r) {
        int mid = (l + r) >> 1;
        build(ls(p), l, mid), build(rs(p), mid + 1, r);
        t[p].sum = (t[ls(p)].sum + t[rs(p)].sum) % mod;
    }
}
inline void spread(int p) {
    if (~t[p].cov) {
        t[ls(p)].sum = t[ls(p)].cov = t[rs(p)].sum = t[rs(p)].cov = t[p].cov;
        t[ls(p)].mul = t[rs(p)].mul = 1;
        t[p].cov = -1;
    }
    if (t[p].mul > 1) {
        mu(t[ls(p)].sum, t[p].mul), mu(t[rs(p)].sum, t[p].mul);
        mu(t[ls(p)].mul, t[p].mul), mu(t[rs(p)].mul, t[p].mul);
        t[p].mul = 1;
    }
}
void change(int p, int pl, int pr, int l, int r, int x, int tg = 0) {
    if (l > r) return;
    if (pl >= l && pr <= r) {
        if (tg) t[p].mul = 1, t[p].sum = t[p].cov = x;
        else mu(t[p].mul, x), mu(t[p].sum, x);
    } else {
        spread(p);
        int mid = (pl + pr) >> 1;
        if (l <= mid) change(ls(p), pl, mid, l, r, x, tg);
        if (r > mid) change(rs(p), mid + 1, pr, l, r, x, tg);
        t[p].sum = (t[ls(p)].sum + t[rs(p)].sum) % mod;
    }
}
long long ask(int p, int pl, int pr, int l, int r) {
    if (l > r) return 0;
    else if (pl >= l && pr <= r) return t[p].sum;
    spread(p);
    int mid = (pl + pr) >> 1;
    long long ans = 0;
    if (l <= mid) ad(ans, ask(ls(p), pl, mid, l, r));
    if (r > mid) ad(ans, ask(rs(p), mid + 1, pr, l, r));
    return ans;
}
};
namespace Mn {
int ma[maxm * 4], tg[maxm * 4];
inline void init() { memset(tg, 0x3f, sizeof(tg)); }
void cover(int p, int pl, int pr, int l, int r, int x) {
    if (pl >= l && pr <= r) tg[p] = min(tg[p], x);
    else {
        int mid = (pl + pr) >> 1;
        if (l <= mid) cover(ls(p), pl, mid, l, r, x);
        if (r > mid) cover(rs(p), mid + 1, pr, l, r, x);
    }
}
void dfs(int p, int l, int r, int x) {
    int mid = (l + r) >> 1;
    x = min(x, tg[p]);
    if (l == r) c[l] = x, ma[p] = len(l) ? x : 0;
    else dfs(ls(p), l, mid, x), dfs(rs(p), mid + 1, r, x), ma[p] = max(ma[ls(p)], ma[rs(p)]);
}
int ask(int p, int pl, int pr, int l, int r) {
    if (pl >= l && pr <= r) return ma[p];
    int mid = (pl + pr) >> 1, ans = 0;
    if (l <= mid) ans = max(ans, ask(ls(p), pl, mid, l, r));
    if (r > mid) ans = max(ans, ask(rs(p), mid + 1, pr, l, r));
    return ans;
}
};
signed main() {
    ans = 1, num = nx = 0;
    scanf("%d%d%d", &n, &k, &m), ++k;
    for (int i = 1; i <= m; ++i) {
        scanf("%d%d%d", &a[i].l, &a[i].r, &a[i].x), ++a[i].x;
        b[++num] = a[i].l, b[++num] = a[i].r, bx[++nx] = a[i].x;
    }
    sort(b + 1, b + 1 + num), num = unique(b + 1, b + 1 + num) - b - 1;
    sort(bx + 1, bx + 1 + nx), nx = unique(bx + 1, bx + 1 + nx) - bx - 1;
    b[num + 1] = n + 1, n = num * 2 + 1;
    Mn::init();
    for (int i = 1; i <= m; ++i) {
        a[i].l = (lower_bound(b + 1, b + 1 + num, a[i].l) - b) << 1;
        a[i].r = (lower_bound(b + 1, b + 1 + num, a[i].r) - b) << 1;
        a[i].x = lower_bound(bx + 1, bx + 1 + nx, a[i].x) - bx;
        q[a[i].x].push_back(i);
        Mn::cover(1, 1, n, a[i].l, a[i].r, a[i].x);
    }
    Mn::dfs(1, 1, n, inf);
    // 3 m log m
    for (int i = 1; i <= m; ++i)
        if (Mn::ask(1, 1, n, a[i].l, a[i].r) < a[i].x) {
            puts("0");
            return 0;
        }
    for (int i = 1; i <= n; ++i)
        if (c[i] < inf && len(i)) v[c[i]].push_back(i);
        else if (len(i)) mu(ans, qpow(k, len(i)));
    for (int x = 1; x <= nx; ++x) {
        nf = v[x].size();
        for (int i = 1; i <= nf; ++i) f[i] = g[i] = 0;
        sort(q[x].begin(), q[x].end(), cmp);
        Seg::build(1, 1, nf);
        int p = -1;
        f[0] = 1;
        for (int i = 1; i <= nf; ++i) {
            int ln = len(v[x][i - 1]);
            long long t = f[i - 1] * ((qpow(bx[x], ln) - qpow(bx[x] - 1, ln) + mod) % mod) % mod;
            int nxt = i < nf ? v[x][i] : inf, l = 0;
            while (p < int(q[x].size() - 1) && a[q[x][p + 1]].r < nxt) ++p, l = max(l, a[q[x][p]].l);
            if (l) {
                auto lp = lower_bound(v[x].begin(), v[x].end(), l);
                l = lp - v[x].begin() + 1;
            }
            if (!l) f[i] = f[i - 1] * qpow(bx[x], ln) % mod;
            else {
                f[i] = (t + Seg::ask(1, 1, nf, l, i - 1) * qpow(bx[x] - 1, ln) % mod) % mod;
                Seg::change(1, 1, nf, 1, l - 1, 0, 1);
            
            }
            Seg::change(1, 1, nf, i, i, t, 1);
            Seg::change(1, 1, nf, 1, i - 1, qpow(bx[x] - 1, ln));
        }
        mu(ans, f[nf]);
    }
    printf("%lld\n", ans);
    return 0;
}