题解 CF1651F【Tower Defense】

发布时间 2023-10-18 20:47:51作者: caijianhong

题解 CF1651F【Tower Defense】

problem

一个塔防游戏。

一共有 \(n\) 个塔按 \(1 \sim n\) 的顺序排成一列,每座塔都有魔力容量 \(c_i\) 和魔力恢复速率 \(r_i\)。对于一座塔 \(i\),每过一秒它的魔力 \(m_i\) 会变为 \(\min(m_i+r_i, c_i)\)。每座塔初始时满魔力。

一共有 \(q\) 个怪物,每个怪物有两个属性 \(t_i\)\(h_i\),表示这个怪物会在第 \(t_i\) 秒出现在第一座塔前面。当它到一座塔 \(j\) 面前时,自己的血量 \(h_i\) 会减少 \(\min(h_i,m_j)\),塔的魔力也会减少这个数。当怪物血量 \(h_i=0\) 时停止移动,否则它下一秒会移动到下一座塔。

有些怪物在经过塔 \(n\) 后血量仍未减少至 \(0\),请你求出这样的怪物最终剩下的血量总和。

  • \(1\le n,q\le2\times10^5\)
  • \(1\le c_i,r_i\le 10^9\)
  • \(1\le h_i\le10^{12}\)
  • \(0\le t_i\le 2\times10^5\)
  • \(\forall 1\le j<q,t_j<t_{j+1}\)

solution 1

小恐龙(即题目中的怪物)的路线就是:有很长的一段前缀,小恐龙直接推平塔,把塔全部清零;在一座塔前面突然被塔打死,停在这里,或者走完所有塔。另外注意到所谓“下一秒会移动到下一座塔”是假的,因为小恐龙的速度是一样的,大家都提前 \(t\) 时刻的结果是一样的,不如让大家到达 \(i\) 号塔的时间都提前 \(i\),忽略这潜在的顺序为问题。

考虑分块维护每个塔,每个块的状态只有:在之前的某个时刻被推平,在之前的某个时刻有小恐龙被打停。其中后者只会出现 \(O(q)\) 个,因此对于这样的块直接暴力。

我们现在要解决的问题,就是给你 \(B\) 个塔其中 \(B\) 是块长,求这些塔从全零开始经过 \(t\) 时刻后的总和。因为块数太少,时间范围只有 \(2\times 10^5\),考虑直接预处理这个东西。每个塔的贡献,形如:

  • \(1\sim \left\lfloor c_i/r_i\right\rfloor\) 的时刻,塔的魔力都是增加 \(r_i\)
  • 在时刻 \(\left\lfloor c_i/r_i\right\rfloor + 1\),塔的魔力增加 \(c_i\bmod r_i\)
  • 此后塔的增长停止。

考虑将第一种用差分前缀和做一次,再加入第二种,再做前缀和。这样的复杂度是 \(O(nT/B+n/B+qB)\) 其中 \(T=2\times 10^5\) 是时间范围。使得 \(B=\sqrt n\) 的时间复杂度趋于平衡,但是有空间的问题,考虑使得 \(B=1024\) 即可。

solution 2

我们的时间复杂度瓶颈在于求出“这些塔从全零开始经过 \(t\) 时刻后的总和”。考虑到塔关于时间的魔力值是分两段的函数,使用可持久化线段树维护分段函数和。具体地,若记 \(f_i(x)\) 表示第 \(i\) 个塔在 \(x\) 时刻的魔力值,则可持久化线段树上的版本 \(t\) 维护了 \(f_i(t)\) 的区间和。从版本 \(t-1\) 转移到 \(t\),只需要修改 \(\left\lfloor c_i/r_i\right\rfloor=t-1\) 的塔的函数值,是一个单点修改。那么查询就是轻而易举的。

考虑当小恐龙来临时,我们还是需要维护:哪些连续的塔是推平的,哪个塔是小恐龙死的。用一个 stack 维护所有的区间,根据这个区间是区间还是单点,决定如果算出它现在的总和。如果计算得到小恐龙在这个区间停下,使用线段树二分得到具体位置,并将区间断开。因为每个区间访问过之后要么没了要么只会重新加入,而重新加入只会发生 \(O(q)\) 次,因此总的时间复杂度是 \(O((n+q+T)\log n)\)

