[蓝桥杯 2021 国 AB] 翻转括号序列(线段树上二分)

发布时间 2023-04-07 13:23:13作者: Fighoh

[蓝桥杯 2021 国 AB] 翻转括号序列

题目描述

给定一个长度为 \(n\) 的括号序列,要求支持两种操作:

  1. \(\left[L_{i}, R_{i}\right]\) 区间内(序列中的第 \(L_{i}\) 个字符到第 \(R_{i}\) 个字符)的括号全部翻转(左括号变成右括号,右括号变成左括号)。

  2. 求出以 \(L_{i}\) 为左端点时,最长的合法括号序列对应的 \(R_{i}\) (即找出最大的 \(R_{i}\) 使 \(\left[L_{i}, R_{i}\right]\) 是一个合法括号序列)。

输入格式

输入的第一行包含两个整数 \(n, m\),分别表示括号序列长度和操作次数。

第二行包含给定的括号序列,括号序列中只包含左括号和右括号。

接下来 \(m\) 行,每行描述一个操作。如果该行为 1 L R, 表示第一种操作,区间为 \(\left[L, R\right]\);如果该行为 2 L 表示第二种操作,左端点为 \(L\)

对于所有评测用例,\(1 \leq n \leq 10^{6}, 1 \leq m \leq 2 \times 10^{5}\)

Solution

分析一下操作二

看到这两个操作不难想到这个问题大概可以用线段树解决,不难想到我们可以把左括号看成1,右括号看成-1。这个时候的合法序列就应该是

\[legal\;sequence\iff sum(l\cdots r) = 0 \;and\; \forall_{k\in \{l\cdots r\}}sum(l\cdots k) \ge 0 \]

那如果用线段树直接去维护区间和,然后每次去查询\(l\)右侧是否有这样一个点是很困难的,所以要把它转化成前缀和形式

\[legal\;sequence\iff pre(r) = pre(l - 1) \;and\; \forall_{k\in \{l\cdots r\}}pre(k) \ge pre(l - 1) \]

所以对于每一个\(l\)我们要找的是\(l - 1\)右侧最后一个符合上述式子的位置即可。

对于操作一的实现

而对于翻转区间操作,如果翻转的区间的\(l = 1\),我们可以知道就是直接把\(l\cdots r\)全都取相反数而对于\([r + 1\cdots n]\)的区间应该全都减去某一个数

\[pre(r + 1) = pre(r) + sum(r + 1),\; pre(r + 2) = pre(r) + sum(r + 1) + sum(r + 2) \cdots \]

所以当\(pre(r) \rightarrow -pre(r)\)时,\([r + 1\cdots n]\)的区间应该全都减去原来的\(2 * pre(r)\),由于是区间操作,我们需要一个懒惰标记,记为\(lazy\_add\)

像这种区间翻转一个很套路的做法就是

\[reverse(l\cdots r) \iff reverse(1\cdots r)\; + \; reverse(1\cdots l-1) \]

在线段树上翻转区间我们只要维护区间最大值和区间最小值,每次翻转就是交换最大最小值并且取负数,由于是区间操作,我们需要一个懒惰标记,记为\(lazy\_rev\)

//假设p是我们当前操作区间的节点id
tmp1 = mx[p], tmp2 = mn[p];
mx[p] = (~tmp2) + 1;//取反加1就是取相反数
mn[p] = (~tmp1) + 1;
lazy_rev[p] ^= 1;

对于操作二的实现

很显然右端点所在的位置是可以二分的,那我们先考虑直接进行二分,我们每次去二分一个位置\(pos\)然后验证区间\([l\cdots pos]\)的最小值是否小于\(pre(l - 1)\),小于我们就去左侧区间,否则就去右侧区间,但是我们要注意如果右侧区间最小值大于\(pre(l - 1)\)其实也是没有答案的,因为我们要找的右端点\(pre(r) = pre(l - 1)\)

int now = query_po(1, 1, n, l - 1);//单点查询pre(l - 1)
int L = l, R = n;
int ans = 0;
while (L <= R) {
    int mid = L + R >> 1;
    int k = query(1, 1, n, l, mid);//查询区间(l, mid)最小值
    if (k < now || k >= now && query(1, 1, n, mid, n) > now) {//保证满足上述条件,其实可以不用第二个query
        R = mid - 1;
    } else {
        L = mid + 1;
        ans = mid;
    }
}

这个做法时间复杂度是\(O(m\log^2n)\),像我写法不太好就可能被卡掉,要常数非常小才有可能通过。所以我们得考虑把这个二分的过程搬到线段树上去,线段树的很多操作本质上就是在做二分,我们可以利用这个二分的过程。

基于朴素二分的优化

先看一下我一开始错误的二分方式

