算法学习Day24回溯算法、基础组合

发布时间 2024-01-07 17:14:27作者: HQWQF

Day24回溯算法、基础组合

By HQWQF 2024/01/07

笔记


第77题. 组合

给定两个整数 nk,返回范围 [1, n] 中所有可能的 k 个数的组合。

你可以按 任何顺序 返回答案。

示例 1:

输入: n = 4, k = 2

输出:

[

  [2,4],

  [3,4],

  [2,3],

  [1,2],

  [1,3],

  [1,4],

]

回溯算法

我们可以用这题理解什么是回溯算法。

对于这个问题的示例1,比较容易可以想到这样的代码:

int n = 4;
for (int i = 1; i <= n; i++) {
    //j = i + 1避免组合重复
    for (int j = i + 1; j <= n; j++) {
        cout << i << " " << j << endl;
    }
}

但是如果我们改变一下数据,要求k=3,那么这里的for循环就需要3个了,以此类推k是多少就需要几个for循环。

可以发现这里涉及可变的代码长度,联想到使用递归来承载这些代码。

在本题嵌套循环中,对于每一个外层循环到的元素,我们对其展开下一层的循环,最深一层的循环执行一次就得出一个结果,最深一层的循环循环一遍后,上一层的再循环一次,又继续最深一层的循环,这些过程可以理解为一颗树

最深一层的循环循环一遍后,上一层的再循环一次,又继续最深一层的循环,这样回到上一层的过程就是回溯算法中回溯二字的由来

对于回溯算法,我们有一个代码模板:

void backtracking(参数) {
    if (终止条件) {
        //存放结果,相当于最深一层的循环循环一遍后;
        return;
    }

    for (//选择:本层集合中元素(树中节点孩子的数量就是集合的大小)) {
        //处理节点;
        backtracking(路径,选择列表); // 递归
        //回溯,撤销处理结果,相当于回到上层循环
    }
}

对于本题,需要存放结果时是结果数组达到目标长度,所以终止条件是数组达到目标长度。

另外我们还需要在参数中实现迭代法中j = i + 1避免组合重复的操作,所以每深入一层递归都要用一个参数去传递类似的值。

回溯算法代码

class Solution {
private:
    vector<vector<int>> result; // 存放符合条件结果的集合
    vector<int> path; // 用来存放符合条件结果
    void backtracking(int n, int k, int startIndex) {
        if (path.size() == k) {
            result.push_back(path);
            return;
        }
        for (int i = startIndex; i <= n; i++) {
            path.push_back(i); // 处理节点
            backtracking(n, k, i + 1); // 递归
            path.pop_back(); // 回溯,撤销处理的节点
        }
    }
public:
    vector<vector<int>> combine(int n, int k) {
        result.clear(); // 可以不写
        path.clear();   // 可以不写
        backtracking(n, k, 1);
        return result;
    }
};

可以发现,其实回溯算法和暴力法没有什么区别,只是使用了递归实现以应对不同数据而已。

剪枝

在上面的代码中我们使用了startIndex作为避免组合重复的手段,但是在过程树中的某些分支,startIndex已经取得过大以至于这些分支继续下去到末端也不会满足path.size() == k的条件,也就是凑不满k个元素,对于这些分支我们可以剪掉以减少工作量。

k - path.size()就是我们这条支离拼购k个元素还差的元素数

n - i就是我们还能放的元素数

如果这条支继续下去最终能拼够k个元素,那么我们还能放的元素数就必须大于等于这条支离拼购k个元素还差的元素数。

然后在代码中这个等式左边要+1,因为在代码中, i的初始值是1而k - path.size()初始值是0,这两个并不是同步的。

所以有:

n - i + 1 >= k - path.size()
//换种写法
i <=n -(k - path.size()) + 1

剪枝代码

class Solution {
private:
    vector<vector<int>> result;
    vector<int> path;
    void backtracking(int n, int k, int startIndex) {
        if (path.size() == k) {
            result.push_back(path);
            return;
        }
        for (int i = startIndex; i <= n - (k - path.size()) + 1; i++) { // 优化的地方
            path.push_back(i); // 处理节点
            backtracking(n, k, i + 1);
            path.pop_back(); // 回溯,撤销处理的节点
        }
    }
public:

    vector<vector<int>> combine(int n, int k) {
        backtracking(n, k, 1);
        return result;
    }
};