P5175 题解

发布时间 2023-07-09 18:35:38作者: Hadtsti

题意简述

给出数列 \({a_n}(1\le n\le10^{18})\) 的两项 \(a_1,a_2\) 与递推公式 \(a_n=xa_{n-1}+ya_{n-2}\),求:

\[S_n=\sum_{k=1}^{n}a_k^2\mod (10^9+7) \]

题目分析

一看见 \(1\le n\le 10^{18}\),就大概能知道要用 \(O(\log n)\) 级别的算法。再一看递推,就知道要用矩阵快速幂了。

题目所求的答案是前 \(n\) 项的 \(a_k^2\) 之和,如果我们能快速求出每个 \(a_k^2\),那么再用矩阵快速幂递推求和是十分简单的。

注意到 \(a_{k+1}=xa_k+ya_{k-1}\),那么就有 \(a_{k+1}^2=x^2a_k^2+y^2a_{k-1}^2+2xya_{k}a_{k-1}\),如果我们维护了 \(a_k^2\)\(a_{k-1}^2\),就可以递推地求出 \(a_{k+1}^2,a_{k+2}^2\) 等等。

还没完,注意到 \(a_{k+1}\) 展开后还有一项 \(2xya_{k}a_{k-1}\),去掉系数就还剩下 \(a_{k}a_{k-1}\)。维护它应该怎么处理呢?将 \(a_k\) 拆开得到 \(a_{k}a_{k-1}=(xa_{k-1}+ya_{k-2})a_{k-1}=xa_{k-1}^2+ya_{k-1}a_{k-2}\),其中 \(a_{k-1}^2\) 已经维护了,所剩的只有 \(ya_{k-1}a_{k-2}\)。这就意味着我们维护每一个 \(a_{k}a_{k-1}\) 就可以求出 \(a_{k+1}a_k\)\(a_{k+2}a_{k+1}\) 等等。

综上,我们使用答案矩阵
\(\begin{bmatrix} S_{k} & a_k^2 & a_{k-1}^2 & a_ka_{k-1} \end{bmatrix}\) 进行维护。初始值为
\(\begin{bmatrix} S_{2}=a_1^2+a_2^2 & a_2^2 & a_1^2 & a_2a_1 \end{bmatrix}\)

为了将其递推到 \(\begin{bmatrix} S_{k+1} & a_{k+1}^2 & a_{k}^2 & a_{k+1}a_k \end{bmatrix}\),需要有一个转移矩阵:

\[\begin{bmatrix} 1 & 0 & 0 & 0\\ x^2 & x^2 & 1 & x\\ y^2 & y^2 & 0 & 0\\ 2xy & 2xy & 0 & y \end{bmatrix}\]

我们一列一列地进行解释。

对于第 \(2\) 列,有 \(a_{k+1}^2=0S_k+x^2a_k^2+y^2a_{k-1}^2+2xya_ka_{k-1}\)

那么对于第 \(1\) 列,有 \(S_{k+1}=S_k+a_{k+1}^2=1S_k+x^2a_k^2+y^2a_{k-1}^2+2xya_ka_{k-1}\)

对于第 \(3\) 列,有 \(a_k^2=0S_k+1a_k^2+0a_{k-1}^2+0a_ka_{k-1}\),没什么可解释的。

对于第 \(4\) 列,有 \(a_{k+1}a_k=0S_k+xa_k^2+0a_{k-1}^2+ya_ka_{k-1}\)

现在我们只需要计算出转移矩阵的 \(n-2\) 次幂,再乘到答案矩阵上就好了。

需要注意的是要特判 \(n=1\),否则会挂掉。然后就是该开 long long 的地方一定要开,不然也会挂。

分析一下时间复杂度。一次矩阵乘法是 \(O(4^3)=O(64)\),因此矩阵快速幂的时间复杂度为 \(O(64\log n)\),总复杂度为 \(O(64T\log n)\),在 \(T\le 3\times 10^4,1\le n\le 10^{18}\),时限 \(1.50s\) 的情况下用个快读快写就可以过了。

代码实现

#include<bits/stdc++.h>
using namespace std;
const long long mod=1000000007;
int t;
long long n,x,y,a1,a2;
void rd(long long &x)
{
	x=0ll;
	char c=getchar();
	for(;c>'9'||c<'0';c=getchar());
	for(;c<='9'&&c>='0';c=getchar())
		x=(x<<3ll)+(x<<1ll)+c-'0';
}//long long快读 
void rd(int &x)
{
	x=0;
	char c=getchar();
	for(;c>'9'||c<'0';c=getchar());
	for(;c<='9'&&c>='0';c=getchar())
		x=(x<<3)+(x<<1)+c-'0';
}//int快读 
void pt(long long x)
{
	if(x>=10ll)
		pt(x/10ll);
	putchar(x%10ll+'0');
}//快写 
struct matrix
{
	int n,m;
	long long a[5][5];
	void init(long long k)
	{
		for(int i=1;i<=n;i++)
			for(int j=1;j<=m;j++)
				if(i!=j)
					a[i][j]=0ll;
				else
					a[i][j]=k;
	}//初始化 k=0零矩阵,k=1单位矩阵 
	friend matrix operator *(matrix a,matrix b)
	{
		matrix c;
		c.n=a.n,c.m=a.m;
		c.init(0ll);
		for(int i=1;i<=a.n;i++)
			for(int k=1;k<=a.m;k++)
				for(int j=1;j<=b.m;j++)
					c.a[i][j]=(c.a[i][j]+a.a[i][k]*b.a[k][j]%mod)%mod;
		return c;
	}//矩阵乘法 
	friend matrix operator ^(matrix a,long long b)
	{
		matrix res;
		res.n=res.m=a.n;
		res.init(1ll);
		for(;b;b>>=1ll)
		{
			if(b&1ll)
				res=res*a;
			a=a*a;
		}
		return res;
	}//矩阵快速幂 
}ans,trans;
int main()
{
	rd(t);
	while(t--)
	{
		rd(n),rd(a1),rd(a2),rd(x),rd(y);
		if(n==1ll)//特判 n=1,不然会 T 得很惨
		{
			pt(a1*a1%mod);
			putchar('\n');
			continue;
		}
		ans.n=1,ans.m=4;
		ans.a[1][1]=(a2*a2%mod+a1*a1%mod)%mod,ans.a[1][2]=a2*a2%mod,ans.a[1][3]=a1*a1%mod,ans.a[1][4]=a1*a2%mod;//初始化答案矩阵 
		trans.n=trans.m=4;
		trans.init(0ll);
		trans.a[2][1]=trans.a[2][2]=x*x%mod,trans.a[3][1]=trans.a[3][2]=y*y%mod,trans.a[4][1]=trans.a[4][2]=2ll*x*y%mod;
		trans.a[1][1]=trans.a[2][3]=1ll;
		trans.a[2][4]=x,trans.a[4][4]=y;//初始化转移矩阵 
		trans=trans^(n-2ll);//转移矩阵快速幂 
		ans=ans*trans;//答案矩阵乘上转移矩阵 
		pt(ans.a[1][1]);
		putchar('\n');
	}
	return 0;
}