int query(int p, int l, int r, int pos, int val) {//pos代表题目中的l, val代表pre(l - 1)
    if (l == r) {
        if (mn[p] == val) return l;//要和val相等
        else return 0;
    }
    push_down(p);
    int mid = l + r >> 1;
    if (pos > mid) return query(p << 1 | 1, mid + 1, r, pos, val);//pos在右侧我们直接去右儿子   (1)
    int ans = 0;
    if (mn[p << 1] < val) return query(p << 1, l, mid, pos, val);//如果左侧最小值小于val就得去左儿子  (2)
    if (mn[p << 1 | 1] <= val) ans = max(ans, query(p << 1 | 1, mid + 1, r, pos, val));//否则我们贪心先去右儿子,但右边最小值要<=val否则找不到val,为什么小于也可行,因为上一行代码已经保证了左侧最小值>=val,每一个括号只能产生1 or -1的贡献所以是连续的,所以右侧就一定会出现val (3)
    ans = max(ans, query(p << 1, l, mid, pos, val));//可能会返回0,所以再去左侧看看有没有答案
    return ans;
}

其实有很显然的错误

在(2)处,虽然我们保证了当前的\(pos \le mid\)但是\(l\)(线段树的区间的\(l\))依旧可能在\(pos\)的左侧,这就导致了mn[p << 1] < val这个语句会出错,我们要精确地找到这个\(pos\)右侧的最小值通过这一个\(query\)是不太可行的(至少我好像实现不了)。我们可以发现如果\(pos\)的右侧有一个最小值小于\(val\)那就一定是第一个小于\(val\)的左侧点(因为\(pre\)数组具有连续性且这个点是第一个小于val的点)。所以我们直接去二分\(pos\)右侧第一个小于\(val\)的点即可

int query_l(int p, int l, int r, int pos, int val) {
    if (l == r) return l;
    push_down(p);
    int mid = l + r >> 1;
    int ans = 0;
    if (mn[p << 1] < val && pos <= mid) ans =  query_l(p << 1, l, mid, pos, val);//pos在左侧区间并且最小值小于val
    if (ans) return ans;//如果已经有值直接返回就行,因为我们要找的是第一个
    if (mn[p << 1 | 1] < val) ans = query_l(p << 1 | 1, mid + 1, r, pos, val);//右侧要有<val的点才行
    return ans;
}

但是我们发现还有问题,比如这样一个序列((())),他的\(pre\)数组就全都\(\ge (pre(l - 1) = 0)\)也就是不存在小于\(pre(l - 1)\)的点,此时我们就得再去二分一次,去找到\(l\)右侧最后一个等于\(pre(l - 1)\)的点

int query_r(int p, int l, int r, int pos, int val) {
    if (l == r) return l;
    push_down(p);
    int mid = l + r >> 1;
    int ans = 0;
    if (mn[p << 1 | 1] <= val) ans = query_r(p << 1 | 1, mid + 1, r, pos, val);//由于前一次二分我们已经保证了pos右侧所有pre全都>=val所以有小于等于就是等于。还是贪心的先去右边
    if (ans) return ans;
    if (mn[p << 1] <= val && pos <= mid) ans =  query_r(p << 1, l, mid, pos, val);
    return ans;
}

然后还要在注意一下\(pos\)处是\()\)的情况。

LAZY标记!!!

最后,还有最关键的一点也是线段树区间修改最需要注意的地方也就是这种多个\(lazy\)标记的相互影响,打上\(lazy\_add\)标记不会对\(lazy\_rev\)产生影响, 但是打上\(lazy\_rev\)会对\(lazy\_add\)产生影响,由于\(lazy\_rev\)是区间取反,所以要把\(lazy\_add\)也取反。

CODE

\(O(n\log n)\)

bool lazy_rev[maxn << 2];
int mn[maxn << 2], mx[maxn << 2];
int lazy_add[maxn << 2];
string s;
int pre[maxn];
int n, q;
int tmp1, tmp2, tmp3, tmp4;

void build(int p, int l, int r) {
    if (l == r) {
        mn[p] = pre[l];
        mx[p] = pre[l];
        return ;
    }
    int mid = l + r >> 1;
    build(p << 1, l, mid);
    build(p << 1 | 1, mid + 1, r);
    mn[p] = min(mn[p << 1], mn[p << 1 | 1]);
    mx[p] = max(mx[p << 1], mx[p << 1 | 1]);
}

