线段树-多个懒标记pushdown

发布时间 2023-11-16 12:45:27作者: cxy8

P3373 【模板】线段树 2

这里需要用到两个懒标记,一个懒标记为add,记录加,另一个懒标记为mul,记录乘。
我们需要规定一个优先级,然后考虑如何将懒标记下传。
这里无非有两种顺序,一种是先乘后加,另一种是先加后乘。
我们先看先加后乘

\[(sum + add1) * mul1 \]

当我们的懒标记$ add2 、 mul2 $下传

\[(sum + add1) * mul1 + add2 \quad (1) \]

\[(sum + add1) * mul1 * mul2 \quad (2) \]

我们更新懒标记仍需要化成

\[(sum + add) * mul \]

对于(1)

\[(sum + add1 + \frac{add2} {mul1}) * mul1 \quad (1) \]

可以看出如果要下传add标记就需要add2必须满足mul1,这个显然不容易满足
对于(2)

\[(sum + add1) * mul1 * mul2 \quad (2) \]

我们之间将mul标记更新为$ mul1 * mul2 $ 即可
显然先加后乘,下传加号标记是不好处理的
我们再看先乘后加

\[sum * mul1 + add1 \]

当懒标记$ add2 或 mul2 $下传

\[sum * mul1 + add1 + add2 \quad (1) \]

\[(sum * mul1 + add1) * mul2 \quad (2) \]

对于(1)
我们只需要将add标记更新为$ add1 + add2$即可
对于(2)

\[sum * mul1 * mul2 + add1 * mul2 \quad (2) \]

我们需要将add标记跟新为$ add1 * mul2 , 将mul标记更新为 mul1 * mul2$
我们发现先乘后加的情况,懒标记的下传都很容易做到,所以我们选择先乘后加这种顺序。
下面是代码:


#include <bits/stdc++.h> 
#define LL long long 
#define ls p<<1
#define rs p<<1|1
#define endl '\n'
#define PII pair<int, int>
using namespace std;
using ll = long long;
const int N = 1e6 + 10;
ll n, q, mod, a[N];
struct node
{
	ll l, r, sum, mul, add;
	#define l(x) tree[x].l
	#define r(x) tree[x].r
	#define mul(x) tree[x].mul
	#define add(x) tree[x].add
	#define sum(x) tree[x].sum
} tree[4 * N];

void pushup(int p)
{
	sum(p) = (sum(ls) + sum(rs)) % mod;
}

void build(int p, ll l, ll r)
{
	l(p) = l; r(p) = r; mul(p) = 1, add(p) = 0;
	if(l == r)	
	{
		sum(p) = a[l] % mod;
		return;
	}
	int mid = (l(p) + r(p)) >> 1;
	build(ls, l(p), mid), build(rs, mid + 1, r(p));
	pushup(p);	 
}

void pushdown(int p)
{
	sum(ls) =  (sum(ls) * mul(p) + (r(ls) - l(ls) + 1) * add(p)) % mod;
	sum(rs) =  (sum(rs) * mul(p) + (r(rs) - l(rs) + 1) * add(p)) % mod;
	
	mul(ls) =  (mul(ls) * mul(p)) % mod;
	mul(rs) =  (mul(rs) * mul(p)) % mod;
	
	add(ls) =  (add(ls) * mul(p) + add(p)) % mod;
	add(rs) =  (add(rs) * mul(p) + add(p)) % mod;
	
	add(p) = 0; mul(p) = 1;
}

void modify(int p, ll l, ll r, int op, ll x)
{
	if(l(p) >= l && r(p) <= r)
	{
		if(op == 1)
		{
			sum(p) = (sum(p) * x) % mod;
			mul(p) = (mul(p) * x) % mod;
			add(p) = (add(p) * x) % mod;
		}
		else if(op == 2) 
		{ 
			sum(p) = (sum(p) + (r(p) - l(p) + 1) * x) % mod;
			add(p) = (add(p) + x) % mod;
		}
		return;
	}
	pushdown(p);
	int mid = (l(p) + r(p)) >> 1;
	if(l <= mid)	modify(ls, l, r, op, x);
	if(r > mid)	modify(rs , l, r, op, x);
	pushup(p);
}

ll query(int p, ll l, ll r)
{
	if(l(p) >= l && r(p) <= r)	return sum(p);
	pushdown(p);
	int mid = (l(p) + r(p)) >> 1;
	ll val = 0;
	if(l <= mid)	val = (val + query(ls, l, r)) % mod;
	if(r > mid)		val = (val + query(rs, l, r)) % mod;
	return val;
}

