原题链接:P3959
乍一看,感觉像是一道图论的最短路这类的题,但是细想发现用图论似乎不可做。再看到这道题的数据范围 \(n<=12\),立马就可以想到用状压 \(DP\),因为数据范围很状压/。
思路
- 设计状态
首先来考虑状态的设计。如果按状压 \(DP\) 的套路来设的话,设 \(dp_{i,j}\) 表示选到第 \(i\) 个点,状态为 \(j\) 时的最小代价。但是这个状态有一些问题,因为每个点对应的 \(k\) 的值是会随着 \(i\) 的变化而变化的,这就使得状态不好转移。所以我们可以换一种状态:设 \(dp_{i,j}\) 表示选到第 \(i\) 层,状态为 \(j\) 时的最小代价。这里的“层”是指:将答案的生成树按照深度分成几层,根据定义,这样第 \(i\) 层的 \(k\) 的值就为 \(i-1\)。这样就方便我们进行转移了。
- 预处理
在转移之前,我们还需要进行一些预处理。设 \(dis_{i,j}\) 表示点 \(i\) 到点 \(j\) 的距离,这可以在输入的时候直接预处理好。接下来设 \(d_{i,j}\) 表示点 \(i\) 到状态 \(j\) 中选了的点(二进制位为 \(1\))的距离中最近的距离是多少。我们可以这样预处理 \(d\) 数组(\(i\) 是 \(j\) 中没选的点,\(k\) 是 \(j\) 中选了的点):
for(int i = 1;i <= n;i++)
for(int j = 1;j < (1 << n);j++)
for(int k = 1;k <= n;k++)
if(((1 << k - 1) & j) && !((1 << i - 1) & j))
d[i][j] = min(d[i][j],dis[i][k]);
还需要将 \(dp_{1,i}\) 赋值为 \(0\),因为 \(k=0\),其余赋极大值。
- 方程转移
方程的转移其实就比较简单了。
\(dp_{k,i}=min(dp_{k,i},dp_{k-1,j}+(k-1)\times\sum d_{p,j})\)
其中状态 \(j\) 是状态 \(i\) 的子集,\(p\) 则是状态 \(i\) 中有但是状态 \(j\) 中没有的点。
但是在转移过程中要要注意枚举子集的方式(如下),不然会 \(TLE\)。
for(int j = i & (i - 1);j;j = (j - 1) & i)
普通枚举子集的时间复杂度是 \(O(4^{n})\) 的,而优化过后是 \(O(3^{n})\)。总时间复杂度 \(O(n \times 3^{n})\)。
- 计算答案
答案则是最终状态为 \((1<<n)-1\) 中代价最小的。
for(int i = 2;i <= n;i++) ans = min(ans,dp[i][(1 << n) - 1]);
Code
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define inf 0x3f
#define inf_db 127
#define ls id << 1
#define rs id << 1 | 1
#define re register
typedef pair <int,int> pii;
const int MAXN = 13;
const int MAXM = (1 << 13) + 10;
int n,m,x,y,z,dis[MAXN][MAXN],dp[MAXN][MAXM],d[MAXN][MAXM],ans = 0x3f3f3f3f3f3f;
signed main()
{
cin >> n >> m;
memset(dp,inf,sizeof dp);
memset(dis,inf,sizeof dis);
memset(d,inf,sizeof d);
for(int i = 1;i <= n;i++) dp[1][1 << i - 1] = 0;
for(int i = 1;i <= m;i++)
{
cin >> x >> y >> z;
dis[x][y] = dis[y][x] = min(dis[x][y],z);
}
for(int i = 1;i <= n;i++)
for(int j = 1;j < (1 << n);j++)
for(int k = 1;k <= n;k++)
if(((1 << k - 1) & j) && !((1 << i - 1) & j))
d[i][j] = min(d[i][j],dis[i][k]);
for(int i = 1;i < (1 << n);i++)
for(int j = i & (i - 1);j;j = (j - 1) & i)
{
int sum = 0;
for(int k = 1;k <= n;k++)
if((1 << k - 1) & (i - j))
{
if(d[k][j] > ans) sum = 0x3f3f3f3f;
else sum += d[k][j];
}
for(int k = 2;k <= n;k++) dp[k][i] = min(dp[k][i],dp[k - 1][j] + sum * (k - 1));
}
for(int i = 2;i <= n;i++) ans = min(ans,dp[i][(1 << n) - 1]);
if(ans == 0x3f3f3f3f3f3f) ans = 0;
cout << ans;
return 0;
}