李超线段树学习笔记

发布时间 2023-09-04 19:56:56作者: 洛谷Augury

李超线段树学习笔记

P4097 【模板】李超线段树 / [HEOI2013] Segment

题意

要求在平面直角坐标系下维护两个操作:

  1. 在平面上加入一条线段。记第 \(i\) 条被插入的线段的标号为 \(i\)
  2. 给定一个数 \(k\),询问与直线 \(x = k\) 相交的线段中,交点纵坐标最大的线段的编号。

做法

首先,我们有一颗线段树。

线段树的节点维护的信息是:

  1. 左右孩子编号
  2. 没有在祖先节点出现过的,在这个节点的 \(mid\) 位置值最大的线段

现在我们要插入一条线段。我们将这条线段表示成 \(y=kx+b\) 的形式,考虑如何插入。

  1. 如果这个位置没有线段,那么就放在这里,返回
  2. 如果这个位置有线段,且当前线段更为优秀,就把当前线段和这个节点的线段交换,进行步骤 \(3\)
  3. 如果当前线段在左端点的值和这个节点的线段比更为优秀,那么就尝试用这条线段更新左子树的答案;如果右端点更优秀,就尝试更新右端点的答案

这样就可以对一个线段树节点插入线段了。

考虑到这是模板题,我们要对一个区间插入线段而不是一个线段树节点,所以我们要找到这个区间对应的线段树节点,然后对这 \(\log\) 个节点插入线段,这样的复杂度就是 \(\log^2\) 的。

查询的时候,我们就对这个点进行线段树查询,把从根到这个位置对应的节点都取个 \(\max\) 就行了。

修改 \(\log^2\),查询 \(\log\)

代码

#include<bits/stdc++.h>
using namespace std;
inline int read(){
	int ans=0;bool op=0;char ch=getchar();
	while(ch<'0'||'9'<ch){if(ch=='-')op=1;ch=getchar();}
	while('0'<=ch&&ch<='9'){ans=(ans<<1)+(ans<<3)+(ch^48);ch=getchar();}
	if(op)return -ans;
	return ans;
}
const int maxn=1e5+10;
const double esp=1e-6;
const int inf=39989;
const int awa=1e9;
int n;
struct seg{
	double k,b;
}s[maxn];
struct node{
	int l,r,val;
}tr[maxn];
int rt,ct,cs;
int lastans=0;
void add(int x0,int y0,int x1,int y1){
	double k,b;
	if(x0^x1){
		k=1.0*(y0-y1)/(x0-x1);
		b=y0-k*x0;
	}
	else{
		k=0;
		b=max(y0,y1);
	}
	s[++cs]=(seg){k,b};
}
double f(int u,int x){
	return 1.0*s[u].k*x+s[u].b;
}
int cmp(int u,int v,int x){
	double i=f(u,x),j=f(v,x);
	if(abs(i-j)<=esp)return u<v;
	if(i>j)return 1;
	else return -1;
}
void ins(int s,int t,int &p,int u){
	if(!p)p=++ct;
	int &v=tr[p].val;
	int mid=(s+t)>>1;
	if(cmp(u,v,mid)==1)swap(u,v);
	if(cmp(u,v,s)==1)ins(s,mid,tr[p].l,u);
	if(cmp(u,v,t)==1)ins(mid+1,t,tr[p].r,u);
}
void update(int l,int r,int s,int t,int &p,int u){
	if(!p)p=++ct;
	if(l<=s&&t<=r){
		ins(s,t,p,u);
		return;
	}
	int mid=(s+t)>>1;
	if(l<=mid)update(l,r,s,mid,tr[p].l,u);
	if(mid<r)update(l,r,mid+1,t,tr[p].r,u);
}
int query(int pos,int s,int t,int p){
	if(!p)return 0;
	if(s==t)return tr[p].val;
	int mid=(s+t)>>1;
	int ans=tr[p].val,tmp;
	if(pos<=mid)tmp=query(pos,s,mid,tr[p].l);
	else tmp=query(pos,mid+1,t,tr[p].r);
	if(cmp(ans,tmp,pos)==0)return min(ans,tmp);
	else if(cmp(ans,tmp,pos)==1)return ans;
	else return tmp;
}
signed main(){
	n=read();
	for(int i=1;i<=n;i++){
		int op=read();
		if(op){
			int x0=(read()+lastans-1)%inf+1;
			int y0=(read()+lastans-1)%awa+1;
			int x1=(read()+lastans-1)%inf+1;
			int y1=(read()+lastans-1)%awa+1;
			if(x0>x1)swap(x0,x1),swap(y0,y1);
			add(x0,y0,x1,y1);
			update(x0,x1,1,inf,rt,cs);
		}
		else{
			int x=(read()+lastans-1)%inf+1;
			lastans=query(x,1,inf,rt);
			cout<<lastans<<'\n';
		}
	}
	return 0;
}

结合斜率优化

现在我们有一个斜率优化式子:

\[dp_i=\max dp_j+y_j-x_j\times k_i \]

然后我们遇到诸多问题,比如这个不单调,那个不单调

正常的斜优就要上平衡树了,但实际上我们可以把这个东西变一变,把原先一个凸包上的点变成一个线段。

那么新的 \(dp\) 式子就是:

\[dp_i=\max dp_j+b_j+x_i\times k_j \]

