广州大学第十七届ACM大学生程序设计竞赛 L. 因子模仿 - hard version 线段树维护矩阵

发布时间 2023-04-11 22:34:58作者: qdhys

传送门

大致思路:

  观察发现,茉美香胜利会叠加对手所有状态,茉美香失败会被对手叠加所有状态。我们可以用矩阵[a1, a2, b1, b2]表示两个人的状态(其中a1, a2表示茉美香, b1, b2表示对手)茉美香赢了之后的状态是[a1 + b1, a2 + b2, b1, b2], 茉美香输了之后的状态是[a1, b1, a1 + b1, a2 + b2]。[a1, a2, b1, b2]乘上一个什么样的矩阵会变成[a1 + b1, a2 + b2, b1, b2],我们手推一下会发现是如下的矩阵,下面用win表示这个矩阵。

1, 0, 0, 0;
0, 1, 0, 0;
1, 0, 1, 0;
0, 1, 0, 1;

  [a1, a2, b1, b2]乘上一个什么样的矩阵会变成[a1, a2, a1 + b1, a2 + b2],我们手推一下会发现是如下的矩阵,下面用lost表示这个矩阵。

1, 0, 1, 0;
0, 1, 0, 1;
0, 0, 1, 0;
0, 0, 0, 1;

  我们观察题意发现区间修改和区间查询,所以考虑线段树维护矩阵,我们知道矩阵乘法具有结合律,我们对于茉美香赢赢输输赢这五个操作会得到乘法是[a1, a2, b1, b2] * win * win * lost * lost * win,那么我们会发现[a1, a2, b1, b2]只出现了一次,4 * 4大小的win或者lost出现了很多次,所以我们线段树上维护这样的4 * 4大小的矩阵乘法就行了,最后拿出来乘以[a1, a2, b1, b2]这个矩阵即可。但是这题要求翻转,如果每次翻转我们重新算就很复杂, 所以我们多维护一个矩阵,也就是考虑线段树维护两个矩阵,一个表示正确的状态,另一个表示翻转后的状态,这样的话我们每次进行修改操作,递归到一个被包含的区间可以直接交换两个矩阵并且打上一个懒标记,这样会大大降低时间复杂度。

#include <iostream>
#include <cstring>
#include <iomanip>
#include <algorithm>
#include <stack>
#include <queue>
#include <numeric>
#include <cassert>
#include <bitset>
#include <cstdio>
#include <vector>
#include <unordered_set>
#include <cmath>
#include <map>
#include <unordered_map>
#include <set>
#include <deque>
#include <tuple>
#include <array>
 
#define all(a) a.begin(), a.end()
#define cnt0(x) __builtin_ctz(x)
#define endl '\n'
#define ll long long
#define ull unsigned long long
#define cntone(x) __builtin_popcount(x)
#define db double
#define fs first
#define se second
#define AC main(void)
#define ls u << 1
#define rs u << 1 | 1
typedef std::pair<int, int> PII;

const int N = 1e6 + 10;
const int M = 4e5 + 10;

inline int max(int a, int b){return a > b ? a : b;}
inline int min(int a, int b){return a > b ? b : a;}


const int MOD = 1e9 + 7;
void add(int &a, int b) {
	a += b;
	if (a >= MOD) a -= MOD;
}
void del(int &a, int b) {
	a -= b;
	if (a < 0) a += MOD;
}
template <typename T>

struct Matrix{

	T a[4][4];

	Matrix<T> operator*(const Matrix<T>& B)const {

		Matrix<T> res;
		memset(res.a, 0, sizeof res.a);
		for (int i = 0; i < 4; i++) {
			for (int j = 0; j < 4; j++) {
				for (int k = 0; k < 4; k++)
					add(res.a[i][j], 1ll * a[i][k] * B.a[k][j] % MOD);
			}
		}
		return res;
	}

	void print() {
		for (int i = 0; i < 4; i++)
			for (int j = 0; j < 4; j++) std::cout << a[i][j] << " \n"[j == 3];
	}
};

int n, m;

char str[N];
int a1, a2, b1, b2;

struct Node{
	Matrix<int> A, B;

	int tag;
}tr[N << 2];

inline void pushup(int u){
	tr[u].A = tr[ls].A * tr[rs].A;
	tr[u].B = tr[ls].B * tr[rs].B;
}

