主席树

发布时间 2023-10-13 21:50:35作者: lza0v0

权值线段树

思路:

现将数值离散化

每个节点存的是值在 \(l\) ~ \(r\) 之间的数的个数,用线段树维护

作用:

\(k\) 小值 或 \(k\) 大值

查某一数值的排名

查询数组排序

查前驱、后继

求逆序对

相比平衡树:码量小、简单

P1801 黑匣子

离散化:

sort(alls.begin(),alls.end());
alls.erase(unique(alls.begin(),alls.end()),alls.end());

单点修改:

void update(int l,int r,int rt,int pos)
{
	if(l==r){
		t[rt]++;
		return;
	}
	int mid=l+r>>1;
	if(pos<=mid)update(l,mid,ls,pos);
	else update(mid+1,r,rs,pos);
	t[rt]=t[ls]+t[rs];
}

查询第k小的数:

int kth(int l,int r,int rt,int k)
{
	if(l==r)return alls[l];
	int mid=l+r>>1,s1=t[ls];
	if(k<=s1)return kth(l,mid,ls,k);
	else return kth(mid+1,r,rs,k-s1);//k-s1
}

查询x的排名(比x小的有多少个)

int query(int cur,int x)
{
	if(!cur)return 0;
	if(alls[t[cur].l]>x)return 0;
	if(alls[t[cur].r]>x)return query(t[cur].ls,x)+query(t[cur].rs,x);
	return t[cur].cnt;
}

P1774 最接近神的人

求一组数逆序对的数量

用权值线段树做:

读入一个数,查找在它前面比它大的数的个数

相当于查找 \(find(x)+1\) ~ \(alls.size()-1\) 的值的个数有多少个

查询值在 \(l\) ~ \(r\) 之间的数的个数:

int query(int l,int r,int rt,int x,int y)
{
	if(x>y)return 0;
	if(l==x&&r==y){
		return t[rt];
	}
	int mid=l+r>>1;
	if(y<=mid)return query(l,mid,rt<<1,x,y);
	else if(x>mid)return query(mid+1,r,rt<<1|1,x,y);
	else return query(l,mid,rt<<1,x,mid)+query(mid+1,r,rt<<1|1,mid+1,y);
}

主席树

主席树是可持久化权值线段树

给定 \(n\) 个数字, \(m\) 个询问,每次求 \([L,R]\) 内的第 \(k\) 大值

\(n,m\le 100000\)

思路:

权值线段树只能求所有树的第 \(k\) 大值

用前缀和思想,建立 \(n\) 棵线段树,第 \(i\) 棵树维护前\(i\)个数。

查询的时候用第 \(R\) 棵树和第 \(L\) 棵树相减即可。

然而 \(n\) 棵线段树,每棵要开 \(4n\) 的空间,会 \(MLE\)

于是我们就可以用可持久化来优化。

可持久化可以在不超空间的情况下维护每次操作的历史版本

思路:

对于每新插入的一个值,只会改变其到根节点的那一段( 最多\(\log n\)个),于是每新插入一个值只需要新建那一段节点,其他直接接到上一个版本即可。

可持久化

代码:(待注释)

#include<bits/stdc++.h>
using namespace std;
const int N=200010;
int n,m,a[N];
int T[N],cnt,pre;//T[i]:第i棵树的树根
vector<int>alls;
struct SH{
	int l,r,cnt;
	int ls,rs;
}t[N<<5];
int find(int x)
{
	int l=0,r=alls.size()-1;
	while(l<r){
		int mid=l+r+1>>1;
		if(x>=alls[mid])l=mid;
		else r=mid-1;
	}
	return l;
}
void build(int l,int r,int cur)//建树
{
	t[cur].l=l,t[cur].r=r;
	if(l==r)return;
	int mid=l+r>>1;
	t[cur].ls=++cnt;
	build(l,mid,cnt);
	t[cur].rs=++cnt;
	build(mid+1,r,cnt);
}
void add(int l,int r,int cur,int pre,int pos)
{
	t[cur].l=l,t[cur].r=r;
	if(l==r){
		t[cur].cnt+=t[pre].cnt+1;
		return;
	}
	int mid=l+r>>1;
	if(pos<=mid){
		t[cur].ls=++cnt;//只添加变化的节点
		t[cur].rs=t[pre].rs;
		add(l,mid,cnt,t[pre].ls,pos);
	}
	else{
		t[cur].rs=++cnt;
		t[cur].ls=t[pre].ls;
		add(mid+1,r,cnt,t[pre].rs,pos);
	}
	t[cur].cnt=t[t[cur].ls].cnt+t[t[cur].rs].cnt;
}
int kth(int cur,int pre,int k)
{
	if(t[cur].l==t[cur].r){
		return alls[t[cur].l];
	}
	int s1=t[cur].ls,s2=t[pre].ls;
	if(t[s1].cnt-t[s2].cnt>=k)return kth(s1,s2,k);//t[s1].cnt-t[s2].cnt
	else return kth(t[cur].rs,t[pre].rs,k-t[s1].cnt+t[s2].cnt);
}
int main()
{
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++){
		scanf("%d",&a[i]);
		alls.push_back(a[i]);
	}
	sort(alls.begin(),alls.end());
	alls.erase(unique(alls.begin(),alls.end()),alls.end());
	build(0,alls.size()-1,1);
	T[0]=1;
	for(int i=1;i<=n;i++){
		T[i]=++cnt;
		add(0,alls.size()-1,T[i],T[i-1],find(a[i]));
	}
	while(m--){
		int l,r,k;
		scanf("%d%d%d",&l,&r,&k);
		printf("%d\n",kth(T[r],T[l-1],k));
	}
	return 0;
}

