权值线段树、树状数组应用

发布时间 2023-04-02 21:22:50作者: asdfo123

权值线段树类应用

最近因为练蓝桥杯,总算搞明白这些东西了(高中三年没搞明白233),放在一起总结一下

树状数组逆序对

经过处理,离散化 和上面的正好是反过来 上面是位置不变 按照大小排序,第一个最大的在位置\(4\) ,而下面的排序更好理解,第\(i\)个数代表原来第 \(i\) 个数在原数组中大小排序后的位置。

在这里插入图片描述

因为要求逆序对,所以将大的放在下面树状数组的叶子节点,如果后面再有结点且是前面放的大的的根节点 则说明按大小排序后面放进去的小,且在后面 符合逆序对。

如果有重复 则这样考虑:有逆序对 顺序对 相等对。相等的话,按照离散化,让前面出现的在是后面出现的的父(祖先)结点,则找逆序对时就不会找重复的了。所以就是让相等话就按照从大到小的顺序排序 后面的就不会计到前面相等的了
like this

在这里插入图片描述

#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
const int N =  2e6+10;
int tr[N];
int a[N];
int s[N];
int n;
int lowbit(int x)
{
    return x & -x;
}

void update(int x, int c)  // 位置x加c
{
    for (int i = x; i <= n; i += lowbit(i)) tr[i] += c;
}

int query(int x)  // 返回前x个数的和
{
    int res = 0;
    for (int i = x; i; i -= lowbit(i)) res += tr[i];
    return res;
}
bool cmp(int x,int y)
{
    if(a[x]==a[y]) return x>y;
    return a[x]>a[y];
}
int main()
{
    cin>>n;
    for(int i=1;i<=n;i++) scanf("%d",&a[i]),s[i] = i;
    sort(s+1,s+1+n,cmp);
    for(int i=1;i<=n;i++) a[s[i]]=i;	//按照大小顺序在弄回去
    long long ans=0;
    for(int i=1;i<=n;i++) cout<<a[i]<<" ";cout<<"\n";
    for(int i=1;i<=n;i++)
    {
        update(a[i],1);
        ans+=query(a[i]-1);
      //  cout<<ans<<" ";
    }
    cout<<ans;   
}

最长上升子序列

朴素,O(n^2)

for(int i = 1;i <= n;i++)
{
    for(int j = 1;j < n;j++)
    {
        dp[i] = max(dp[i],dp[j]+1);
    }
}

\(F_i\)代表以 \(i\) 为长度的子序列的末尾最小值

int n;
	cin>>n;
	for(int i=1;i<=n;i++)
	{
		cin>>a[i];
		f[i]=0x7fffffff;
	}
	f[1]=a[1];
	int len=1;
	for(int i=2;i<=n;i++)
	{
		if(a[i]>f[len])f[++len]=a[i];
		else 
		{
			int k = lower_bound(f+1,f+len+1) - f;
			f[k]=min(a[i],f[k]);//更新最小末尾 
     	}
    }
    cout<<len;

权值线段树

见下面例题

权值树状数组

#include <bits/stdc++.h>
using namespace std;
int read()
{
	int x = 0,f = 1;char ch = getchar();
	while(ch < '0'||ch > '9'){if(ch == '-') f = -1;ch = getchar();}
	while(ch >= '0'&&ch <= '9'){x = x * 10 + ch - '0';ch = getchar();}
	return x*f;
}
const int maxn = 1e5+10;
int a[maxn];
int tr[maxn];
int sz;
int xs[maxn];//存a[i]这些数
int lb(int x)
{
	return x&(-x);
}
void add(int x,int c)
{
	for(int i = x;i <= sz;i += lb(i))  tr[i] = max(tr[i],c);
}
int query(int x)
{
	int ret = 0;
	for(int i = x;i;i -= lb(i)) ret = max(ret,tr[i]);
	return ret;
}
signed main()
{
	int n = read();
	for(int i = 1;i <= n;i++) a[i] = read(),xs[++sz] = a[i];
	sort(xs+1,xs+sz+1);
	sz = unique(xs+1,xs+sz+1)-xs-1;
	for(int i = 1;i <= n;i++)
	{
		int x = lower_bound(xs+1,xs+sz+1,a[i]) - xs;
		f[i] = query(x-1)+1;//在数字xs[x]前面的f[j]最大值+1
		add(x,f[i]);
	}
	cout<<*max_element(f+1,f+n+1);
}

