Day24回溯算法、基础组合
By HQWQF 2024/01/07
笔记
第77题. 组合
给定两个整数 n
和 k
,返回范围 [1, n]
中所有可能的 k
个数的组合。
你可以按 任何顺序 返回答案。
示例 1:
输入: n = 4, k = 2
输出:
[
[2,4],
[3,4],
[2,3],
[1,2],
[1,3],
[1,4],
]
回溯算法
我们可以用这题理解什么是回溯算法。
对于这个问题的示例1,比较容易可以想到这样的代码:
int n = 4;
for (int i = 1; i <= n; i++) {
//j = i + 1避免组合重复
for (int j = i + 1; j <= n; j++) {
cout << i << " " << j << endl;
}
}
但是如果我们改变一下数据,要求k=3,那么这里的for循环就需要3个了,以此类推k是多少就需要几个for循环。
可以发现这里涉及可变的代码长度,联想到使用递归来承载这些代码。
在本题嵌套循环中,对于每一个外层循环到的元素,我们对其展开下一层的循环,最深一层的循环执行一次就得出一个结果,最深一层的循环循环一遍后,上一层的再循环一次,又继续最深一层的循环,这些过程可以理解为一颗树。
最深一层的循环循环一遍后,上一层的再循环一次,又继续最深一层的循环,这样回到上一层的过程就是回溯算法中回溯二字的由来。
对于回溯算法,我们有一个代码模板:
void backtracking(参数) {
if (终止条件) {
//存放结果,相当于最深一层的循环循环一遍后;
return;
}
for (//选择:本层集合中元素(树中节点孩子的数量就是集合的大小)) {
//处理节点;
backtracking(路径,选择列表); // 递归
//回溯,撤销处理结果,相当于回到上层循环
}
}
对于本题,需要存放结果时是结果数组达到目标长度,所以终止条件是数组达到目标长度。
另外我们还需要在参数中实现迭代法中j = i + 1避免组合重复的操作,所以每深入一层递归都要用一个参数去传递类似的值。
回溯算法代码
class Solution {
private:
vector<vector<int>> result; // 存放符合条件结果的集合
vector<int> path; // 用来存放符合条件结果
void backtracking(int n, int k, int startIndex) {
if (path.size() == k) {
result.push_back(path);
return;
}
for (int i = startIndex; i <= n; i++) {
path.push_back(i); // 处理节点
backtracking(n, k, i + 1); // 递归
path.pop_back(); // 回溯,撤销处理的节点
}
}
public:
vector<vector<int>> combine(int n, int k) {
result.clear(); // 可以不写
path.clear(); // 可以不写
backtracking(n, k, 1);
return result;
}
};
可以发现,其实回溯算法和暴力法没有什么区别,只是使用了递归实现以应对不同数据而已。
剪枝
在上面的代码中我们使用了startIndex作为避免组合重复的手段,但是在过程树中的某些分支,startIndex已经取得过大以至于这些分支继续下去到末端也不会满足path.size() == k的条件,也就是凑不满k个元素,对于这些分支我们可以剪掉以减少工作量。
k - path.size()就是我们这条支离拼购k个元素还差的元素数
n - i就是我们还能放的元素数
如果这条支继续下去最终能拼够k个元素,那么我们还能放的元素数就必须大于等于这条支离拼购k个元素还差的元素数。
然后在代码中这个等式左边要+1,因为在代码中, i的初始值是1而k - path.size()初始值是0,这两个并不是同步的。
所以有:
n - i + 1 >= k - path.size()
//换种写法
i <=n -(k - path.size()) + 1
剪枝代码
class Solution {
private:
vector<vector<int>> result;
vector<int> path;
void backtracking(int n, int k, int startIndex) {
if (path.size() == k) {
result.push_back(path);
return;
}
for (int i = startIndex; i <= n - (k - path.size()) + 1; i++) { // 优化的地方
path.push_back(i); // 处理节点
backtracking(n, k, i + 1);
path.pop_back(); // 回溯,撤销处理的节点
}
}
public:
vector<vector<int>> combine(int n, int k) {
backtracking(n, k, 1);
return result;
}
};