C. Medium Design

发布时间 2023-10-23 17:33:05作者: onlyblues

C. Medium Design

The array $a_1, a_2, \ldots, a_m$ is initially filled with zeroes. You are given $n$ pairwise distinct segments $1 \le l_i \le r_i \le m$. You have to select an arbitrary subset of these segments (in particular, you may select an empty set). Next, you do the following:

For each $i = 1, 2, \ldots, n$, if the segment $(l_i, r_i)$ has been selected to the subset, then for each index $l_i \le j \le r_i$ you increase $a_j$ by $1$ (i. e. $a_j$ is replaced by $a_j + 1$). If the segment $(l_i, r_i)$ has not been selected, the array does not change.
Next (after processing all values of $i = 1, 2, \ldots, n$), you compute $\max(a)$ as the maximum value among all elements of $a$. Analogously, compute $\min(a)$ as the minimum value.
Finally, the cost of the selected subset of segments is declared as $\max(a) - \min(a)$.
Please, find the maximum cost among all subsets of segments.

Input

Each test contains multiple test cases. The first line contains the number of test cases $t$ ($1 \le t \le 10^4$). The description of the test cases follows.

The first line of each test case contains two integers $n$ and $m$ ($1 \le n \le 10^5$, $1 \le m \le 10^9$) — the number of segments and the length of the array.

The following $n$ lines of each test case describe the segments. The $i$-th of these lines contains two integers $l_i$ and $r_i$ ($1 \le l_i \le r_i \le m$). It is guaranteed that the segments are pairwise distinct.

It is guaranteed that the sum of $n$ over all test cases does not exceed $2 \cdot 10^5$.

Output

For each test case, output the maximum cost among all subsets of the given set of segments.

Example

input

6
1 3
2 2
3 8
2 4
3 5
4 6
6 3
1 1
1 2
1 3
2 2
2 3
3 3
7 6
2 2
1 6
1 2
5 6
1 5
4 4
3 6
6 27
6 26
5 17
2 3
20 21
1 22
12 24
4 1000000000
2 999999999
3 1000000000
123456789 987654321
9274 123456789

output

1
3
2
3
4
4

Note

In the first test case, there is only one segment available. If we do not select it, then the array will be $a = [0, 0, 0]$, and the cost of such (empty) subset of segments will be $0$. If, however, we select the only segment, the array will be $a = [0, 1, 0]$, and the cost will be $1 - 0 = 1$.

In the second test case, we can select all the segments: the array will be $a = [0, 1, 2, 3, 2, 1, 0, 0]$ in this case. The cost will be $3 - 0 = 3$.

 

解题思路

  nmd 昨天这题卡了一个半小时都没写出来,后面用了个巨复杂的做法才过,代码巨长调了半天。

  先给出官方的思路。假设 $x$ 是最优解方案中最大元素的下标,因为 $x$ 是最大元素的下标,因此我们应该选择所有覆盖 $x$ 的区间 $[l,r]$($l \leq x \leq r$)。这是因为如果区间 $[l,r]$ 没覆盖到最小元素的下标,那么答案会 $+1$;如果覆盖到最小元素的下标,那么答案不会变小。同时在这种选择下还会有另外一个结论,即最小元素的下标不是 $1$ 就是 $m$。

  感觉很不显然,这里给出两种解释。首先假设有下标 $i$ 和 $j$ 满足 $i < j < x$,把所有选择的区间分成三类:$1.$ $l \leq i$、$2.$ $i < l \leq j$、$3.$ $l > j$。对于第 $1$ 类的区间,下标 $i$ 和 $j$ 都会被覆盖,而对于第 $2$ 类的区间只有下标 $j$ 会被覆盖,而第 $3$ 类的区间 $i$ 和 $j$ 都无法被覆盖,因此下标 $i$ 处被覆盖的次数一定不超过 $j$ 被覆盖的次数。即下标越小,元素的值越小。因此对于小于 $x$ 的下标,下标 $1$ 的元素值一定最小。同理可以推出对于大于 $x$ 的下标,下标 $m$ 的元素值一定最小。

  再给出另外一种解释。假设在最优方案中最小元素的下标 $y < x$,且 $y \ne 1$。下面证明可以将 $y$ 调整到 $1$ 而不影响答案。由于下标 $y$ 的元素值小于下标 $1$ 的元素值,因此必然还选择了左端点是 $1$ 且右端点严格小于 $y$ 的区间。我们把这些区间删去,很明显答案不会受到影响,因为这些区间都没有覆盖到 $x$ 和 $y$。同时下标 $1$ 的元素值不会超过下标 $y$ 的元素值(因为此时下标 $1$ 只能被左端点为 $1$,右端点至少为 $y$ 的下标覆盖,除了这些区间外 $y$ 还能被其他的区间覆盖),因此最小值的下标就变成 $1$ 了。

  所以为了得到最优解,只需选择所有不含 $1$ 或 $m$ 的区间,然后找到被覆盖次数最多的位置对应的元素,取两种情况的最大值即为答案。现在已经知道 $1$ 或 $m$ 是最小元素的下标了,如果选择包含 $1$ 或 $m$ 的区间,如果区间覆盖 $x$,那么对答案没影响,而如果不覆盖 $x$,那么答案会 $-1$,因此不选这些区间肯定不会让答案变糟糕。实现的话我这里是用离散化加前缀和。

  AC 代码如下,时间复杂度为 $O(n \log{n})$:

#include <bits/stdc++.h>
using namespace std;

typedef long long LL;

const int N = 2e5 + 10;

int n, m;
int l[N], r[N];
int xs[N * 2], sz;
int s[N];

int find(int x) {
    int l = 1, r = sz;
    while (l < r) {
        int mid = l + r >> 1;
        if (xs[mid] >= x) r = mid;
        else l = mid + 1;
    }
    return l;
}

int get(int x) {
    memset(s, 0, sz + 10 << 2);
    for (int i = 1; i <= n; i++) {
        if (l[i] != x && r[i] != x) s[find(l[i])]++, s[find(r[i] + 1)]--;
    }
    int mx = 0, mn = n;
    for (int i = 1; i <= sz; i++) {
        s[i] += s[i - 1];
        mx = max(mx, s[i]);
        mn = min(mn, s[i]);
    }
    return mx - mn;
}

void solve() {
    scanf("%d %d", &n, &m);
    sz = 0;
    for (int i = 1; i <= n; i++) {
        scanf("%d %d", l + i, r + i);
        xs[++sz] = l[i], xs[++sz] = r[i] + 1;
    }
    sort(xs + 1, xs + sz + 1);
    sz = unique(xs + 1, xs + sz + 1) - xs - 1;
    printf("%d\n", max(get(1), get(m)));
}

int main() {
    int t;
    scanf("%d", &t);
    while (t--) {
        solve();
    }
    
    return 0;
}

  还是记录一下一开始想到的复杂思路吧。

  如果把 $n$ 个区间都选择而无法将 $1 \sim m$ 的每个下标覆盖,那么很明显最小值就是 $0$,因此答案就是某个下标被覆盖最多次数。

  否则 $1 \sim m$ 都能被覆盖,把所有选择方案按最小元素的下标进行分类,那么最小元素的下标就可以是 $1, 2, \ldots, m$,一共 $m$ 种情况。然后与上面所述一样,选择所有不含这个下标的区间,求最大值。而这 $m$ 个下标里面只需考虑所有区间的两个端点即可,因此最多只需考虑 $2n$ 个下标,这是因为区间内的点的选择方案可以都用某个端点的选择方案来表示。

  如果分别枚举这 $2n$ 个点然后再选择区间求前缀和,那么时间复杂度是 $O(n^2)$。这里给出我的做法。

  用线段树动态维护区间元素的最大值,修改的操作只有对某个区间的每个数都加上一个数。一开始先把所有区间选上,用线段树来维护。然后对 $2n$ 个端点从小到大排序,依次枚举。假设当前枚举到端点 $x$,把所有左端点是 $x$ 的区间都删掉,把所有之前删掉的右端小于 $x$ 的区间重新加回来(用最小堆来维护),这样就选择了所有不含下标 $x$ 的区间了,此时最大值就是 tr[1].mx

  AC 代码如下,时间复杂度为 $O(n \log{n})$:

#include <bits/stdc++.h>
using namespace std;

typedef long long LL;
typedef pair<int, int> PII;

const int N = 2e5 + 10, M = N * 2;

int n, m;
vector<int> p[N], q[N];
int xs[M], sz;
struct Node {
    int l, r, mx, add;
}tr[N * 4];