code

分块

// ubsan: undefined
// accoders
#include <cstdio>
#include <vector>
#include <cstring>
#include <cassert>
#include <algorithm>
using namespace std;
#ifdef LOCAL
#define debug(...) fprintf(stderr, ##__VA_ARGS__)
#else
#define debug(...) void(0)
#endif
typedef long long LL;
template <int N>
struct block {
    LL c[N + 10], r[N + 10], now[N + 10], tag, flag;
    LL syu[200010];
    void build(int *_c, int *_r) {
        tag = 0, flag = 1;
        for (int i = 1; i <= N; i++) {
            c[i] = _c[i], r[i] = _r[i], now[i] = c[i];
            if (r[i])
                syu[min(c[i] / r[i], LL(2e5))] += r[i];
        }
        for (int i = 2e5; i >= 1; i--) syu[i - 1] += syu[i];
        syu[0] = 0;
        for (int i = 1; i <= N; i++) {
            if (r[i] && c[i] / r[i] + 1 <= 2e5)
                syu[c[i] / r[i] + 1] += c[i] % r[i];
        }
        syu[0] = 0;
        for (int i = 1; i <= 2e5; i++) syu[i] += syu[i - 1];
    }
    void remake() {
        if (!tag)
            return;
        if (!flag)
            memset(now, 0, sizeof now);
        for (int i = 1; i <= N; i++) now[i] = min(c[i], now[i] + tag * r[i]);
        tag = 0, flag = 1;
    }
    LL getSum() {
        if (!flag)
            return assert(tag <= 2e5), syu[tag];
        remake();
        LL res = 0;
        for (int i = 1; i <= N; i++) res += now[i];
        return res;
    }
    void clear() { tag = flag = 0; }
    void brute(LL h) {
        remake();
        for (int i = 1; i <= N; i++) {
            if (h >= now[i])
                h -= now[i], now[i] = 0;
            else {
                now[i] -= h;
                break;
            }
        }
    }
};
constexpr int B = 1 << 10;
block<B> blo[int(2e5) / B + 5];
int n, Q;
int c[1 << 18], r[1 << 18];
int main() {
    fprintf(stderr, "sizeof blo = %u\n", sizeof blo >> 20);
#ifndef NF
#endif
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) scanf("%d%d", &c[i], &r[i]);
    int bcnt = 0;
    for (int i = 1; i <= n; i += B) blo[++bcnt].build(c + i - 1, r + i - 1);
    scanf("%d", &Q);
    LL ans = 0, last = 0;
    for (LL h, t; Q--;) {
        scanf("%lld%lld", &t, &h);
        for (int b = 1; b <= bcnt; b++) blo[b].tag += t - last;
        last = t;
        for (int b = 1; b <= bcnt && h; b++) {
            LL s = blo[b].getSum();
            debug("blo[%d] = %lld\n", b, s);
            if (h >= s)
                blo[b].clear(), h -= s;
            else {
                blo[b].brute(h), h = 0;
                break;
            }
        }
        ans += h;
    }
    printf("%lld\n", ans);
    return 0;
}

可持久化线段树的代码还没调出来。但是贴在这里。

