DFS 剪枝

发布时间 2023-10-21 09:03:42作者: 单南松

DFS

剪枝

\(DFS\) 是一种常见的算法,大部分情况下,很少会爆搜为正解的题目。因为 \(DFS\) 的时间复杂度特别高。

我们可以先写一段 dfs 的伪代码

int ans = 最坏情况,  now;  // now 为当前答案

void dfs(传入数值) 
{
    if (到达目的地) 
    {
        ans = 从当前解与已有解中选最优;
    }
    for (遍历所有可能性)
    {
        if (可行) 
        {
            进行操作;
            dfs(缩小规模);
            撤回操作;
        }        
    }
}

记忆化搜索

因为在搜索中,相同的传入值往往会带来相同的解,那我们就可以用数组来记忆

int g[N];  // 定义记忆化数组
int ans = 最坏情况, now;

void dfs(传入数值) 
{
    if (g[规模] != 无效数值) return;  // 或记录解,视情况而定
    if (到达目的地) ans = 从当前解与已有解中选最优;  // 输出解,视情况而定
    for (遍历所有可能性)
    {
         if (可行) 
        {
            进行操作;
            dfs(缩小规模);
            撤回操作;
        }       
    }
}

int main() 
{
    // ...
    memset(g, 无效数值, sizeof(g));  // 初始化记忆化数组
    // ...
}

最优性剪枝

在搜索中导致运行慢的原因还有一种,就是在当前解已经比已有解差时仍然在搜索,那么我们只需要判断一下当前解是否已经差于已有解。

int ans = 最坏情况, now;

void dfs(传入数值) 
{
    if (now比ans的答案还要差) return;
    if (到达目的地) ans = 从当前解与已有解中选最优;
    for (遍历所有可能性)
    {
        if (可行) 
        {
            进行操作;
            dfs(缩小规模);
            撤回操作;
        }        
    }
}

可行性剪枝

在搜索过程中当前解已经不可用了还继续搜索下去也是运行慢的原因。

int ans = 最坏情况, now;

void dfs(传入数值) 
{
    if (当前解已不可用) return;
    if (到达目的地) ans = 从当前解与已有解中选最优;
    for (遍历所有可能性)
    {
        if (可行) 
        {
            进行操作;
            dfs(缩小规模);
            撤回操作;
        }
    }
}

剪枝思路

剪枝思路有很多种,大多需要对于具体问题来分析,在此简要介绍几种常见的剪枝思路。

  • 极端法:考虑极端情况,如果最极端(最理想)的情况都无法满足,那么肯定实际情况搜出来的结果不会更优了。

  • 调整法:通过对子树的比较剪掉重复子树和明显不是最有「前途」的子树。

  • 数学方法:比如在图论中借助连通分量,数论中借助模方程的分析,借助不等式的放缩来估计下界等等。

例题

[NOIP2002 普及组] 选数

题目传送门
此题并不需要剪枝,注释代码

#include <bits/stdc++.h>

#define int long long
#define rint register int
#define endl '\n'

using namespace std;

const int N = 2e1 + 5;

int n, k;
int a[N];
int ans;

bool isprime(int p)
{
    for (rint i = 2; i * i <= p; i++)
    {
        if (p % i == 0)
        {
            return false;
        }
    }
    return true;
}

void dfs(int cnt, int sum, int start)
//cnt 代表现在选择了多少个数
//sum 表示当前的和
//start 表示从第几个数开始
{
    if (cnt == k)
    {
        if (isprime(sum))
        {
            ans++;
        }
        return ;
    }
    
    for (rint i = start; i <= n; i++)
    {
        dfs(cnt + 1, sum + a[i], i + 1);
    }
}

signed main()
{
    cin >> n >> k;
    
    for (rint i = 1; i <= n; i++)
    {
        cin >> a[i];
    }
    
    dfs(0, 0 ,1);
    
    cout << ans << endl;
    
    return 0;
}