bool check() {    // 检查n个区间能否把1~m都覆盖 
    sort(p + 1, p + n + 1);
    int r = 0;
    for (int i = 1; i <= n; i++) {
        if (p[i][0] > r + 1) return false;
        r = max(r, p[i][1]);
    }
    return r >= m;
}

int find(int x) {
    int l = 1, r = sz;
    while (l < r) {
        int mid = l + r >> 1;
        if (xs[mid] >= x) r = mid;
        else l = mid + 1;
    }
    return l;
}

void build(int u, int l, int r) {
    tr[u] = {l, r};
    if (l == r) {
        tr[u].mx = tr[u].add = 0;
    }
    else {
        int mid = l + r >> 1;
        build(u << 1, l, mid);
        build(u << 1 | 1, mid + 1, r);
        tr[u].mx = max(tr[u << 1].mx, tr[u << 1 | 1].mx);
    }
}

void pushdown(int u) {
    if (tr[u].add) {
        tr[u << 1].mx += tr[u].add;
        tr[u << 1].add += tr[u].add;
        tr[u << 1 | 1].mx += tr[u].add;
        tr[u << 1 | 1].add += tr[u].add;
        tr[u].add = 0;
    }
}

void modify(int u, int l, int r, int c) {
    if (tr[u].l >= l && tr[u].r <= r) {
        tr[u].mx += c;
        tr[u].add += c;
    }
    else {
        pushdown(u);
        int mid = tr[u].l + tr[u].r >> 1;
        if (l <= mid) modify(u << 1, l, r, c);
        if (r >= mid + 1) modify(u << 1 | 1, l, r, c);
        tr[u].mx = max(tr[u << 1].mx, tr[u << 1 | 1].mx);
    }
}

void solve() {
    scanf("%d %d", &n, &m);
    sz = 0;
    for (int i = 1; i <= n; i++) {
        p[i].resize(2);
        scanf("%d %d", &p[i][0], &p[i][1]);
        xs[++sz] = p[i][0];
        xs[++sz] = p[i][1];
    }
    sort(xs + 1, xs + sz + 1);
    sz = unique(xs + 1, xs + sz + 1) - xs - 1;
    build(1, 1, sz);
    for (int i = 1; i <= n; i++) {    // 一开始先把所有区间选上 
        modify(1, find(p[i][0]), find(p[i][1]), 1);
    }
    if (!check()) {    //检查能否把1~m都覆盖 
        printf("%d\n", tr[1].mx);    // 不能,则答案就是最大值减0 
        return;
    }
    for (int i = 1, j = 1; i <= n; i++) {    // 2n个区间端点排序,vector的第一个元素记录区间编号,第二个元素记录是左端点还是右端点 
        q[j++] = {i, 0};
        q[j++] = {i, 1};
    }
    sort(q + 1, q + 2 * n + 1, [&](vector<int> &a, vector<int> &b) {
        return p[a[0]][a[1]] < p[b[0]][b[1]];
    });
    int ret = 0;
    priority_queue<PII, vector<PII>, greater<PII>> pq;
    for (int i = 1; i <= n << 1; i++) {
        int j = i;
        while (j <= n << 1 && p[q[j][0]][q[j][1]] == p[q[i][0]][q[i][1]]) {    // 考虑所有相同的端点 
            if (!q[j][1]) {    // 如果是左端点,则删除这个区间 
                int l = find(p[q[j][0]][0]), r = find(p[q[j][0]][1]);
                modify(1, l, r, -1);
                pq.push({r, l});    // 把删除的区间压入堆中,堆是右端点为关键字的最小堆 
            }
            j++;
        }
        ret = max(ret, tr[1].mx);    // 此时选择了所有不含该端点的区间,取最大值 
        int x = find(p[q[i][0]][q[i][1]]);    // 该端点的离散值 
        while (!pq.empty() && pq.top().first <= x) {    // 把右端点不超过x的区间重新加上 
            modify(1, pq.top().second, pq.top().first, 1);
            pq.pop();
        }
        i = j - 1;
    }
    printf("%d\n", ret);
}

int main() {
    int t;
    scanf("%d", &t);
    while (t--) {
        solve();
    }
    
    return 0;
}

 

参考资料

  Codeforces Round #904 (Div. 2) Editorial:https://codeforces.com/blog/entry/121618