#include <cstdio>
#include <vector>
#include <cstring>
#include <cassert>
#include <algorithm>
using namespace std;
#ifdef LOCAL
#define debug(...) fprintf(stderr, ##__VA_ARGS__)
#else
#define debug(...) void(0)
#endif
typedef long long LL;
struct func {
    LL k, b;
    func() : func(0, 0) {}
    func(LL k, LL b) : k(k), b(b) {}
    func operator+(const func &rhs) const { 
        return func(k + rhs.k, b + rhs.b); 
    }
    LL operator()(LL x) const { return k * x + b; }
};
template <int N>
struct segtree {
    func ans[N << 5];
    int ch[N << 5][2], tot;
    segtree() : tot(0) { ch[0][0] = ch[0][1] = 0; }
    int newnode(int q = 0) {
        int p = ++tot;
        memcpy(ch[p], ch[q], sizeof ch[0]);
        ans[p] = ans[q];
        return p;
    }
    void maintain(int p) {
        ans[p] = ans[ch[p][0]] + ans[ch[p][1]];
    }
    int modify(int x, const func &f, int q, int l, int r) {
        int p = newnode(q);
        if (l == r) return ans[p] = f, p;
        int mid = (l + r) >> 1;
        if (x <= mid)
            ch[p][0] = modify(x, f, ch[q][0], l, mid);
        else
            ch[p][1] = modify(x, f, ch[q][1], mid + 1, r);
        maintain(p);
        return p;
    }
    func query(int L, int R, int p, int l, int r) {
        if (!p || (L <= l && r <= R)) return ans[p];
        int mid = (l + r) >> 1;
        func ret = func(0, 0);
        if (L <= mid)
            ret = ret + query(L, R, ch[p][0], l, mid);
        if (mid < R)
            ret = ret + query(L, R, ch[p][1], mid + 1, r);
        return ret;
    }
    int binary(int L, int R, int x, LL &h, int p, int l, int r) {
        if (L <= l && r <= R) {
            if (!p) return -1;
            if (h >= ans[p](x)) return h -= ans[p](x), -1;
            if (l == r) return l;
        }
        int mid = (l + r) >> 1, res;
        if (L <= mid && (res = binary(L, R, x, h, ch[p][0], l, mid)) != -1)
            return res;
        if (mid < R && (res = binary(L, R, x, h, ch[p][1], mid + 1, r)) != -1)
            return res;
        return -1;
    }
};
struct range {
    int l, r, lst;
};
int n, c[1 << 18], r[1 << 18], root[1 << 18], now[1 << 18], Q;
segtree<1 << 18> T;
vector<int> buc[1 << 18];
range stk[1 << 18];
int top;
void buildT() {
    for (int i = 1; i <= n; i++)
        root[0] = T.modify(i, func(r[i], 0), root[0], 1, n);
    for (int i = 1; i <= n; i++)
        if (c[i] / r[i] + 1 <= 2e5) buc[c[i] / r[i] + 1].push_back(i);
    for (int t = 1; t <= 2e5; t++) {
        root[t] = root[t - 1];
        for (int i: buc[t]) 
            root[t] = T.modify(i, func(0, c[i]), root[t], 1, n);
    }
}
int last[1 << 18];
LL solve(LL h, int t) {
    while (top && h) {
        int L = stk[top].l, R = stk[top].r, lst = stk[top].lst, pos;
        --top;
        if (lst == -1) {
            now[L] = min(LL(c[L]), now[L] + 1ll * (t - last[L]) * r[L]);
            last[L] = t;
            if (h >= now[L]) h -= now[L], pos = -1;
            else pos = L;
        } else {
            pos = T.binary(L, R, t - lst, h, root[t - lst], 1, n);
        }
        if (pos == -1) continue;
        if (pos < R) stk[++top] = {pos + 1, R, lst};
        stk[++top] = {pos, pos, -1};
        last[pos] = t;
        if (lst == -1) {
            now[pos] -= h;
        } else {
            now[pos] = T.query(pos, pos, root[t - lst], 1, n)(t - lst) - h;
        }
        if (pos > 1) stk[++top] = {1, pos - 1, t};
        debug(">>> h = %lld\n", 0ll);
        return 0;
    }
    if (!top) stk[++top] = {1, n, t};
    debug(">>> h = %lld\n", h);
    return h;
}
int main() {
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) scanf("%d%d", &c[i], &r[i]), now[i] = c[i];
    buildT();
    for (int i = n; i >= 1; i--) stk[++top] = {i, i, -1}, last[i] = 0;
    LL ans = 0;
    scanf("%d", &Q);
    for (LL i = 1, h, t; i <= Q; i++) scanf("%lld%lld", &t, &h), ans += solve(h, t);
    printf("%lld\n", ans);
    return 0;
}