P8600 连号区间数 题解

发布时间 2023-03-28 18:30:28作者: ice_dragon_grass

题目地址

题意:

在 1~N 的某个全排列中有多少个连号区间?如果一个区间中的所有数字按升序排列后是连续数列,则称其“连号”,如3,4,5

分析:

蓝桥杯 2013 省 B。原题数据很水,可\(O(n^2)\)过之。洛谷已加强时间限制,算是偏难的问题,应该被评为紫才对。

析合树的经典例题。不过以下题解不涉及析合树。

首先挖掘性质:一个区间为连号区间,等价于这个区间\([l,r]\)中,最大值mx-最小值mi=\(r-l\)

很容易想到暴力法:先枚举右端点r,再从r出发向左枚举l

一般来说,我寄希望于在枚举l时能找到某些连续规律,以一次性得到贡献值。然而,本题中,在固定r枚举l时,有贡献/无贡献的区间是间断出现的,几乎没有规律可言。

我们不难将上述的等式解耦:\(r=mx-mi+l\)。设\(v_i\)表示等号右部分,它表示的区间范围为\([i,r]\)。于是任务变为:寻找值为r的\(v_i\)的数量。我们不难针对每个r都扫一遍1~r,可以考虑到:以r为右端点的所有子区间,总是包含在以r+1为右端点的所有子区间下,这通常是我们优化一些区间统计问题的原理。

区间最大值、最小值是最难去处理的点,我们仔细分析:对于\(a_r\),它可以向左更新一段连续的\(v_i\)(分别从最大值与最小值)。比如有一段具有相同mx值的连续的\(v_i\),其原先mx值为4,我们更新为6,相当于给这段\(v_i\)加1,我们可以维护一个单调栈,去更新这样的一段段具有相同mx, mi的连续值。

然后就是查找值为r的\(v_i\)的数量,事实上并没有办法边做区间修改边做到查询指定值数量,不过还有一个性质:\(mx-mi \ge r-l\),即\(mx-mi+l \ge r\),这表明r是最小值,并且一定存在于某个\(v_i\)等于r(至少对于\([r,r]\)来说),而用线段树来维护最小值的数量并不难做到。此题的分析便到此为止了。

思路

用两个单调栈维护连续区间的最大值、连续区间的最小值。

线段树维护区间最小值的数量。

代码:

#include <bits/stdc++.h>
using namespace std;

const int maxn = 5e4 + 5;

struct Dat
{
    int val;
    int num;
    bool operator<(const Dat &other)
    {
        return val < other.val;
    }
    bool operator==(const Dat &other)
    {
        return val == other.val;
    }
};

Dat operator+(const Dat &lhs, const Dat &rhs)
{
    return {lhs.val, lhs.num + rhs.num};
}

struct Node
{
    int l;
    int r;
    int tag;
    Dat d;
} nd[maxn << 2];

void pushup(int x)
{
    if (nd[x << 1].d < nd[x << 1 | 1].d)
        nd[x].d = nd[x << 1].d;
    else if (nd[x << 1].d == nd[x << 1 | 1].d)
        nd[x].d = nd[x << 1].d + nd[x << 1 | 1].d;
    else
        nd[x].d = nd[x << 1 | 1].d;
}

void pushdown(int x)
{
    nd[x << 1].tag += nd[x].tag;
    nd[x << 1 | 1].tag += nd[x].tag;
    nd[x << 1].d.val += nd[x].tag;
    nd[x << 1 | 1].d.val += nd[x].tag;
    nd[x].tag = 0;
}

void make(int x, int l, int r)
{
    nd[x].l = l;
    nd[x].r = r;
    if (l == r)
    {
        nd[x].d = {l, 1};
        return;
    }
    int mid = (l + r) / 2;
    make(x << 1, l, mid);
    make(x << 1 | 1, mid + 1, r);
    pushup(x);
}

void updata(int x, int l, int r, int v)
{
    if (l <= nd[x].l && nd[x].r <= r)
    {
        nd[x].d.val += v;
        nd[x].tag += v;
        return;
    }
    pushdown(x);
    int mid = (nd[x].l + nd[x].r) / 2;
    if (l <= mid)
        updata(x << 1, l, r, v);
    if (r > mid)
        updata(x << 1 | 1, l, r, v);
    pushup(x);
}

Dat query(int x, int l, int r)
{
    if (l <= nd[x].l && nd[x].r <= r)
    {
        return nd[x].d;
    }
    pushdown(x);
    Dat a = {100000, 0}, b = {100000, 0}, res;
    int mid = (nd[x].l + nd[x].r) / 2;
    if (l <= mid)
        a = query(x << 1, l, r);
    if (r > mid)
        b = query(x << 1 | 1, l, r);
    if (a < b)
        res = a;
    else if (a == b)
        res = a + b;
    else
        res = b;
    pushup(x);
    return res;
}

int st1[maxn], st2[maxn];
int p1, p2;
int ar[maxn];

int main()
{
    cin.tie(0)->sync_with_stdio(false);
    int n;
    cin >> n;
    long long ans = 0;
    make(1, 1, n);
    for (int i = 1; i <= n; ++i)
    {
        cin >> ar[i];
        while (p1 && ar[st1[p1]] <= ar[i])
        {
            updata(1, st1[p1 - 1] + 1, st1[p1], ar[i] - ar[st1[p1]]);
            --p1;
        }
        st1[++p1] = i;
        while (p2 && ar[st2[p2]] >= ar[i])
        {
            updata(1, st2[p2 - 1] + 1, st2[p2], ar[st2[p2]] - ar[i]);
            --p2;
        }
        st2[++p2] = i;
        Dat r = query(1, 1, i);
        ans += r.num;
    }
    cout << ans;
}