蓝桥杯,最长不下降子序列

img

分析:看到这种对其中连续k个数进行修改的我们就应该想到答案是由三部分组成

因为求的是最长不下降子序列,那么我们可以找到一个最合适的断点 \(i\),使得答案是由区间

\[[1,i],[i+1,i+k],[i+k+1,n] \]

三部分组成,其中区间 \([i+1,i+k]\) 里面的数是可以任意变化的,那么我们只要在区间 \([1,i]\) 和区间 \([i+k+1,n]\) 中找到一个最长不下降子序列 \(b_1,b_2,……,b_m\) ,那么我们就可以将区间 \([i+1,i+k]\) 中的所有数变为某个 \(b_j\) ,使得最长不下降子序列的长度为 \(m+k\) ,所以现在我们的关键问题就是为了求取 \(m\)

一般这种问题就是要设置一个前缀和一个后缀,表示含义如下:

\(f_1[i]\)表示中 \(a[1 \sim i]\)\(a[i]\) 结尾的最长不下降子序列的长度
\(f_2[i]\) 表示 \(a[i \sim n]\) 中以 \(a[i]\) 开头的最长不下降子序列的长度

这两个数组显然可以用权值线段树权值树状数组预处理出来:

\(f_1[i]\) :就是每次在加入 \(a[i]\) 之前,先看一下线段树中以小于等于 \(a[i]\) 的值 结尾的最长不下降子序列的长度的最大值,然后在这个基础上+1即可得到

\(f_2[i]\) :这个要从后往前遍历,这个是在每次加入 \(a[i]\) 之前,先看一下线段树中以大于等于 \(a[i]\) 的值开头的最长不下降子序列的长度的最大值,然后在这个基础上+1即可得到

注意当求出这个值后要用 \(f\) 数组对权值线段树树状数组进行更新

那么我们枚举前半段区间的最长不下降子序列端点 \(i\) ,那么也就代表最长不下降子序列是由 \(a[1 \sim i]\) 中的一部分和 \([i+1 \sim i+k]\) 中的全部以及 \(a[i+k+1 \sim n]\) 中的一部分组成,由于我们枚举的前半段区间的最长不下降子序列的 末尾 ,那么我们就要在区间 \([i+k+1,n]\) 中找到以大于等于 \(a[i]\) 的值开头的最长不下降子序列的长度最大值,这个直接在求解 \(f_2[]\)过程中刚好可以利用权值线段树得到。

答案还有可能就是只有两段区间,这个要分两种情况,一种是只有 \(a[1 \sim i]\) 中的一部分和 \([i+1 \sim i+k]\) 中的全部,或者是只有 \([i+1 \sim i+k]\) 中的全部以及 \(a[i+k+1 \sim n]\) 中的一部分组成,这两种情况直接用for循环遍历一遍即可得到,无非就是一种只用到 \(f_1[]\),另一种只用到 \(f_2[]\)

细节见代码:

