[THUSCH2017] 大魔法师

发布时间 2023-11-11 20:13:10作者: 未抑郁的刘大狗

前期准备

1.熟练的掌握区间修改线段树
2.对矩阵乘法有部分的了解,知道如何使用
3.对卡常十分精通

题目大意

题目给定 \(n\) 个三元组,每个三元组包含 \(A\)\(B\)\(C\) 三个元素,一共进行 \(m\) 次操作,分别是下面七种之一:

1.令给定区间内,\(A_i=A_i+B_i\)
2.令给定区间内,\(B_i=B_i+C_i\)
3.令给定区间内,\(C_i=C_i+A_i\)
4.令给定区间内,\(A_i=A_i+v\)
5.令给定区间内,\(B_i=B_i\times v\)
6.令给定区间内,\(C_i=v\)
7.查询区间内每个元素 \(A\)\(B\)\(C\) 累加得到的和。

其中 \(1 \le n \le 2.5\times 10^5\)\(1 \le m \le 2.5\times 10^5\) ,元素中每个值对 \(998244353\) 取模。
题目时间限制 \(5\) 秒!!!

思路

因为题目要求写一个动态的区间修改和区间最大值,所以自然地就可以想到区间修改线段树
但是因为这道题目处理的是三元组,所以如果一个一个处理的话,线段树的 lazy 数组的会非常难写。

于是顺理成章的,就应该使用矩阵乘法给线段树进行优化。
每个三元组在运算的过程中都可以看做一个矩阵,而这 \(7\) 个操作就只需要推 \(6\) 个矩阵并写一个区间求和就结束了。

做法

操作 \(1\)\(2\)\(3\)

如果你已经做个一些题目的话,那么你应该可以顺理成章的推出后面 \(3\) 个式子

\[\begin{bmatrix} A_i+B_i & B_i & C_i \end{bmatrix}= \begin{bmatrix} A_i & B_i & C_i \end{bmatrix} \times \begin{bmatrix} 1 & 0 &0 \\ 1 & 1 & 0\\ 0 & 0 &1 \end{bmatrix} \]

\[\begin{bmatrix} A_i & B_i+C_i & C_i \end{bmatrix}= \begin{bmatrix} A_i & B_i & C_i \end{bmatrix} \times \begin{bmatrix} 1 & 0 &0 \\ 0 & 1 &0 \\ 0 & 1 &1 \end{bmatrix} \]

\[\begin{bmatrix} A_i & B_i & C_i+a_i \end{bmatrix}= \begin{bmatrix} A_i & B_i & C_i \end{bmatrix} \times \begin{bmatrix} 1 & 0 &1 \\ 0 & 1 &0 \\ 0 & 0 &1 \end{bmatrix} \]

操作\(4\)\(5\)\(6\)

现在你会发现这个给定的 \(v\) 不知道应该塞到哪里了,于是我们就应该添加辅助的维度
可以将原来的 \(\begin{bmatrix}A_i & B_i &C_i\end{bmatrix}\) 替换为 \(\begin{bmatrix}A_i & B_i &C_i &1\end{bmatrix}\) 辅助增加

\[\begin{bmatrix} A_i+v & B_i & C_i & 1 \end{bmatrix}= \begin{bmatrix} A_i & B_i & C_i & 1 \end{bmatrix} \times \begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & 1 &0 &0\\ 0 & 0 &1 &0 \\ v & 0 & 0 & 0 \end{bmatrix} \]

\[\begin{bmatrix} A_i & B_i\times v & C_i & 1 \end{bmatrix}= \begin{bmatrix} A_i & B_i & C_i & 1 \end{bmatrix} \times \begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & v &0 &0\\ 0 & 0 &1 &0 \\ 0 & 0 & 0 & 0 \end{bmatrix} \]

操作 \(7\)

这个操作其实就可以直接写一个线段树的区间求和就可以了

死亡记录

AC Code


