P9032 [COCI2022-2023#1] Neboderi

发布时间 2023-11-09 09:43:41作者: LHLeisus

原题链接

最暴力的想法是枚举每一个区间进行计算,这样子的复杂度是 \(O(n^3)\),可以用前缀和以及 ST 表优化到 \(O(n^2\log n)\)

满分做法:

注:以下的 \(\gcd(l,r)\) 均指 \([l,r]\) 这个区间所有数的 \(\gcd\)

对于一个长度为 \(V\) 的区间,\(\gcd(l,l),\gcd(l,l+1),\gcd(l,l+2),\dots,\gcd(l,r)\) 至多只有 \(\log h_l\) 种不同的取值,并且是递减的。证明比较简单,一个数 \(x\) 和另一个数求 \(\gcd\),结果一定是 \(\le x\),如果小于 \(x\),也就是说少了一个公因子,这个公因子最小是 \(2\),故极限的情况是每次都除以 \(2\),直到 \(\gcd=1\),这样是至多 \(\log h_l\) 种取值。

也就是说,每个区间会被分成 \(\log\) 个子区间,对于每个子区间,\(\forall i\in\) 子区间,\(\gcd(h_l,h_i)\) 都是一样的,而为了让区间和最大,一定取右端点。我们只需要记录每个子区间的右端点和当前子区间对应的 \(\gcd\) 取值即可。

对于每个点 \(i\),我们都需要求出 \(i\)\(n\) 的断点,显然对于每个 \(i\) 的断点集合都是不一样的,直接求是 \(O(n^2\log n)\) 的(求 \(\gcd\) 还有 \(O(\log n)\))。但是总有一些规律,比如,对于 \(i+1\) 断点集合里的断点,他们在 \(i\) 中可能成为断点,但是不在 \(i+1\) 中的一定不会成为断点。于是我们可以倒序计算,将断点和对应的 \(\gcd\) 作为二元组存进 vector,每一次在 \(vec[i+1]\) 中加入 \(i\),并且将所有断点的 \(\gcd\)\(h_i\) 求一次 \(\gcd\),这时候会出现一些相同的,我们保留最靠右的。由于最多只会有 \(\log\) 个断点,还要求 \(\gcd\) 这样的复杂度是 \(O(n\log^2n)\)。查询时直接遍历断点集合,取区间长度 \(\ge k\) 的进行计算即可,复杂度 \(O(n\log n)\)。最终复杂度为 \(O(n\log^2n)\)

code:

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstring>
#include<string>
#include<utility>
#include<vector>
#include<queue>
#include<bitset>
#include<map>
#define int long long
#define FOR(i,a,b) for(register int i=a;i<=b;i++)
#define ROF(i,a,b) for(register int i=a;i>=b;i--)
#define mp(a,b) make_pair(a,b)
#define pll pair<long long,long long>
#define pii pair<int,int>
#define fi first
#define se second
using namespace std;
inline int read();
typedef long long ll;
typedef double db;
const int N=1e6+5;
const int INF=0x3f3f3f3f;
int n,m,k;
int h[N],s[N];
int gcd(int a,int b){return b?gcd(b,a%b):a;}
int calc(int x,int y){return s[y]-s[x-1];}
vector<pii>vec[N];
int ans=0;
signed main()
{
   n=read(),k=read();
   FOR(i,1,n) h[i]=read(),s[i]=s[i-1]+h[i];
   vec[n].push_back(mp(n,h[n]));
   vector<pii>temp;
   ROF(i,n-1,1){
   	temp.clear();
   	temp.push_back(mp(i,h[i]));
   	for(auto v:vec[i+1]) temp.push_back(v);
   	int las=INF;
   	ROF(j,temp.size()-1,0){
   		auto &v=temp[j];
   		v.se=gcd(v.se,h[i]);
   		if(v.se==las) v.fi=INF;
   		else las=v.se;
   	}
   	for(auto v:temp) if(v.fi!=INF) vec[i].push_back(v);
   }
   int ans=0;
   FOR(i,1,n){
   	for(auto v:vec[i]){
   		if(v.fi-i+1>=k)ans=max(ans,v.se*calc(i,v.fi));
   	}
   }
   printf("%lld",ans);
   return 0;
}


inline int read()
{
   int x=0,f=1;char ch=getchar();
   while(ch<'0'||ch>'9') {if(ch=='-') f=-1;ch=getchar();}
   while(ch>='0'&&ch<='9'){x=(x<<3)+(x<<1)+(ch^48);ch=getchar();}
   return f*x;
}