【题解】NOIP2021 - 方差

发布时间 2023-11-05 21:36:49作者: KiharaTouma

NOIP2021 - 方差

https://www.luogu.com.cn/problem/P7962

想当年我第一次站在 noip 赛场上,过了 T1 剩下三题就一题不会了……幸好这题拿了点分水了个一等。


观察操作:若对于连续的三个数 \(a,b,c\),对 \(b\) 进行一次操作后就变成了 \(a,a+c-b,c\)。求出两个数组的差分数组:\(b-a,c-b\)\(c-b,b-a\)。可以发现一次操作相当于差分数组中的邻项交换。

然后又因为将 \(a\) 数组全体加减一个数方差 \(D\) 是不变的……所以我们可以求出长度为 \(n-1\) 的差分数组 \(d\),然后考虑 \(d\) 数组。

我们对方差进行一个推式子,最后得到了

\[n^2D=\sum_{i=1}^{n-1}id_i*(n-i)d_i+2\sum_{i=1}^{n-1}\sum_{j=i+1}^{n-1}id_i*(n-j)d_j \]

感性理解,要使这个最小,那就要将大的 \(d\) 与偏远部分的下标对应,故最优解的差分数组应该是单谷的。

考虑先对 \(d\) 数组从小到大排序,然后 dp。这个式子不好去 dp,回归本源:

\[n^2D=n\sum a_i^2-(\sum a_i)^2 \]

\(f_i\) 表示考虑到前 \(i\)\(d\) 时后面那串的最小值?发现状态中信息不太够。把当前 \(x=\sum a_i\) 放到状态里就可以了。\(f_{i,x}\) 表示前 \(i\)\(d\),构成的 \(a\) 数组和为 \(x\) 时最小的 \(\sum a^2\)

易得转移。\(d_i\) 放在右边(\(s=\sum_{j=1}^i d_j\)):

\[f_{i-1,x} + s^2 \to f_{i,x+s} \]

放在左边:

\[f_{i-1,x} + 2xd_i + id_i^2 \to f_{i,x+id_i} \]

复杂度 \(O(n^2a)\),观察到当 \(n>a\) 时差分数组会有大量 \(0\),而这些 \(0\) 一定出现在排序后的 \(d\) 数组开头,那么这个时候 \(s\) 也为 \(0\),不管是放在哪边都不会产生转移。所以跳过就行。注意这里跳过只是跳过 \(d_i\),而 \(i\) 还是要加的。那么现在复杂度就是 \(O(na^2)\),注意实现,数组要滚动,不要用 memset

//P7962
#include <bits/stdc++.h>
using namespace std;

const int N = 1e4 + 10, M = 610;
int n, a[N], m, mx;
typedef long long ll;
ll f[2][N*M], d[N], s[N];

int main(){
	scanf("%d", &n);
	for(int i = 1; i <= n; ++ i){
		scanf("%d", &a[i]);
	}
	for(int i = 2; i <= n; ++ i){
		d[i-1] = a[i] - a[i-1];
	}
	sort(d + 1, d + n);
	for(int i = 1; i < n; ++ i){
		s[i] = s[i-1] + d[i];
	}
	memset(f, 0x3f, sizeof(f));
	f[0][0] = 0;
	int p = 0;
	for(int i = 1; i < n; ++ i){
		if(!d[i]){
			continue;
		}
		p ^= 1;
		for(int j = 0; j <= mx; ++ j){
			if(f[p^1][j] == 0x3f3f3f3f3f3f3f3f){
				continue;
			}
			f[p][j+s[i]] = min(f[p][j+s[i]], f[p^1][j] + s[i] * s[i]);
			f[p][j+i*d[i]] = min(f[p][j+i*d[i]], f[p^1][j] + i * d[i] * d[i] + 2 * j * d[i]);
			mx = max(mx, j + (int)max(s[i], i * d[i]));
			f[p^1][j] = 0x3f3f3f3f3f3f3f3f;
		}
	}
	ll ans = 9e18;
	for(int j = 0; j <= mx; ++ j){
		if(f[p][j] < 0x3f3f3f3f3f3f3f3f){
			ans = min(ans, f[p][j] * n - (ll)j * j);
		}
	}
	printf("%lld\n", ans);
	return 0;
}