珂朵莉树——优雅的暴力

发布时间 2023-08-23 11:44:10作者: Aisaka_Taiga

珂朵莉树

引入

珂朵莉树 (Chtholly Tree),又名老司机树 (Old Driver Tree)。起源于CF896C

这种想法的本质是基于数据随机的「颜色段均摊」,而不是一种数据结构。

前置

需要了解 STL 的 set 的基本用法。比如:

  • insert(x) 当容器中没有等价元素的时候,将元素 x 插入到 set 中。
  • erase(x) 删除值为 x 的 所有 元素,返回删除元素的个数。
  • erase(pos) 删除迭代器为 pos 的元素,要求迭代器必须合法。
  • erase(first,last) 删除迭代器在 \([first,last]\) 范围内的所有元素。
  • clear() 清空 set

解决什么问题

通常是维护序列的操作,使一段区间内的东西变得一样,注意一定是数据随机

一般的情况下只要是有区间赋值操作,都可以骗分,但是数据不随机的话复杂度是不正确的。

对于 add,assign,sum 操作,用 set 实现的珂朵莉树的复杂度为 \(O(n \log \log n)\),用链表实现的复杂度为 \(O(n\log n)\)

操作

下面以模板题来介绍。

Willem, Chtholly and Seniorious - 洛谷

给定 \(n\) 个数, \(m\) 次操作 \((n,m\le 10^5)\)

支持:

  • 区间加

  • 区间赋值

  • 区间第 \(k\)

  • 区间幂次和

珂朵莉树的核心思想就是用一个结构体来存放一个值相同的子段信息,以此在区间的操作上降低复杂度,因为有区间赋值的操作,所以让他的复杂度降低了不少。

我们首先来看一下珂朵莉树如何存储信息。

结点存储

struct node
{
    int l, r;
    mutable int v;
    node(int L, int R = -1, int V = 0): l(L), r(R), v(V) {}//新建一个变量,如果只有L,R默认为-1,V默认为0
    bool operator < (const node &o) const{return l < o.l;}//重载运算符,按左端点从小到大排序
};
set<node> s;

可以看到里面有一个 mutable,他的意思是“可修改的”,因为在 set 里面的元素是不能直接修改的,我们在加上 mutable 后就可以不取出插入完成修改,而是可以直接修改里面的 \(v\) 值。

\(l,r\) 就是表示当前点存放的是区间 \([l,r]\) 的信息。

分裂区间

珂朵莉树的各种操作都是基于这个分裂区间的操作,因为我们有可能修改的区间左右端点是在一个结点当中包含的,比如要修改 \([3,4]\),但是我们只有 \([1,6]\),这个时候我们就需要把区间分成 \([1,2][3,4][5,6]\) 了。

假设我们需要分出一个以 \(x\) 为左端点的区间。

我们在实现的过程中,可以用 setlower_bound() 函数来找到第一个左端点不大于 \(x\) 的区间结点,然后特判一下:如果当前结点存的区间左端点正好是 \(x\),那么就不用分了,直接返回就好了;否则,我们找到前面一个区间,因为 \(x\) 肯定是包含在前一个区间,然后我们把他的信息 \((l,r,v)\) 都拿出来,防止等等删除后会出现错误,然后删除当前区间,再把 \((l,x- 1,v)\) 给插入,然后插入 \((x,r,v)\),这个时候我们就可以返回后面区间的迭代器了,由于 insert 后返回的是一个 pair 类型,first 是当前元素的迭代器,而 second 是一个 bool 类型的变量,表示此次操作是否成功,所以我们这样写:

#define IT set<node>::iterator
IT split(int p)
{
    IT it = s.lower_bound(node(p));//找到第一个左端点不小于p的元素
    if(it != s.end() && it -> l == p) return it;//如果不是最后一个元素,并且左端点就是p就不用分裂直接返回
    it --;//前移一个单位
    int L = it -> l, R = it -> r, V = it -> v;//取出所有的元素
    s.erase(it);//将当前区间删除掉
    s.insert(node(L, p - 1, V));//插入分裂的左半区间
    return s.insert(node(p, R, V)).first;//插入分裂的右半区间
}

注意:使用 split 函数的时候一定要先分裂 \(r + 1\),再分裂 \(l\),不然有可能会 RE。

区间加

我们上面说过 set 内的元素存放的都是一个区间,每个区间内的值都是一样的,所以我们在进行区间加的时候,我们先把左右端点的区间分裂出来,然后在直接对于我们得到的两个迭代器进行遍历,把中间所有的元素都直接加上 \(val\) 就好了。

inline void add(int l, int r, int val)
{
    IT itr = split(r + 1), itl = split(l);//分裂区间
    for(; itl != itr; itl ++) itl -> v += val;//暴力修改
    return ;
}

区间覆盖

也就是区间赋值。

我们上面进行了多次修改后可能所有的元素 \(l=r\) 了,也就是基本都是一个元素只代表一个点了,这样复杂度就成了 \(O(nm)\) 的了,但是由于有区间赋值的操作,我们可以把当前修改的区间直接变成一个元素,然后插进 set 里,这样在随机的数据下珂朵莉树的复杂度就变优了。