//权值树状数组
#include <bits/stdc++.h>
#define int long long
using namespace std;
int read()
{
	int x = 0,f = 1;char ch = getchar();
	while(ch < '0'||ch > '9'){if(ch == '-') f = -1;ch = getchar();}
	while(ch >= '0'&&ch <= '9'){x = x * 10 + ch - '0';ch = getchar();}
	return x*f;
}
const int maxn = 1e5+10;
int f[maxn],a[maxn],g[maxn];//f[i]:长度为i的不下降子序列,最长末尾元素
int tr[maxn];
int sz;int xs[maxn];
int lb(int x)
{
	return x&(-x);
}
void add(int x,int c)
{
	for(int i = x;i <= sz;i += lb(i))  tr[i] = max(tr[i],c);
}
int query(int x)
{
	int ret = 0;
	for(int i = x;i;i -= lb(i)) ret = max(ret,tr[i]);
	return ret;
}
void add1(int x,int c)
{
	for(int i = x;i;i -= lb(i)) tr[i] = max(tr[i],c);
}
int query1(int x)
{
	int ret = 0;
	for(int i = x;i <= sz;i+=lb(i)) ret = max(ret,tr[i]);
	return ret;
}
signed main()
{
	int n = read(),k = read();
	for(int i = 1;i <= n;i++) a[i] = read(),xs[++sz] = a[i];
	sort(xs+1,xs+sz+1);
	sz = unique(xs+1,xs+sz+1)-xs-1;
    //求f_i:以a_i结尾的最长不下降子序列的最大值
	for(int i = 1;i <= n;i++)
	{
		int x = lower_bound(xs+1,xs+sz+1,a[i]) - xs;
		f[i] = query(x)+1;
		add(x,f[i]);
	}
	memset(tr,0,sizeof tr);
	int ans = 0;
    //求g_i:以a_i开头的最长不下降子序列的最大值
	for(int i = n;i >= 1;i--)
	{
		int x = lower_bound(xs+1,xs+sz+1,a[i]) - xs;
		g[i] = query1(x)+1;
		add1(x,g[i]);
		if(i > k)
		{
			int y = lower_bound(xs+1,xs+sz+1,a[i-k-1])-xs;
            //query(y)在g数组中找到此时(从后往前枚举到i时)以≥a[i]数值开头的最大值
			ans = max(ans,f[i-k-1]+k+query1(y));
			ans = max(ans,k+g[i]);
		}
		if(i+k <= n) ans= max(ans,f[i]+k);
	}
	cout<<ans;
}
//权值线段树
#include <bits/stdc++.h>
#define int long long
using namespace std;
int read()
{
	int x = 0,f = 1;char ch = getchar();
	while(ch < '0'||ch > '9'){if(ch == '-') f = -1;ch = getchar();}
	while(ch >= '0'&&ch <= '9'){x = x * 10 + ch - '0';ch = getchar();}
	return x*f;
}
const int maxn = 1e5+10;
int f[maxn],a[maxn],g[maxn];//f[i]:长度为i的不下降子序列,最长末尾元素
int tr[maxn<<2];
int sz;
int xs[maxn];
void build(int p,int l,int r)
{
	if(l == r) 
	{
		tr[p] = 0;
		return;
	}
	int mid = (l+r)>>1;
	build(p<<1,l,mid);
	build(p<<1|1,mid+1,r);
}
void modify(int p,int l,int r,int x,int k)
{
	if(l == r && l == x)
	{
		tr[p] = max(tr[p],k);
		return;
	}
	int mid = (l+r)>>1;
	if(x<=mid) modify(p<<1,l,mid,x,k);
	else modify(p<<1|1,mid+1,r,x,k);
	tr[p] = max(tr[p<<1],tr[p<<1|1]);
}
int query(int p,int l,int r,int L,int R)
{
	if(L <= l&&r <= R) return tr[p];
	int mid = (l+r)>>1;
	int anss = 0;
	if(L <= mid) anss = max(anss,query(p<<1,l,mid,L,R));
	if(R > mid) anss = max(anss,query(p<<1|1,mid+1,r,L,R));
	return anss;
}
signed main()
{
	int n = read(),k = read();
	for(int i = 1;i <= n;i++) a[i] = read(),xs[++sz] = a[i];
	sort(xs+1,xs+sz+1);
	sz = unique(xs+1,xs+sz+1)-xs-1;
	build(1,1,sz);
	for(int i = 1;i <= n;i++)
	{
		int x = lower_bound(xs+1,xs+sz+1,a[i]) - xs;
		f[i] = query(1,1,sz,1,x)+1;
		modify(1,1,sz,x,f[i]);
		//
	}
	memset(tr,0,sizeof tr);
	int ans = 0;
	build(1,1,sz);
	for(int i = n;i >= 1;i--)
	{
		int x = lower_bound(xs+1,xs+sz+1,a[i]) - xs;
		g[i] = query(1,1,sz,x,sz)+1;
	//	cout<<x<<" "<<g[i]<<endl;
		modify(1,1,sz,x,g[i]);
		if(i > k)
		{
			int y = lower_bound(xs+1,xs+sz+1,a[i-k-1])-xs;
			ans = max(ans,f[i-k-1]+k+query(1,1,sz,y,sz));
			ans = max(ans,k+g[i]);
		}
		if(i+k <= n) ans= max(ans,f[i]+k);
	}
	cout<<ans;
}