inline void build(int u, int L, int R){
	memset(tr[u].A.a, 0, sizeof tr[u].A.a), memset(tr[u].B.a, 0, sizeof tr[u].B.a);
	if(L == R){
		if(str[L] == '1'){//win
			tr[u].A.a[0][0] = 1, tr[u].A.a[0][1] = 0, tr[u].A.a[0][2] = 0, tr[u].A.a[0][3] = 0;
			tr[u].A.a[1][0] = 0, tr[u].A.a[1][1] = 1, tr[u].A.a[1][2] = 0, tr[u].A.a[1][3] = 0;
			tr[u].A.a[2][0] = 1, tr[u].A.a[2][1] = 0, tr[u].A.a[2][2] = 1, tr[u].A.a[2][3] = 0;
			tr[u].A.a[3][0] = 0, tr[u].A.a[3][1] = 1, tr[u].A.a[3][2] = 0, tr[u].A.a[3][3] = 1;
			
			tr[u].B.a[0][0] = 1, tr[u].B.a[0][1] = 0, tr[u].B.a[0][2] = 1, tr[u].B.a[0][3] = 0;
			tr[u].B.a[1][0] = 0, tr[u].B.a[1][1] = 1, tr[u].B.a[1][2] = 0, tr[u].B.a[1][3] = 1;
			tr[u].B.a[2][0] = 0, tr[u].B.a[2][1] = 0, tr[u].B.a[2][2] = 1, tr[u].B.a[2][3] = 0;
			tr[u].B.a[3][0] = 0, tr[u].B.a[3][1] = 0, tr[u].B.a[3][2] = 0, tr[u].B.a[3][3] = 1;
		
		}else{
			tr[u].B.a[0][0] = 1, tr[u].B.a[0][1] = 0, tr[u].B.a[0][2] = 0, tr[u].B.a[0][3] = 0;
			tr[u].B.a[1][0] = 0, tr[u].B.a[1][1] = 1, tr[u].B.a[1][2] = 0, tr[u].B.a[1][3] = 0;
			tr[u].B.a[2][0] = 1, tr[u].B.a[2][1] = 0, tr[u].B.a[2][2] = 1, tr[u].B.a[2][3] = 0;
			tr[u].B.a[3][0] = 0, tr[u].B.a[3][1] = 1, tr[u].B.a[3][2] = 0, tr[u].B.a[3][3] = 1;
			
			tr[u].A.a[0][0] = 1, tr[u].A.a[0][1] = 0, tr[u].A.a[0][2] = 1, tr[u].A.a[0][3] = 0;
			tr[u].A.a[1][0] = 0, tr[u].A.a[1][1] = 1, tr[u].A.a[1][2] = 0, tr[u].A.a[1][3] = 1;
			tr[u].A.a[2][0] = 0, tr[u].A.a[2][1] = 0, tr[u].A.a[2][2] = 1, tr[u].A.a[2][3] = 0;
			tr[u].A.a[3][0] = 0, tr[u].A.a[3][1] = 0, tr[u].A.a[3][2] = 0, tr[u].A.a[3][3] = 1;
		
		}

		return ;
	}
	
	int mid = L + R >> 1;
	build(ls, L, mid);
	build(rs, mid + 1, R);
	
	pushup(u);
}

inline void pushdown(int u){
	if(tr[u].tag & 1){
		for(int i = 0; i < 4; i ++)
			for(int j = 0; j < 4; j ++){
				std::swap(tr[ls].A.a[i][j], tr[ls].B.a[i][j]);
				std::swap(tr[rs].A.a[i][j], tr[rs].B.a[i][j]);
			}
		tr[ls].tag ++;
		tr[rs].tag ++;
		tr[u].tag = 0;
	}
}

inline void modify(int u, int L, int R, int l, int r){
	if(L >= l && R <= r){
		for(int i = 0; i < 4; i ++)
			for(int j = 0; j < 4; j ++){
				std::swap(tr[u].A.a[i][j], tr[u].B.a[i][j]);
			}
		tr[u].tag ++;
		return ;
	}
	
	pushdown(u);
	
	int mid = L + R >> 1;
	if(l <= mid)	modify(ls, L, mid, l, r);
	if(r > mid)	modify(rs, mid + 1, R, l, r);
	
	pushup(u);
}

inline Matrix<int> query(int u, int L, int R, int l, int r){
	if(L >= l && R <= r)	return tr[u].A;
	
	pushdown(u);
	
	int mid = L + R >> 1;
	
	if(r <= mid)	return query(ls, L, mid, l, r);
	if(l > mid)	return query(rs, mid + 1, R, l, r);
	
	return query(ls, L, mid, l, r) * query(rs, mid + 1, R, l, r);

}

inline void solve(){
	std::cin >> n >> m;
	std::cin >> a1 >> b1 >> a2 >> b2;
	
	std::cin >> (str + 1);
	build(1, 1, n);
	
	while(m --){
		int op, l, r;
		std::cin >> op >> l >> r;
		if(op == 1){
			modify(1, 1, n, l, r);
		}else{
			Matrix<int> MY;
			memset(MY.a, 0, sizeof MY.a);
			MY.a[0][0] = a1, MY.a[0][1] = a2, MY.a[0][2] = b1, MY.a[0][3] = b2;
			Matrix<int> res = MY * query(1, 1, n, l, r);
		
			int ans = 0;
			add(ans, 1ll * res.a[0][0] * res.a[0][1] % MOD);
			del(ans, 1ll * res.a[0][2] * res.a[0][3] % MOD);
			std::cout << ans << '\n';
		}
	}
}

int main(void){
	std::ios::sync_with_stdio(false);
	std::cin.tie(0);
	std::cout.tie(0);

	int _ = 1;
	//std::cin >> _;
	while(_ --)	solve();
	
	return 0;
}