小木棍

题目传送门
非常经典的剪枝题。

  • 可行性剪枝:记录上一次失败的值, 如果这次还是的话, 那么肯定是不能选择的
  • 最优性剪枝:比 maxx 小的一定不是最优,所以从 maxx 开始搜索

其他剪枝见注释代码

#include <bits/stdc++.h>

#define rint register int
#define endl '\n'

using namespace std;

const int N = 1e2 + 5;

int n;
int a[N];
bool v[N];
int len, cnt;
int sum;
int maxx;

bool dfs(int x, int now, int last)
// x 为当前第几段了
//now 记录当前段的长度
//last 为上一次已经选过的值
{
    if (x > cnt)//如果当前搜到的段数已经在 cnt 之后,说明可行
    {
        return 1;
    }
    
    if (now == len)//如果长度刚刚好,直接下一次搜索
    {
        return dfs(x + 1, 0, 1);
    }

    int fail = 0; 
    //记录上一次失败的值,如果这次还是的话,那么肯定是不能选择的
    for (rint i = last; i <= n; i++)
    {
        if (!v[i] && now + a[i] <= len && fail != a[i])
        {
            v[i] = 1;
            if (dfs(x, now + a[i], i + 1)) 
            {
                return 1;
            }
            fail = a[i];
            v[i] = 0;
            if (now == 0 || now + a[i] == len)
            {
                return 0;
            }
        }
    }

    return 0;
}

int main()
{
    cin >> n;
    
    for (rint i = 1; i <= n; i++)
    {
        cin >> a[i];
        sum += a[i];
        maxx = max(maxx, a[i]);
        //找出最大值和所有数的和
    }

    sort(a + 1, a + n + 1);
    reverse(a + 1, a + n + 1);//从大到小排序

    for (len = maxx; len < sum; len++)
    //比maxx小的一定不是最优,所以从maxx开始搜索
    {
        if (sum % len)
        //一定分不出来
        {
            continue;
        }
        cnt = sum / len;
        memset(v, 0, sizeof v);
        if (dfs(1, 0, 1))
        {
            break;
        }
    }

    cout << len << endl;

    return 0;
}

[NOI1999] 生日蛋糕

题目传送门

此题更多的是数学上的剪枝 + 贪心

#include <bits/stdc++.h>

#define rint register int
#define endl '\n'

using namespace std;

const int N = 1e6 + 5;
const int inf = 1e9;

int n, m;
int ans = inf;
int mins[N], minv[N];

void dfs(int c, int v, int s, int h, int r) 
//c 层数, v 体积, s 抹奶油的外表面积 , r 半径, h 高度
{
    if(c == 0)
    {
    	if(v == n) 
    	{
    	    ans = min(ans ,s);
    	}
    	return;
	}
	
    if(v + minv[c] > n) return;
    /*
    当前的体积 > n    return
    */
    if(s + mins[c] > ans) return;
    if(s + 2 * (n - v) / r > ans) return;
    /*
    当前的奶油面积 + 之后的奶油面积 > 现在已求出的的最小奶油面积  return
    */
    
    for (rint i = r - 1; i >= c; i--)
    { 
    	if(c == m)
    	{
    	    s = i * i;
    	}
    	int Maxh = min(h - 1, (n - v - minv[c - 1]) / (i * i)); 
        for (rint j = Maxh; j >= c; j--)
        {
            dfs(c - 1, v + i * i * j, s + 2 * i * j, j, i);            
        }
    }
}

signed main() 
{
    cin >> n >> m;
    
	int MaxR = sqrt(n);
	
	for(int i = 1; i <= n; i++)
	{
		minv[i] = minv[i-1] + i * i * i;
		mins[i] = mins[i-1] + 2 * i * i;
	} 
	
	dfs(m, 0, 0, n, MaxR); 
	
    if(ans == inf)
    {
        ans = 0;
    }
    
    cout << ans << endl;
    
    return 0;
}