【主席树】洛谷 P3834 可持久化线段树 2

发布时间 2023-08-25 19:00:45作者: blockche

【主席树】洛谷 P3834 可持久化线段树2

题目链接:https://www.luogu.com.cn/problem/P3834

主席树是可持久化线段树的一种,也叫做可持久化权值线段树,主要可以用来O(logn)求静态区间的第k小数。

总所周知,普通线段树每次修改会遍历logn个点,那么我们在每次修改时都把这logn个点复制一份出来再修改,生成一个历史版本,就是可持久化线段树了,这里每一个点都是动态开点,而不是提前开好的,所以每个点内要存他的左右儿子节点的编号,不再是传统线段树的 2*p 和 2*p+1。

主席树有两种写法,一个是提前分配好所有空间,一个是使用指针。

需要注意的是,指针的写法时空复杂度一般要比普通写法大一倍,在这道洛谷板题的具体表现就是:普通写法(634ms,40.79MB),指针写法(1.18s,101.34MB)。

所以比赛时还是用普通写法稳妥一点,空间开到 N<<6 就保证够了。

代码(提前分配空间)

#include <bits/stdc++.h>
using namespace std;
using i64 = long long;

constexpr int N = 2e5;
struct node {
    int l, r;
    int sum;
} tr[N << 6];
int cnt;
int add(int p, int l, int r, int x) {
    int u = ++cnt;
    tr[u] = tr[p];
    tr[u].sum++;
    if (l == r) return u;
    int m = (l + r) / 2;
    if (x <= m) {
        tr[u].l = add(tr[u].l, l, m, x);
    } else {
        tr[u].r = add(tr[u].r, m + 1, r, x);
    }
    return u;
}
int query(int p, int q, int l, int r, int k) {
    if (l == r) return l;
    int m = (l + r) / 2;
    int x = tr[tr[q].l].sum - tr[tr[p].l].sum;
    if (x >= k) {
        return query(tr[p].l, tr[q].l, l, m, k);
    } else {
        return query(tr[p].r, tr[q].r, m + 1, r, k - x);
    }
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int n, m;
    cin >> n >> m;
    vector<int> a(n);
    for (int i = 0; i < n; i++) {
        cin >> a[i];
    }
    auto b = a;
    sort(b.begin(), b.end());
    b.erase(unique(b.begin(), b.end()), b.end());
    int tot = b.size();
    auto getid = [&](int x) {
        return lower_bound(b.begin(), b.end(), x) - b.begin();
    };
    vector<int> rt(n + 1);
    for (int i = 0; i < n; i++) {
        rt[i + 1] = add(rt[i], 0, tot - 1, getid(a[i]));
    }

    for (int i = 0; i < m; i++) {
        int l, r, k;
        cin >> l >> r >> k;
        l--, r--;

        int id = query(rt[l], rt[r + 1], 0, tot - 1, k);
        cout << b[id] << '\n';
    }

    return 0;
}

代码(指针)

#include <bits/stdc++.h>
using namespace std;
using i64 = long long;

struct node {
    node *l;
    node *r;
    int sum;
    node() : l{}, r{}, sum{} {}
};
node *add(node *p, int l, int r, int x) {
    node *n = new node();
    if (p) *n = *p;
    n->sum++;
    if (l == r) return n;
    int m = (l + r) / 2;
    if (x <= m) {
        n->l = add(n->l, l, m, x);
    } else {
        n->r = add(n->r, m + 1, r, x);
    }
    return n;
}
int query(node *p, node *q, int l, int r, int k) {
    if (l == r) return l;
    int nq = (q && q->l ? q->l->sum : 0);
    int np = (p && p->l ? p->l->sum : 0);
    int num = nq - np;
    int m = (l + r) / 2;
    if (num >= k) {
        return query(p ? p->l : nullptr, q ? q->l : nullptr, l, m, k);
    } else {
        return query(p ? p->r : nullptr, q ? q->r : nullptr, m + 1, r, k - num);
    }
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int n, m;
    cin >> n >> m;
    vector<int> a(n);
    for (int i = 0; i < n; i++) {
        cin >> a[i];
    }
    auto b = a;
    sort(b.begin(), b.end());
    b.erase(unique(b.begin(), b.end()), b.end());
    int tot = b.size();
    auto getid = [&](int x) {
        return lower_bound(b.begin(), b.end(), x) - b.begin();
    };
    vector<node *> rt(n + 1);
    for (int i = 0; i < n; i++) {
        rt[i + 1] = add(rt[i], 0, tot - 1, getid(a[i]));
    }

    for (int i = 0; i < m; i++) {
        int l, r, k;
        cin >> l >> r >> k;
        l--, r--;

        int id = query(rt[l], rt[r + 1], 0, tot - 1, k);
        cout << b[id] << '\n';
    }

    return 0;
}