inline void assign_val(int l, int r, int val)
{
    IT itr = split(r + 1), itl = split(l);//分裂区间
    s.erase(itl, itr);//暴力删除
    s.insert(node(l, r, val));//插入新区间
}

区间第 K 大的值

我们可以开一个 vector 然后对其进行排序,然后遍历每一个元素,来找到第 \(k\) 大的元素。

inline int ranks(int l, int r, int k)
{
    vector<pair<int, int> > vp;
    IT itr = split(r + 1), itl = split(l);//分裂区间
    for(; itl != itr; itl ++) vp.push_back(pair<int, int>(itl -> v, itl -> r - itl -> l + 1));//放入vector,值,区间内数的个数
    sort(vp.begin(), vp.end());//从小到大排序
    for(vector<pair<int, int> >::iterator it = vp.begin(); it != vp.end(); it ++)//遍历
    {
        k -= it -> second;//减去数的个数
        if(k <= 0) return it -> first;//返回值
    }
}

区间求次幂和

我们可以直接遍历,然后计算区间内的数的个数,然后直接求和。

inline int sum(int l, int r, int ex, int mod)
{
    IT itr = split(r + 1), itl = split(l);
    int res = 0;
    for(; itl != itr; itl ++) res = (res + (int)(itl -> r - itl -> l + 1) * ksm(itl -> v, ex, mod)) % mod;
    return res;
}

模板题完整代码

#include <bits/stdc++.h>

#define IT set<node>::iterator
#define int long long
#define P 1000000007
#define N 1000100

using namespace std;

int n, m, seed, vmax, a[N];
struct node
{
    int l, r;
    mutable int v;
    node(int L, int R = -1, int V = 0): l(L), r(R), v(V) {}//新建一个变量,如果只有L,R默认为-1,V默认为0
    bool operator < (const node &o) const{return l < o.l;}//重载运算符,按左端点从小到大排序
};
set<node> s;

inline int ksm(int a, int b, int mod)//取模快速幂
{
    int res = 1;
    int ans = a % mod;
    while(b)
    {
        if(b & 1) res = res * ans % mod;
        ans = ans * ans % mod;
        b >>= 1;
    }
    return res;
}

IT split(int p)
{
    IT it = s.lower_bound(node(p));//找到第一个左端点不小于p的元素
    if(it != s.end() && it -> l == p) return it;//如果不是最后一个元素,并且左端点就是p就不用分裂直接返回
    it --;//前移一个单位
    int L = it -> l, R = it -> r, V = it -> v;//取出所有的元素
    s.erase(it);//将当前区间删除掉
    s.insert(node(L, p - 1, V));//插入分裂的左半区间
    return s.insert(node(p, R, V)).first;//插入分裂的右半区间
}

inline void add(int l, int r, int val)
{
    IT itr = split(r + 1), itl = split(l);//分裂区间
    for(; itl != itr; itl ++) itl -> v += val;//暴力修改
    return ;
}

inline void assign_val(int l, int r, int val)
{
    IT itr = split(r + 1), itl = split(l);//分裂区间
    s.erase(itl, itr);//暴力删除
    s.insert(node(l, r, val));//插入新区间
}

inline int ranks(int l, int r, int k)
{
    vector<pair<int, int> > vp;
    IT itr = split(r + 1), itl = split(l);//分裂区间
    for(; itl != itr; itl ++) vp.push_back(pair<int, int>(itl -> v, itl -> r - itl -> l + 1));//放入vector,值,区间内数的个数
    sort(vp.begin(), vp.end());//从小到大排序
    for(vector<pair<int, int> >::iterator it = vp.begin(); it != vp.end(); it ++)//遍历
    {
        k -= it -> second;//减去数的个数
        if(k <= 0) return it -> first;//返回值
    }
}

inline int sum(int l, int r, int ex, int mod)
{
    IT itr = split(r + 1), itl = split(l);
    int res = 0;
    for(; itl != itr; itl ++) res = (res + (int)(itl -> r - itl -> l + 1) * ksm(itl -> v, ex, mod)) % mod;
    return res;
}

inline int rnd()
{
    int res = seed;
    seed = (seed * 7 + 13) % P;
    return res;
}

signed main()
{
    cin >> n >> m >> seed >> vmax;
    for(int i = 1; i <= n; i ++)
    {
        a[i] = (rnd() % vmax) + 1;
        s.insert(node(i, i, a[i]));
    }
    s.insert(node(n + 1, n + 1, 0));
    for(int i = 1; i <= m; i ++)
    {
        int op = (rnd() % 4) + 1;
        int l = (rnd() % n) + 1;
        int r = (rnd() % n) + 1;
        if(l > r) swap(l, r);
        int x, y;
        if(op == 3) x = (rnd() % (r - l + 1)) + 1;
        else x = (rnd() % vmax) + 1;
        if(op == 4) y = (rnd() % vmax) + 1;

        if(op == 1) add(l, r, x);
        else if(op == 2) assign_val(l, r, x);
        else if(op == 3) cout << ranks(l, r, x) << endl;
        else cout << sum(l, r, x, y) << endl;
    }
    return 0;
}