动态区间kth

考虑树套树,修改时维护 \(T[i]\)\(T[i+lowbit(i)]\)

#include<bits/stdc++.h>
using namespace std;
const int N=100010;
int n,m,idx;
int a[N],op[N],l1[N],r1[N],k1[N];
int rt[N];
int stk[N],tt;
int stk2[N],tt2;
vector<int>alls;
struct tree{
	int l,r,cnt;
}tr[N*800];
int lowbit(int x)
{
	return x&(-x);
}
int find(int x)
{
	int l=1,r=alls.size();
	while(l<r){
		int mid=l+r+1>>1;
		if(alls[mid-1]<=x)l=mid;
		else r=mid-1;
	}
	return l;
}
void pushup(int u)
{
	tr[u].cnt=tr[tr[u].l].cnt+tr[tr[u].r].cnt;
}
void modify(int &u,int l,int r,int pos,int val)
{
	if(!u)u=++idx;
	if(l==r){
		tr[u].cnt+=val;
		return;
	}
	int mid=l+r>>1;
	if(pos<=mid)modify(tr[u].l,l,mid,pos,val);
	else modify(tr[u].r,mid+1,r,pos,val);
	pushup(u);
}
void qev(int x,int y)
{
	tt=tt2=0;
	while(x){
		stk[++tt]=rt[x];
		x-=lowbit(x);
	}
	while(y){
		stk2[++tt2]=rt[y];
		y-=lowbit(y);
	}
}
void rap(int d)
{
	if(!d){
		for(int i=1;i<=tt;i++)stk[i]=tr[stk[i]].l;
		for(int i=1;i<=tt2;i++)stk2[i]=tr[stk2[i]].l;
	}
	else{
		for(int i=1;i<=tt;i++)stk[i]=tr[stk[i]].r;
		for(int i=1;i<=tt2;i++)stk2[i]=tr[stk2[i]].r;
	}
}
int query(int l,int r,int k)
{
	if(l==r){
		return alls[l-1];
	}
	int r1=0,r2=0;
	for(int i=1;i<=tt;i++){
		r1+=tr[tr[stk[i]].l].cnt;
	} 
	for(int i=1;i<=tt2;i++){
		r2+=tr[tr[stk2[i]].l].cnt;
	}
	int mid=l+r>>1;
	if(r2-r1>=k){
		rap(0);
		return query(l,mid,k);
	}
	else{
		rap(1);
		return query(mid+1,r,k-r2+r1);
	}
}
void output(int u,int l,int r)
{
	if(!u)return;
	if(l==r){
		for(int i=1;i<=tr[u].cnt;i++)cout<<alls[l-1]<<" ";
	}
	int mid=l+r>>1;
	output(tr[u].l,l,mid);
	output(tr[u].r,mid+1,r);
}
int main()
{
	ios::sync_with_stdio(false);
	cin.tie(0);cout.tie(0);
	cin>>n>>m;
	for(int i=1;i<=n;i++){
		cin>>a[i];
		alls.push_back(a[i]);
	}
	for(int i=1;i<=m;i++){
		string opt;
		int l,r,k;
		cin>>opt;
		if(opt=="Q"){
			op[i]=0;
			cin>>l>>r>>k;
			l1[i]=l;r1[i]=r;k1[i]=k;
		}
		else{
			op[i]=1;
			cin>>l>>r;
			l1[i]=l,r1[i]=r;
			alls.push_back(r);
		}
	}
	sort(alls.begin(),alls.end());
	alls.erase(unique(alls.begin(),alls.end()),alls.end());
	for(int i=1;i<=n;i++){
		int p=find(a[i]),j=i;
		while(j<N){
			modify(rt[j],1,alls.size(),p,1);
			j+=lowbit(j);
		}
	}

	for(int i=1;i<=m;i++){
		if(!op[i]){
			qev(l1[i]-1,r1[i]);
			cout<<query(1,alls.size(),k1[i])<<endl;
		}
		else{
			int x=find(a[l1[i]]),y=find(r1[i]),z=l1[i];
			a[l1[i]]=r1[i];
			while(z<N){
				modify(rt[z],1,alls.size(),x,-1);
				modify(rt[z],1,alls.size(),y,1);
				z+=lowbit(z); 
			}
		}
	}
	return 0;	
}