void push_down(int p) {
    if (lazy_rev[p]) {
        tmp1 = mn[p << 1], tmp2 = mn[p << 1 | 1];
        tmp3 = mx[p << 1], tmp4 = mx[p << 1 | 1];
        mn[p << 1] = (~tmp3) + 1;
        mx[p << 1] = (~tmp1) + 1;
        mn[p << 1 | 1] = (~tmp4) + 1;
        mx[p << 1 | 1] = (~tmp2) + 1;
        lazy_rev[p << 1] ^= 1;
        lazy_add[p << 1] = (~lazy_add[p << 1]) + 1;
        lazy_rev[p << 1 | 1] ^= 1;
        lazy_add[p << 1 | 1] = (~lazy_add[p << 1 | 1]) + 1;
        lazy_rev[p] = 0;
    }
    if (lazy_add[p]) {
        mx[p << 1] += lazy_add[p];
        mx[p << 1 | 1] += lazy_add[p];
        mn[p << 1] += lazy_add[p];
        mn[p << 1 | 1] += lazy_add[p];
        lazy_add[p << 1] += lazy_add[p];
        lazy_add[p << 1 | 1] += lazy_add[p];
        lazy_add[p] = 0;
    }
}

void update(int p, int l, int r, int ql, int qr) {
    if (ql > qr) return ;
    if (ql <= l && r <= qr) {
        tmp1 = mx[p], tmp2 = mn[p];
        mx[p] = (~tmp2) + 1;
        mn[p] = (~tmp1) + 1;
        lazy_add[p] = (~lazy_add[p]) + 1;
        lazy_rev[p] ^= 1;
        return ;
    }
    push_down(p);
    int mid = l + r >> 1;
    if (ql <= mid) update(p << 1, l, mid, ql, qr);
    if (mid < qr) update(p << 1 | 1, mid + 1, r, ql, qr);
    mx[p] = max(mx[p << 1], mx[p << 1 | 1]);
    mn[p] = min(mn[p << 1], mn[p << 1 | 1]);
}

void update(int p, int l, int r, int ql, int qr, int val) {
    if (ql > qr) return ;
    if (ql <= l && r <= qr) {
        mx[p] += val;
        mn[p] += val;
        lazy_add[p] += val;
        return ;
    }
    push_down(p);
    int mid = l + r >> 1;
    if (ql <= mid) update(p << 1, l, mid, ql, qr, val);
    if (mid < qr) update(p << 1 | 1, mid + 1, r, ql, qr, val);
    mx[p] = max(mx[p << 1], mx[p << 1 | 1]);
    mn[p] = min(mn[p << 1], mn[p << 1 | 1]);
}

int query_po(int p, int l, int r, int pos) {
    if (pos < 1) return 0;
    if (l == r) return mn[p];
    push_down(p);
    int mid = l + r >> 1;
    if (pos <= mid) return query_po(p << 1, l, mid, pos);
    else return query_po(p << 1 | 1, mid + 1, r, pos);
}

int query_l(int p, int l, int r, int pos, int val) {
    if (l == r) return l;
    push_down(p);
    int mid = l + r >> 1;
    int ans = 0;
    if (mn[p << 1] < val && pos <= mid) ans =  query_l(p << 1, l, mid, pos, val);
    if (ans) return ans;
    if (mn[p << 1 | 1] < val) ans = query_l(p << 1 | 1, mid + 1, r, pos, val);
    return ans;
}

int query_r(int p, int l, int r, int pos, int val) {
    if (l == r) return l;
    push_down(p);
    int mid = l + r >> 1;
    int ans = 0;
    if (mn[p << 1 | 1] <= val) ans = query_r(p << 1 | 1, mid + 1, r, pos, val);
    if (ans) return ans;
    if (mn[p << 1] <= val && pos <= mid) ans =  query_r(p << 1, l, mid, pos, val);
    return ans;
}

void solve(int cas) {
    cin >> n >> q >> s;
    for (int i = 1; i <= n; ++i) {
        pre[i] = (s[i - 1] == '(' ? 1 : -1);
        pre[i] += pre[i - 1];
    }
    build(1, 1, n);
    while (q--) {
        int op, l, r; cin >> op;
        if (op == 1) {
            cin >> l >> r;
            int delta = query_po(1, 1, n, l - 1);
            update(1, 1, n, 1, l - 1);
            update(1, 1, n, l, n, (~(delta << 1)) + 1);
            delta = query_po(1, 1, n, r);
            update(1, 1, n, 1, r);
            update(1, 1, n, r + 1, n, (~(delta << 1)) + 1);
        } else {
            cin >> l;
            int now;
            now = query_po(1, 1, n, l - 1);
            int pos = query_l(1, 1, n, l, now);
            if (pos - 1 > l) cout << pos - 1 << '\n';
            else if (pos) cout << 0 << '\n';
            else cout << query_r(1, 1, n, l, now) << '\n';
        }
    }
}

\(O(n\log^2n)\)

bool lazy_rev[maxn << 2];
int mn[maxn << 2], mx[maxn << 2];
int lazy_add[maxn << 2];
string s;
int pre[maxn];
int n, q;