这个东西直接李超树维护即可。

因为是全局插入,所以直接对根节点进行插入即可,复杂度 \(O(\log n)\)

合并

众所周知,普通的线段树是可以合并的。

现在我们来合并李超树。

直接看代码吧,就是先跑完孩子再合并自己。

int merge(int x,int y,int s,int t){
	if(!x||!y)return x|y;
	if(s==t){
		if(f(s,tr[x].v)>f(s,tr[y].v))tr[x].v=tr[y].v;
		return x;
	}
	int mid=(s+t)>>1;
	tr[x].l=merge(tr[x].l,tr[y].l,s,mid);
	tr[x].r=merge(tr[x].r,tr[y].r,mid+1,t);
	upd(s,t,x,tr[y].v);
	return x;
}

CF932F Escape Through Leaf

题意

有一颗 \(n\) 个节点的树(节点从 \(1\)\(n\) 依次编号),根节点为 \(1\)。每个节点有两个权值,第 \(i\) 个节点的权值为 \(a_i,b_i\)

你可以从一个节点跳到它的子树内任意一个节点上。从节点 \(x\) 跳到节点 \(y\) 一次的花费为 \(a_x\times b_y\)。跳跃多次走过一条路径的总费用为每次跳跃的费用之和。请分别计算出每个节点到达树的每个叶子节点的费用中的最小值。

注意:就算根节点的度数为 \(1\),根节点也不算做叶子节点。另外,不能从一个节点跳到它自己.

\(2\leq n\leq 10^5\)\(-10^5\leq a_i\leq 10^5\)\(-10^5\leq b_i\leq 10^5\)

dqa2022大佬在CF932F Escape Through Leaf 题解中证明了,这样合并的总复杂度是 \(O(n\log n)\) 的。

做法

容易发现这就是一个斜率优化的东西,但是要合并两个凸包,就很麻烦。

所以我们使用李超树。

全局插入线段的复杂度是 \(O(n\log n)\) 的,李超树合并的总复杂度也是 \(O(n\log n)\) 的,所以这道题的总复杂度就是 \(O(n\log n)\) 的。

然后就做完了。

代码

#include<bits/stdc++.h>
#define int long long
using namespace std;
inline int read(){
	int ans=0;bool op=0;char ch=getchar();
	while(ch<'0'||'9'<ch){if(ch=='-')op=1;ch=getchar();}
	while('0'<=ch&&ch<='9'){ans=(ans<<1)+(ans<<3)+(ch^48);ch=getchar();}
	if(op)return -ans;
	return ans;
}
const int maxn=1e5+10;
const int inf=1e15;
const int lim=1e5+10;
int n;
int a[maxn],b[maxn];
vector<int>g[maxn];
int dp[maxn];
struct seg{
	int k,b;
	seg(int kk=0,int bb=inf){
		k=kk,b=bb;
	}
};
struct node{
	int l,r;
	seg v;
}tr[maxn*20];
int rt[maxn],cn;
int f(int x,seg l){
	return l.k*x+l.b;
}
void upd(int s,int t,int &p,seg u){
	if(!p)p=++cn;
	seg &v=tr[p].v;
	int mid=(s+t)>>1;
	if(f(mid,v)>f(mid,u))swap(u,v);
	if(s==t)return;
	if(f(s,u)<f(s,v))upd(s,mid,tr[p].l,u);
	if(f(t,u)<f(t,v))upd(mid+1,t,tr[p].r,u);
}
int qry(int pos,int s,int t,int p){
	if(!p)return inf;
	if(s==t)return f(pos,tr[p].v);
	int mid=(s+t)>>1;
	int ans=f(pos,tr[p].v);
	if(pos<=mid)ans=min(ans,qry(pos,s,mid,tr[p].l));
	else ans=min(ans,qry(pos,mid+1,t,tr[p].r));
	return ans;
}
void update(int &p,seg u){
	u.b=u.b-lim*u.k;
	upd(0,lim<<1,p,u);
}
int query(int p,int x){
	return qry(x+lim,0,lim<<1,p);
}
int merge(int x,int y,int s,int t){
	if(!x||!y)return x|y;
	if(s==t){
		if(f(s,tr[x].v)>f(s,tr[y].v))tr[x].v=tr[y].v;
		return x;
	}
	int mid=(s+t)>>1;
	tr[x].l=merge(tr[x].l,tr[y].l,s,mid);
	tr[x].r=merge(tr[x].r,tr[y].r,mid+1,t);
	upd(s,t,x,tr[y].v);
	return x;
}
void dfs(int now,int fa){
	for(int nxt:g[now]){
		if(nxt==fa)continue;
		dfs(nxt,now);
		rt[now]=merge(rt[now],rt[nxt],0,lim<<1);
	}
	if(g[now].size()>1||!fa)dp[now]=query(rt[now],a[now]);
	update(rt[now],(seg){b[now],dp[now]});
}
signed main(){
	n=read();
	for(int i=1;i<=n;i++)a[i]=read();
	for(int i=1;i<=n;i++)b[i]=read();
	for(int i=1;i<n;i++){
		int u=read(),v=read();
		g[u].push_back(v);
		g[v].push_back(u);
	}
	dfs(1,0);
	for(int i=1;i<=n;i++)cout<<dp[i]<<' ';
	return 0;
}