void solve()
{
	cin >> n >> mod;
	for(int i = 1; i <= n; ++ i)	cin >> a[i];
	cin >> q;
	build(1, 1, n);
	for(int i = 1; i <= q; ++ i)
	{
		int op; cin >> op;
		if(op == 1)
		{
			ll l, r, x;	cin >> l >> r >> x;
			modify(1, l, r, op, x);
		}
		else if(op == 2)
		{
			ll l, r, x; cin >> l >> r >> x;
			modify(1, l, r, op, x);
		}
		else
		{
			int l, r; cin >> l >> r;
			cout << query(1, l, r) << endl; 	
		}
	} 
}		
int main()
{
	ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
//	freopen("1.in", "r", stdin);
	solve(); 
	return 0;
}

P1253 扶苏的问题

这道题目也需要两个懒标记,一个把一个区间的所有数变成x的懒标记记为add1,另一个懒标记记为add2

特殊性:操作一会覆盖,假设有操作1,新增懒标记为$ add1 时, 我们就将 add2清空$
当新增一个操作2时,我们就直接增加 $ add2 标记即可 $

#include <bits/stdc++.h> 
#define LL long long 
#define ls p<<1
#define rs p<<1|1
#define PII pair<int, int>
using namespace std;
using ll = long long;
const int N = 1e6 + 10;
int n, q, a[N];
struct node
{
	int l, r, st;
	ll mx, add1, add2;
	#define l(x) tree[x].l
	#define r(x) tree[x].r
	#define add1(x) tree[x].add1
	#define add2(x) tree[x].add2
	#define st(x) tree[x].st
	#define mx(x) tree[x].mx
} tree[4 * N];

void pushup(int p)
{
	mx(p) = max(mx(ls), mx(rs));
}

void build(int p, int l, int r)
{
	l(p) = l; r(p) = r; st(p) = 0;
	if(l == r)	
	{
		mx(p) = a[l];
		return;
	}
	int mid = (l(p) + r(p)) >> 1;
	build(ls, l(p), mid), build(rs, mid + 1, r(p));
	pushup(p);	 
}

void pushdown(int p)
{
	if(st(p))
	{
		st(ls) = st(rs) = 1;
		add1(ls) = add1(rs) = add1(p);
		add2(ls) = add2(rs) = 0;
		mx(ls) = mx(rs) = add1(p);
		st(p) = add1(p) = 0;
	}
	mx(ls) += add2(p);	mx(rs) += add2(p);
	add2(ls) += add2(p);	add2(rs) += add2(p);
	add2(p) = 0;
}

void modify(int p, int l, int r, int op, int x)
{
	if(l(p) >= l && r(p) <= r)
	{
		if(op == 1)
		{
			mx(p) = x;
			st(p) = 1;
			add1(p) = x;
			add2(p) = 0;
		}
		else 
		{
			mx(p) += x;
			add2(p) += x;
		}
		return;
	}
	pushdown(p);
	int mid = (l(p) + r(p)) >> 1;
	if(l <= mid)	modify(ls, l, r, op, x);
	if(r > mid)	modify(rs , l, r, op, x);
	pushup(p);
}

ll query(int p, int l, int r)
{
	if(l(p) >= l && r(p) <= r)	return mx(p);
	pushdown(p);
	int mid = (l(p) + r(p)) >> 1;
	ll val = -1e18;
	if(l <= mid)	val = max(val, query(ls, l, r));
	if(r > mid)		val = max(val, query(rs, l, r));
//	pushup(p); 
	return val;
}

void solve()
{
	cin >> n >> q;
	for(int i = 1; i <= n; ++ i)	cin >> a[i];
	build(1, 1, n); 
	for(int i = 1; i <= q; ++ i)
	{
		int op; cin >> op;
		if(op == 1)
		{
			int l, r, x;	cin >> l >> r >> x;
			modify(1, l, r, op, x);
		}
		else if(op == 2)
		{
			int l, r, x; cin >> l >> r >> x;
			modify(1, l, r, op, x);
		}
		else
		{
			int l, r; cin >> l >> r;
			cout << query(1, l, r) << endl; 	
		}
	} 
	
}		
int main()
{
	ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
//	freopen("1.in", "r", stdin);
	solve(); 
	return 0;
}