void build(int p, int l, int r) {
    if (l == r) {
        mn[p] = pre[l];
        mx[p] = pre[l];
        return ;
    }
    int mid = l + r >> 1;
    build(p << 1, l, mid);
    build(p << 1 | 1, mid + 1, r);
    mn[p] = min(mn[p << 1], mn[p << 1 | 1]);
    mx[p] = max(mx[p << 1], mx[p << 1 | 1]);
}

void push_down(int p) {
    if (lazy_rev[p]) {
        swap(mn[p << 1], mx[p << 1]);
        swap(mx[p << 1 | 1], mn[p << 1 | 1]);
        mn[p << 1] = -mn[p << 1];
        mx[p << 1] = -mx[p << 1];
        mn[p << 1 | 1] = -mn[p << 1 | 1];
        mx[p << 1 | 1] = -mx[p << 1 | 1];
        lazy_rev[p << 1] ^= 1;
        lazy_add[p << 1] *= -1;
        lazy_rev[p << 1 | 1] ^= 1;
        lazy_add[p << 1 | 1] *= -1;
        lazy_rev[p] = 0;
    }
    if (lazy_add[p]) {
        mx[p << 1] += lazy_add[p];
        mx[p << 1 | 1] += lazy_add[p];
        mn[p << 1] += lazy_add[p];
        mn[p << 1 | 1] += lazy_add[p];
        lazy_add[p << 1] += lazy_add[p];
        lazy_add[p << 1 | 1] += lazy_add[p];
        lazy_add[p] = 0;
    }
}

void update(int p, int l, int r, int ql, int qr) {
    if (r < l || ql > qr) return ;
    if (ql <= l && r <= qr) {
        swap(mx[p], mn[p]);
        mx[p] = -mx[p];
        mn[p] = -mn[p];
        lazy_add[p] *= -1;
        lazy_rev[p] ^= 1;
        return ;
    }
    push_down(p);
    int mid = l + r >> 1;
    if (ql <= mid) update(p << 1, l, mid, ql, qr);
    if (mid < qr) update(p << 1 | 1, mid + 1, r, ql, qr);
    mx[p] = max(mx[p << 1], mx[p << 1 | 1]);
    mn[p] = min(mn[p << 1], mn[p << 1 | 1]);
}

void update(int p, int l, int r, int ql, int qr, int val) {
    if (r < l || ql > qr) return ;
    if (ql <= l && r <= qr) {
        mx[p] += val;
        mn[p] += val;
        lazy_add[p] += val;
        return ;
    }
    push_down(p);
    int mid = l + r >> 1;
    if (ql <= mid) update(p << 1, l, mid, ql, qr, val);
    if (mid < qr) update(p << 1 | 1, mid + 1, r, ql, qr, val);
    mx[p] = max(mx[p << 1], mx[p << 1 | 1]);
    mn[p] = min(mn[p << 1], mn[p << 1 | 1]);
}

int query_po(int p, int l, int r, int pos) {
    if (pos < 1) return 0;
    if (l == r) return mn[p];
    push_down(p);
    int mid = l + r >> 1;
    if (pos <= mid) return query_po(p << 1, l, mid, pos);
    else return query_po(p << 1 | 1, mid + 1, r, pos);
}

int query(int p, int l, int r, int ql, int qr) {
    if (ql <= l && r <= qr) return mn[p];
    push_down(p);
    int mid = l + r >> 1;
    int ans = INF;
    if (ql <= mid) ans = min(ans, query(p << 1, l, mid, ql, qr));
    if (mid < qr) ans = min(ans, query(p << 1 | 1, mid + 1, r, ql, qr));
    return ans;
}

void solve(int cas) {
    cin >> n >> q >> s;
    for (int i = 1; i <= n; ++i) {
        pre[i] = (s[i - 1] == '(' ? 1 : -1);
        pre[i] += pre[i - 1];
    }
    build(1, 1, n);
    while (q--) {
        int op, l, r; cin >> op;
        if (op == 1) {
            cin >> l >> r;
            int delta = query_po(1, 1, n, l - 1);
            update(1, 1, n, 1, l - 1);
            update(1, 1, n, l, n, -2 * delta);
            delta = query_po(1, 1, n, r);
            update(1, 1, n, 1, r);
            update(1, 1, n, r + 1, n, -2 * delta);
        } else {
            cin >> l;
            int now;
            now = query_po(1, 1, n, l - 1);
            int L = l, R = n;
            int ans = l - 1;
            while (L <= R) {
                int mid = L + R >> 1;
                int k = query(1, 1, n, l, mid);
                if (k < now || k >= now && query(1, 1, n, mid, n) > now) R = mid - 1;
                else {
                    L = mid + 1;
                    ans = mid;
                }
            }
            if (ans < l) cout << 0 << '\n';
            else cout << ans << '\n';
        }
    }
}