#include<bits/stdc++.h>
#define int long long
#define m(s1) memset(s1.a,0,sizeof(s1.a))
const int mod=998244353;
const int N=1000005;
inline int read(){
	int x=0;
	char ch=getchar();
	while(ch>'9'||ch<'0') ch=getchar();
	while(ch<='9'&&ch>='0') x=(x<<1)+(x<<3)+ch-48,ch=getchar();
	return x;
}
struct node{
	int a[3][3],n,m;
	node(){memset(a,0,sizeof(a));}
	friend node operator + (const node a,const node b){
		node s; s.n=a.n,s.m=a.m;
		m(s);
		for(register int i=0;i<a.n;++i)
		for(register int j=0;j<a.m;++j)
			s.a[i][j]=(a.a[i][j]+b.a[i][j])%mod;
		return s;
	}
	friend node operator * (const node a,const node b){
		node s; s.n=a.n,s.m=b.m;
		m(s);
		for(register int i=0;i<a.n;++i)
		for(register int k=0;k<a.m;++k)
		for(register int j=0;j<b.m;++j)
			s.a[i][j]=(s.a[i][j]+a.a[i][k]*b.a[k][j]%mod)%mod;
		return s;
	}
	friend node operator * (const node a,const int b){
		node s; s.n=a.n,s.m=a.m;
		m(s);
		for(register int i=0;i<a.n;++i)
		for(register int j=0;j<a.m;++j)
			s.a[i][j]=a.a[i][j]*b%mod;
		return s;
	}
}s[8],sum[N],lazy1[N],lazy2[N];
inline void pre(){
	s[1].a[0][0]=s[1].a[1][1]=s[1].a[2][2]=s[1].a[1][0]=1;
	s[1].n=s[1].m=3;
	
	s[2].a[0][0]=s[2].a[1][1]=s[2].a[2][2]=s[2].a[2][1]=1;
	s[2].n=s[2].m=3;
	
	s[3].a[0][0]=s[3].a[1][1]=s[3].a[2][2]=s[3].a[0][2]=1;
	s[3].n=s[3].m=3;
	
	s[4].a[0][0]=-1;
	s[4].n=1,s[4].m=3;
	
	s[5].a[0][0]=s[5].a[2][2]=1;
	s[5].a[1][1]=-1;
	s[5].n=s[5].m=3;
	
	s[6].a[0][0]=s[6].a[1][1]=1;
	s[6].n=s[6].m=3;
	s[7].a[0][2]=-1;
	s[7].n=1,s[7].m=3;
}
int n,m;
inline void updata(int k,int l,int r){
	int mid=(l+r)/2;
	lazy2[k*2]=lazy2[k*2]*lazy2[k];
	lazy2[k*2+1]=lazy2[k*2+1]*lazy2[k];
	lazy1[k*2]=lazy1[k*2]*lazy2[k]+lazy1[k];
	lazy1[k*2+1]=lazy1[k*2+1]*lazy2[k]+lazy1[k];
	sum[k*2]=sum[k*2]*lazy2[k]+lazy1[k]*(mid-l+1);
	sum[k*2+1]=sum[k*2+1]*lazy2[k]+lazy1[k]*(r-mid);
	m(lazy1[k]),m(lazy2[k]);
	lazy2[k].a[0][0]=lazy2[k].a[1][1]=lazy2[k].a[2][2]=1;
}
inline void pre_lazy(int k){
	lazy1[k].n=1,lazy1[k].m=3;
	lazy2[k].n=3,lazy2[k].m=3;
	lazy2[k].a[0][0]=lazy2[k].a[1][1]=lazy2[k].a[2][2]=1;
}
void build(int k,int l,int r){
	sum[k].n=1,sum[k].m=3;
	pre_lazy(k);
	if(l==r){
		sum[k].a[0][0]=read();
		sum[k].a[0][1]=read();
		sum[k].a[0][2]=read();
		return;
	}
	int mid=(l+r)/2;
	build(k*2,l,mid);
	build(k*2+1,mid+1,r);
	sum[k]=sum[k*2]+sum[k*2+1];
}
void up(int k,int l,int r,int ll,int rr,node &v,bool flag){
	if(ll<=l&&rr>=r){
		if(flag==0) lazy1[k]=lazy1[k]*v,lazy2[k]=lazy2[k]*v,sum[k]=sum[k]*v;
		else lazy1[k]=lazy1[k]+v,sum[k]=sum[k]+v*(r-l+1);
		return ;
	}int mid=(l+r)/2;
	updata(k,l,r);
	if(ll<=mid) up(k*2,l,mid,ll,rr,v,flag);
	if(mid<rr) up(k*2+1,mid+1,r,ll,rr,v,flag);
	sum[k]=sum[k*2]+sum[k*2+1];
}
node ask(int k,int l,int r,int ll,int rr){
	if(ll<=l&&rr>=r) return sum[k];
	int mid=(l+r)/2;
	updata(k,l,r);
	node res; res.n=1,res.m=3;
	if(ll<=mid) res=ask(k*2,l,mid,ll,rr);
	if(mid<rr) res=res+ask(k*2+1,mid+1,r,ll,rr);
	return res;
}
signed main(){
	pre();
	n=read();
	build(1,1,n);
	m=read();
	for(register int i=1,op,l,r;i<=m;++i){
		op=read(),l=read(),r=read();
		if(op<=3) up(1,1,n,l,r,s[op],0);
		if(op==4) s[4].a[0][0]=read(),up(1,1,n,l,r,s[4],1);
		if(op==5) s[5].a[1][1]=read(),up(1,1,n,l,r,s[5],0);
		if(op==6) s[7].a[0][2]=read(),up(1,1,n,l,r,s[6],0),up(1,1,n,l,r,s[7],1);
		if(op==7){
			node ans=ask(1,1,n,l,r);
			printf("%lld %lld %lld\n",ans.a[0][0],ans.a[0][1],ans.a[0][2]);
		}
	}
	return 0;
}