P3959 [NOIP2017 提高组] 宝藏 题解

发布时间 2023-12-19 11:54:04作者: Creeper_l

原题链接:P3959

乍一看,感觉像是一道图论的最短路这类的题,但是细想发现用图论似乎不可做。再看到这道题的数据范围 \(n<=12\),立马就可以想到用状压 \(DP\),因为数据范围很状压/。

思路

  • 设计状态

首先来考虑状态的设计。如果按状压 \(DP\) 的套路来设的话,设 \(dp_{i,j}\) 表示选到第 \(i\) 个点,状态为 \(j\) 时的最小代价。但是这个状态有一些问题,因为每个点对应的 \(k\) 的值是会随着 \(i\) 的变化而变化的,这就使得状态不好转移。所以我们可以换一种状态:设 \(dp_{i,j}\) 表示选到第 \(i\) 层,状态为 \(j\) 时的最小代价。这里的“层”是指:将答案的生成树按照深度分成几层,根据定义,这样第 \(i\) 层的 \(k\) 的值就为 \(i-1\)。这样就方便我们进行转移了。

  • 预处理

在转移之前,我们还需要进行一些预处理。设 \(dis_{i,j}\) 表示点 \(i\) 到点 \(j\) 的距离,这可以在输入的时候直接预处理好。接下来设 \(d_{i,j}\) 表示点 \(i\) 到状态 \(j\) 中选了的点(二进制位为 \(1\))的距离中最近的距离是多少。我们可以这样预处理 \(d\) 数组(\(i\)\(j\) 中没选的点,\(k\)\(j\) 中选了的点):

for(int i = 1;i <= n;i++)
	for(int j = 1;j < (1 << n);j++)
		for(int k = 1;k <= n;k++)
			if(((1 << k - 1) & j) && !((1 << i - 1) & j))
				d[i][j] = min(d[i][j],dis[i][k]);

还需要将 \(dp_{1,i}\) 赋值为 \(0\),因为 \(k=0\),其余赋极大值。

  • 方程转移

方程的转移其实就比较简单了。

\(dp_{k,i}=min(dp_{k,i},dp_{k-1,j}+(k-1)\times\sum d_{p,j})\)

其中状态 \(j\) 是状态 \(i\) 的子集,\(p\) 则是状态 \(i\) 中有但是状态 \(j\) 中没有的点。

但是在转移过程中要要注意枚举子集的方式(如下),不然会 \(TLE\)

for(int j = i & (i - 1);j;j = (j - 1) & i)

普通枚举子集的时间复杂度是 \(O(4^{n})\) 的,而优化过后是 \(O(3^{n})\)。总时间复杂度 \(O(n \times 3^{n})\)

  • 计算答案

答案则是最终状态为 \((1<<n)-1\) 中代价最小的。

for(int i = 2;i <= n;i++) ans = min(ans,dp[i][(1 << n) - 1]);

Code

#include<bits/stdc++.h>
using namespace std;
#define int long long
#define inf 0x3f
#define inf_db 127
#define ls id << 1
#define rs id << 1 | 1
#define re register
typedef pair <int,int> pii;
const int MAXN = 13;
const int MAXM = (1 << 13) + 10;
int n,m,x,y,z,dis[MAXN][MAXN],dp[MAXN][MAXM],d[MAXN][MAXM],ans = 0x3f3f3f3f3f3f;
signed main()
{
	cin >> n >> m;
	memset(dp,inf,sizeof dp);
	memset(dis,inf,sizeof dis);
	memset(d,inf,sizeof d);
	for(int i = 1;i <= n;i++) dp[1][1 << i - 1] = 0; 
	for(int i = 1;i <= m;i++)
	{
		cin >> x >> y >> z;
		dis[x][y] = dis[y][x] = min(dis[x][y],z);
	}
	for(int i = 1;i <= n;i++)
		for(int j = 1;j < (1 << n);j++)
			for(int k = 1;k <= n;k++)
				if(((1 << k - 1) & j) && !((1 << i - 1) & j))
					d[i][j] = min(d[i][j],dis[i][k]);	
	for(int i = 1;i < (1 << n);i++)
		for(int j = i & (i - 1);j;j = (j - 1) & i)
		{
			int sum = 0;
			for(int k = 1;k <= n;k++)
				if((1 << k - 1) & (i - j))
				{
					if(d[k][j] > ans) sum = 0x3f3f3f3f;	
					else sum += d[k][j];
				} 
			for(int k = 2;k <= n;k++) dp[k][i] = min(dp[k][i],dp[k - 1][j] + sum * (k - 1));
		} 
	for(int i = 2;i <= n;i++) ans = min(ans,dp[i][(1 << n) - 1]);
	if(ans == 0x3f3f3f3f3f3f) ans = 0;
	cout << ans;
	return 0;
}