P5175 数列

发布时间 2023-08-29 23:44:09作者: Scorilon

Updated

2023.07.05 修正了一处笔误,在此感谢@DWT8125

题解

首先先推一下柿子,因为数据范围很大,所以考虑矩阵加速递推。

根据题意给的递推式,可得:

\[\begin{aligned} a_i^2 &= (x \times a_{i-1}+y \times a_{i-2})^2\\ &= x^2\times a_{i-1}^2+y^2\times a_{i-2}^2+2xy\times a_{i-1} \times a_{i-2} \end{aligned}\]

由此我们可以构造出答案矩阵:

\[T=\begin{bmatrix} a_i^2 & a_{i-1}^2 & a_i \times a_{i-1} &ans_i\\ \end{bmatrix}\]

其中 \(ans_i=\sum_{j=1}^i a_j^2\).

那么由于

\[\begin{cases} x^2\times a_{i-1}^2+y^2\times a_{i-2}^2+2xy\times a_{i-1} \times a_{i-2}+0 \times ans_{i-1}=a_i^2\\ 1\times a_{i-1}^2+0\times a_{i-2}^2+0\times a_{i-1} \times a_{i-2}+0 \times ans_{i-1}=a_{i-1}^2\\ x\times a_{i-1}^2+0\times a_{i-2}^2+y\times a_{i-1} \times a_{i-2}+0 \times ans_{i-1}=a_{i} \times a_{i-1}\\ x^2\times a_{i-1}^2+y^2\times a_{i-2}^2+2xy\times a_{i-1} \times a_{i-2}+1 \times ans_{i-1}=ans_i\\ \end{cases}\]

可得

\[\begin{bmatrix} a_{i-1}^2 & a_{i-2}^2 & a_{i-1} \times a_{i-2} &ans_{i-1}\\ \end{bmatrix} \times \begin{bmatrix} x^2 & 1 & x & x^2\\ y^2 & 0 & 0 & y^2\\ 2xy & 0 & y & 2xy\\ 0 & 0 & 0 & 1\\ \end{bmatrix}=T\]

因为我们已知第一第二项,所以只需得到转移矩阵的 \(n-2\) 次方即可,此时要特判 \(n=1\)\(n=2\) 的情况,防止程序出错。

然后再套一下矩阵加速的板子就可以。

单次时间复杂度为 \(O(\log n \times m^3)\),其中 \(m\) 为矩阵边长,总时间复杂度为 \(O(T\log n \times m^3)\).

#include <cstdio>
#include <iostream>
#include <cstring>

using namespace std;

#define ll long long

const ll mod=1e9+7;

ll t,n,a1,a2,x,y;

struct matrix {
	int n,m;
	ll e[5][5];
	void clean() {
		for(int i=1;i<=n;i++) {
			for(int j=1;j<=m;j++) {
				e[i][j]=0;
			}
		}
	}
};

matrix mul(matrix a,matrix b) {
	matrix ans;
	ans.n=a.n;
	ans.m=b.m;
	ans.clean();
	for(int i=1;i<=a.n;i++) {
		for(int j=1;j<=b.m;j++) {
			for(int k=1;k<=a.m;k++) {
				ans.e[i][j]=(ans.e[i][j]+(a.e[i][k]*b.e[k][j])%mod)%mod;
			}
		}
	}
	return ans;
}

matrix quickly_pow(matrix a,ll k) {
	matrix ans;
	ans.n=a.n;
	ans.m=a.m;
	for(int i=1;i<=ans.n;i++) {
		for(int j=1;j<=ans.m;j++) {
			if(i==j) ans.e[i][j]=1;
			else ans.e[i][j]=0;
		}
	}
	while(k) {
		if(k&1) ans=mul(ans,a);
		a=mul(a,a);
		k>>=1;
	}
	return ans;
}

int main() {
	scanf("%lld",&t);
	while(t--) {
		scanf("%lld%lld%lld%lld%lld",&n,&a1,&a2,&x,&y);
		if(n==1) {//特判
			printf("%lld\n",a1*a1%mod);
			continue;
		} else if(n==2) {
			printf("%lld\n",(a1*a1%mod+a2*a2%mod)%mod);
			continue;
		}
		matrix init;
		init.n=4;init.m=4;
		init.clean();
		init.e[1][1]=x*x%mod;init.e[1][2]=1;init.e[1][3]=x;init.e[1][4]=x*x%mod;
		init.e[2][1]=y*y%mod;init.e[2][4]=y*y%mod;
		init.e[3][1]=2*x%mod*y%mod;init.e[3][3]=y;init.e[3][4]=2*x%mod*y%mod;
		init.e[4][4]=1;
		init=quickly_pow(init,n-2);
		matrix ans;
		ans.n=1;ans.m=4;
		ans.clean();
		ans.e[1][1]=a2*a2%mod;ans.e[1][2]=a1*a1%mod;ans.e[1][3]=a1*a2%mod;ans.e[1][4]=(a1*a1%mod+a2*a2%mod)%mod;
		ans=mul(ans,init);
		printf("%lld\n",ans.e[1][4]);
	}
	return 0;
}