嵌套类 PairCompare 以及一些 typedef
模板参数命名参考 pb_ds 库,当第二个模板参数不为 NullType 时,结点值域类型为 std::pair<const Key, Mapped>,否则为 const Key。
public:
class PairCompare {
public:
PairCompare() : key_compare_(KeyCompare()) {}
bool operator()(
const std::pair<const Key, Mapped>& x,
const std::pair<const Key, Mapped>& y) const {
return key_compare_(x.first, y.first);
}
bool operator()(const Key& key, const std::pair<const Key, Mapped>& y) const {
return key_compare_(key, y.first);
}
bool operator()(const std::pair<const Key, Mapped>& x, const Key& key) const {
return key_compare_(x.first, key);
}
protected:
KeyCompare key_compare_;
};
public:
using HgtType = int; // 高度的存储类型为 int, 方便计算平衡因子
using SizeType = unsigned int; // 子树大小用 unsigned int 来存储
using ValueType = typename std::conditional<
std::is_same<Mapped, NullType>::value,
const Key, std::pair<const Key, Mapped>>::type; // 结点值域类型
using ValueCompare = typename std::conditional<
std::is_same<Mapped, NullType>::value,
KeyCompare, PairCompare>::type; // 比较类选取
嵌套类 NodeBase 及其派生类 Node
protected:
struct NodeBase { // 结点类的基类, 没有值域, 值域在其派生类 Node 内
NodeBase(HgtType hgt, SizeType sz, NodeBase* p, NodeBase* l, NodeBase* r)
: height(hgt), size(sz), parent(p), left(l), right(r) {}
HgtType height;
SizeType size;
NodeBase* parent; // 参考 STL 红黑树源码,借助双亲结点做到所有操作无递归
NodeBase* left;
NodeBase* right;
};
struct Node : public NodeBase {
Node(HgtType hgt, SizeType sz, NodeBase* p,
NodeBase* l, NodeBase* r, const ValueType& val)
: NodeBase(hgt, sz, p, l, r), value(val) {}
ValueType value;
};
成员变量
protected:
NodeBase* header_; // 参考 STL, 这是实现上的一个技巧, 当根节点存在时, header 和根节点互为双亲结点
ValueCompare value_compare_;
静态函数,由其他函数及迭代器的成员函数调用
在静态函数中判断指针 x 指向的是 header 的条件为 x->size == 0U
寻找中序遍历顺序中的上一个节点或下一个节点
protected:
static NodeBase* prev_node(NodeBase* x) {
if (x->size == 0U) {
x = x->right;
} else if (x->left != nullptr) {
x = x->left;
while (x->right != nullptr) x = x->right;
} else {
while (x == x->parent->left) x = x->parent;
if (x->left != x->parent) x = x->parent;
}
return x;
}
static NodeBase* next_node(NodeBase* x) {
if (x->size == 0U) {
x = x->left;
} else if (x->right != nullptr) {
x = x->right;
while (x->left != nullptr) x = x->left;
} else {
while (x == x->parent->right) x = x->parent;
if (x->right != x->parent) x = x->parent;
}
return x;
}
平衡因子
static HgtType factor(const NodeBase* const x) {
return (x->left == nullptr ? 0 : x->left->height)
- (x->right == nullptr ? 0 : x->right->height);
}
更新节点的高度及子树大小
static void update_info(NodeBase* const x) {
x->height = 1;
x->size = 1U;
if (x->left != nullptr) {
x->height += x->left->height;
x->size += x->left->size;
}
if (x->right != nullptr) {
x->height = std::max(x->height, x->right->height + 1);
x->size += x->right->size;
}
}
用于处理 LL 型失衡的右旋以及用于处理 RR 型失衡的左旋
static NodeBase* rotate_right(NodeBase* const x) {
NodeBase* const y = x->left;
x->left = y->right;
if (y->right != nullptr) y->right->parent = x;
y->parent = x->parent;
if (x->parent->size == 0U) x->parent->parent = y;
else if (x == x->parent->right) x->parent->right = y;
else x->parent->left = y;
y->right = x;
x->parent = y;
update_info(x), update_info(y);
return y;
}
static NodeBase* rotate_left(NodeBase* const x) {
NodeBase* const y = x->right;
x->right = y->left;
if (y->left != nullptr) y->left->parent = x;
y->parent = x->parent;
if (x->parent->size == 0U) x->parent->parent = y;
else if (x == x->parent->left) x->parent->left = y;
else x->parent->right = y;
y->left = x;
x->parent = y;
update_info(x), update_info(y);
return y;
}
双旋
先将左子节点左旋,再将该节点右旋,处理 LR 型失衡
先将右子节点右旋,再将该节点左旋,处理 RL 型失衡
这里我特意将双旋操作都整理在一个函数里
static NodeBase* rotate_left_then_right(NodeBase* const x) {
NodeBase* y = x->left;
NodeBase* z = y->right;
x->left = z->right;
y->right = z->left;
if (z->right != nullptr) z->right->parent = x;
if (z->left != nullptr) z->left->parent = y;
z->parent = x->parent;
if (x->parent->size == 0U) x->parent->parent = z;
else if (x == x->parent->right) x->parent->right = z;
else x->parent->left = z;
z->right = x;
z->left = y;
x->parent = y->parent = z;
update_info(x), update_info(y), update_info(z);
return z;
}
static NodeBase* rotate_right_then_left(NodeBase* const x) {
NodeBase* y = x->right;
NodeBase* z = y->left;
x->right = z->left;
y->left = z->right;
if (z->left != nullptr) z->left->parent = x;
if (z->right != nullptr) z->right->parent = y;
z->parent = x->parent;
if (x->parent->size == 0U) x->parent->parent = z;
else if (x == x->parent->left) x->parent->left = z;
else x->parent->right = z;
z->left = x;
z->right = y;
x->parent = y->parent = z;
update_info(x), update_info(y), update_info(z);
return z;
}
将某个节点和其中序遍历的后继节点的位置进行对调,只有删除节点时会用到它,把它单独写是因为如果不单独写的话,删除节点的方法太长了,感觉不太好看。。。
static void swap_with_successor(NodeBase* x, NodeBase* y) {
y->height = x->height; // 用该节点的高度及子树大小给后继节点的高度及子树大小赋值
y->size = x->size;
y->left = x->left;
x->left->parent = y;
x->left = nullptr;
if (y == x->right) {
x->right = y->right;
if (y->right != nullptr) y->right->parent = x;
y->parent = x->parent;
y->right = x;
x->parent = y;
} else {
NodeBase* tmp = x->parent;
x->parent = y->parent;
y->parent = tmp;
tmp = x->right;
x->right = y->right;
y->right = tmp;
x->parent->left = x;
if (x->right != nullptr) x->right->parent = x;
y->right->parent = y;
}
if (y->parent->size == 0U) y->parent->parent = y;
else if (x == y->parent->left) y->parent->left = y;
else y->parent->right = y;
}
从某个节点开始向上回溯,根据平衡因子采取调整策略,并更新所经过的节点的信息,直到某个节点的高度没有变或回溯到尽头为止,且不去管停止回溯时上面还有多少节点的信息没更新。
static NodeBase* backtrack(NodeBase* x) {
for (HgtType orig_hgt{}, bal_factor{}; x->size != 0U; x = x->parent) {
orig_hgt = x->height;
bal_factor = factor(x);
switch (bal_factor) {
case 2: {
x = factor(x->left) < 0
? rotate_left_then_right(x)
: rotate_right(x);
} break;
case -2: {
x = factor(x->right) > 0
? rotate_right_then_left(x)
: rotate_left(x);
} break;
default: update_info(x);
}
if (x->height == orig_hgt) return x->parent;
}
return x;
}
参考 STL 实现的迭代器
public:
class Iterator {
public:
Iterator() : baseptr_(nullptr) {}
explicit Iterator(NodeBase* x) : baseptr_(x) {}
ValueType& operator*() const { return static_cast<Node*>(baseptr_)->value; }
ValueType* operator->() const { return &(operator*()); }
Iterator& operator--() {
baseptr_ = AVLTree::prev_node(baseptr_);
return *this;
}
Iterator operator--(int) {
Iterator tmp = *this;
--*this;
return tmp;
}
Iterator& operator++() {
baseptr_ = AVLTree::next_node(baseptr_);
return *this;
}
Iterator operator++(int) {
Iterator tmp = *this;
++*this;
return tmp;
}
friend bool operator==(const Iterator& x, const Iterator& y) {
return x.baseptr_ == y.baseptr_;
}
friend bool operator!=(const Iterator& x, const Iterator& y) {
return x.baseptr_ != y.baseptr_;
}
protected:
friend AVLTree;
NodeBase* baseptr_;
};
成员函数
begin() -- 指向第一个元素的迭代器
last() -- 指向最后一个元素的迭代器
end() -- 指向最后一个元素的下一个位置(即 header)的迭代器
没有元素时,这三个方法返回的迭代器是相同的
Iterator begin() const { return Iterator(header_->left); }
Iterator last() const { return Iterator(header_->right); }
Iterator end() const { return Iterator(header_); }
判断是否为空树及树中元素个数
bool empty() const { return header_->parent == nullptr; }
SizeType size() const { return empty() ? 0U : header_->parent->size; }
插入元素,不允许该元素与树中已有元素重复
std::pair<Iterator, bool> insert_unique(const ValueType& val) {
NodeBase* x = header_->parent;
NodeBase* y = header_;
bool cmp = true;
while (x != nullptr) {
y = x;
cmp = value_compare_(val, static_cast<Node*>(x)->value);
x = cmp ? x->left : x->right; // 若要插入的值小于当前节点的值则往左走, 否则一律往右走, 判断重复的环节后面再进行
}
NodeBase* z = y;
if (cmp) z = prev_node(z); // 如果要插入的位置是 y 的左子节点, 令 z 走到 y 的前驱
if (z != header_ && !value_compare_(static_cast<Node*>(z)->value, val)) // 判重
return std::pair<Iterator, bool>(Iterator(z), false);
x = new Node(1, 1U, y, nullptr, nullptr, val);
if (y == header_) {
y->parent = x;
y->left = y->right = x;
} else if (cmp) {
y->left = x;
if (header_->left == y) header_->left = x;
} else {
y->right = x;
if (header_->right == y) header_->right = x;
}
y = backtrack(y);
while (y != header_) {
++y->size;
y = y->parent;
}
return std::pair<Iterator, bool>(Iterator(x), true);
}
插入元素,允许该元素与树中已有元素重复
std::pair<Iterator, bool> insert_equal(const ValueType& val) {
NodeBase* x = header_->parent;
NodeBase* y = header_;
while (x != nullptr) {
y = x;
++x->size;
x = value_compare_(val, static_cast<Node*>(x)->value) ? x->left : x->right;
}
x = new Node(1, 1U, y, nullptr, nullptr, val);
if (y == header_) {
y->parent = x;
y->left = x;
y->right = x;
} else if (value_compare_(val, static_cast<Node*>(y)->value)) {
y->left = x;
if (header_->left == y) header_->left = x;
} else {
y->right = x;
if (header_->right == y) header_->right = x;
}
backtrack(y);
return std::pair<Iterator, bool>(Iterator(x), true);
}
查找某个元素,如果存在多个,则总是找到中序遍历过程中最先访问到的那个,当然也可以说是最先插入的那个,因为在 insert_equal 方法中,寻找插入位置过程中若遇到相等的元素一律往右走。
Iterator find(const Key& key) const {
NodeBase* x = header_->parent;
NodeBase* y = header_;
while (x != nullptr) {
if (value_compare_(static_cast<Node*>(x)->value, key)) x = x->right;
else y = x, x = x->left;
}
return (y == header_ || value_compare_(key, static_cast<Node*>(y)->value))
? Iterator(header_) : Iterator(y);
}
lower_bound 和 upper_bound,不解释,容易注意到 lower_bound 和 find 相比仅仅是 return 后面的部分有差别。
Iterator lower_bound(const Key& key) const {
NodeBase* x = header_->parent;
NodeBase* y = header_;
while (x != nullptr) {
if (value_compare_(static_cast<Node*>(x)->value, key)) x = x->right;
else y = x, x = x->left;
}
return Iterator(y);
}
Iterator upper_bound(const Key& key) const {
NodeBase* x = header_->parent;
NodeBase* y = header_;
while (x != nullptr) {
if (value_compare_(key, static_cast<Node*>(x)->value)) y = x, x = x->left;
else x = x->right;
}
return Iterator(y);
}
根据排名查值,注意排名是从 0 开始的
Iterator find_by_order(SizeType order) const {
NodeBase* x = header_->parent;
SizeType lsize{};
while (x != nullptr) {
lsize = (x->left == nullptr ? 0U : x->left->size);
if (order < lsize) x = x->left;
else if (order == lsize) return Iterator(x);
else order -= lsize + 1, x = x->right;
}
return Iterator(header_);
}
根据值查排名,注意事项同上
SizeType order_of_key(const Key& key) const {
SizeType res = 0U;
NodeBase* x = header_->parent;
while (x != nullptr) {
if (value_compare_(static_cast<Node*>(x)->value, key))
res += (x->left == nullptr ? 1U : x->left->size + 1U), x = x->right;
else x = x->left;
}
return res;
}
根据迭代器删除单个节点
注意:在 AVL 树中,如果某个节点仅有一棵子树,那么这棵子树一定仅有一个节点,否则就会违反“每个结点的平衡因子的绝对值对不大于 1”这一性质。
Iterator erase(Iterator pos) {
assert(pos.baseptr_ != header_); // header 肯定是不能删除的
NodeBase* x = (pos++).baseptr_; // AVLTree 是 Iterator 的友元
if (x->left != nullptr && x->right != nullptr) { // 如果要 x 既有左子树也有右子树
swap_with_successor(x, pos.baseptr_); // 让 x 和它的后继节点进行交换
} else if (x->parent == header_) { // 如果 x 至多有一棵子树且是根节点, 则特殊对待
int situ = (x->left != nullptr) + ((x->right != nullptr) << 1) - 1;
switch (situ) {
case 1: header_->parent = (header_->left = header_->right); break; // x 只有右子树
case 0: header_->parent = (header_->right = header_->left); break; // x 只有左子树
case -1: { // x 是叶子节点, 显然此时整棵树除 header 外只有 x 一个节点
header_->parent = nullptr;
header_->left = header_;
header_->right = header_;
} break;
default: assert(false); // 如果 default 后面的代码永远不应该执行可以这么写
}
delete x;
return pos;
} // 从下一行开始是处理 x 不是根节点, 且有至多一棵子树的情况
NodeBase* y = (x->left != nullptr ? x->left : x->right);
if (x == x->parent->left) x->parent->left = y;
else x->parent->right = y;
if (y != nullptr) y->parent = x->parent;
else y = x->parent;
if (header_->left == x) header_->left = y;
else header_->right = y;
y = backtrack(x->parent);
while (y != header_) {
--y->size;
y = y->parent;
}
delete x;
return pos;
}
删除所有等价于指定值的元素,返回值为删除的节点个数
SizeType erase(const Key& key) {
SizeType res = 0U;
Iterator it = find(key);
while (it.baseptr_ != header_ && !value_compare_(key, *it)) {
it = erase(it);
++res;
}
return res;
}
用 O(1) 的时间交换两棵模板参数完全相同的 AVL 树,也就是仅交换 header 指针,当然这个不需要解释
void swap(AVLTree& tr) {
NodeBase* tmp = header_;
header_ = tr.header_;
tr.header_ = tmp;
}
用非递归后序遍历的方式清除所有节点(当然不包括 header)
void clear() {
NodeBase* x = header_->left;
if (x == header_) return; // 如果没有这条语句, 若树中本来就没有节点, 则会陷入无限循环
while (x->left != nullptr || x->right != nullptr)
x = x->left != nullptr ? x->left : x->right;
NodeBase* y = x;
header_->parent = nullptr;
header_->left = header_->right = header_;
while (x != header_) {
x = x->parent;
if (y == x->left && x->right != nullptr) {
x = x->right;
while (x->left != nullptr || x->right != nullptr)
x = x->left != nullptr ? x->left : x->right;
}
delete y;
y = x;
}
}
构造函数及析构函数
AVLTree()
: header_(new NodeBase(0, 0U, nullptr, nullptr, nullptr))
, value_compare_(ValueCompare()) {
header_->left = header_; // 当树为空时, 先令 header 的左指针和右指针都指向自己
header_->right = header_; // 此时 begin() last() end() 方法返回的的迭代器是相同的
}
AVLTree(std::initializer_list<ValueType> lst)
: header_(new NodeBase(0, 0U, nullptr, nullptr, nullptr))
, value_compare_(ValueCompare()) {
header_->left = header_;
header_->right = header_;
for (const ValueType& val : lst) insert_equal(val);
}
~AVLTree() {
clear(); // 先销毁除根节点的双亲结点之外的所有结点
delete header_; // 最后清除根节点的双亲结点
}
完整代码
#ifndef AVL_TREE_HPP
#define AVL_TREE_HPP
#include <algorithm>
#include <cassert>
#include <functional>
#include <utility>
#ifndef NULL_TYPE_IN_CONTAINER
#define NULL_TYPE_IN_CONTAINER
namespace agl {
struct NullType {};
} // namespace agl
#endif // NULL_TYPE_IN_CONTAINER
namespace agl {
template<typename Key, typename Mapped, typename KeyCompare = std::less<Key>>
class AVLTree {
public:
class PairCompare {
public:
PairCompare() : key_compare_(KeyCompare()) {}
bool operator()(
const std::pair<const Key, Mapped>& x,
const std::pair<const Key, Mapped>& y) const {
return key_compare_(x.first, y.first);
}
bool operator()(const Key& key, const std::pair<const Key, Mapped>& y) const {
return key_compare_(key, y.first);
}
bool operator()(const std::pair<const Key, Mapped>& x, const Key& key) const {
return key_compare_(x.first, key);
}
protected:
KeyCompare key_compare_;
};
public:
using HgtType = int;
using SizeType = unsigned int;
using ValueType = typename std::conditional<
std::is_same<Mapped, NullType>::value,
const Key, std::pair<const Key, Mapped>>::type;
using ValueCompare = typename std::conditional<
std::is_same<Mapped, NullType>::value,
KeyCompare, PairCompare>::type;
protected:
struct NodeBase {
NodeBase(HgtType hgt, SizeType sz, NodeBase* p, NodeBase* l, NodeBase* r)
: height(hgt), size(sz), parent(p), left(l), right(r) {}
HgtType height;
SizeType size;
NodeBase* parent;
NodeBase* left;
NodeBase* right;
};
struct Node : public NodeBase {
Node(HgtType hgt, SizeType sz, NodeBase* p,
NodeBase* l, NodeBase* r, const ValueType& val)
: NodeBase(hgt, sz, p, l, r), value(val) {}
ValueType value;
};
public:
class Iterator {
public:
Iterator() : baseptr_(nullptr) {}
explicit Iterator(NodeBase* x) : baseptr_(x) {}
ValueType& operator*() const { return static_cast<Node*>(baseptr_)->value; }
ValueType* operator->() const { return &(operator*()); }
Iterator& operator--() {
baseptr_ = AVLTree::prev_node(baseptr_);
return *this;
}
Iterator operator--(int) {
Iterator tmp = *this;
--*this;
return tmp;
}
Iterator& operator++() {
baseptr_ = AVLTree::next_node(baseptr_);
return *this;
}
Iterator operator++(int) {
Iterator tmp = *this;
++*this;
return tmp;
}
friend bool operator==(const Iterator& x, const Iterator& y) {
return x.baseptr_ == y.baseptr_;
}
friend bool operator!=(const Iterator& x, const Iterator& y) {
return x.baseptr_ != y.baseptr_;
}
protected:
friend AVLTree;
NodeBase* baseptr_;
};
public:
AVLTree()
: header_(new NodeBase(0, 0U, nullptr, nullptr, nullptr))
, value_compare_(ValueCompare()) {
header_->left = header_;
header_->right = header_;
}
AVLTree(std::initializer_list<ValueType> lst)
: header_(new NodeBase(0, 0U, nullptr, nullptr, nullptr))
, value_compare_(ValueCompare()) {
header_->left = header_;
header_->right = header_;
for (const ValueType& val : lst) insert_equal(val);
}
~AVLTree() {
clear();
delete header_;
}
Iterator begin() const { return Iterator(header_->left); }
Iterator last() const { return Iterator(header_->right); }
Iterator end() const { return Iterator(header_); }
bool empty() const { return header_->parent == nullptr; }
SizeType size() const { return empty() ? 0U : header_->parent->size; }
std::pair<Iterator, bool> insert_unique(const ValueType& val) {
NodeBase* x = header_->parent;
NodeBase* y = header_;
bool cmp = true;
while (x != nullptr) {
y = x;
cmp = value_compare_(val, static_cast<Node*>(x)->value);
x = cmp ? x->left : x->right;
}
NodeBase* z = y;
if (cmp) z = prev_node(z);
if (z != header_ && !value_compare_(static_cast<Node*>(z)->value, val))
return std::pair<Iterator, bool>(Iterator(z), false);
x = new Node(1, 1U, y, nullptr, nullptr, val);
if (y == header_) {
y->parent = x;
y->left = y->right = x;
} else if (cmp) {
y->left = x;
if (header_->left == y) header_->left = x;
} else {
y->right = x;
if (header_->right == y) header_->right = x;
}
y = backtrack(y);
while (y != header_) {
++y->size;
y = y->parent;
}
return std::pair<Iterator, bool>(Iterator(x), true);
}
std::pair<Iterator, bool> insert_equal(const ValueType& val) {
NodeBase* x = header_->parent;
NodeBase* y = header_;
while (x != nullptr) {
y = x;
++x->size;
x = value_compare_(val, static_cast<Node*>(x)->value) ? x->left : x->right;
}
x = new Node(1, 1U, y, nullptr, nullptr, val);
if (y == header_) {
y->parent = x;
y->left = x;
y->right = x;
} else if (value_compare_(val, static_cast<Node*>(y)->value)) {
y->left = x;
if (header_->left == y) header_->left = x;
} else {
y->right = x;
if (header_->right == y) header_->right = x;
}
backtrack(y);
return std::pair<Iterator, bool>(Iterator(x), true);
}
Iterator find(const Key& key) const {
NodeBase* x = header_->parent;
NodeBase* y = header_;
while (x != nullptr) {
if (value_compare_(static_cast<Node*>(x)->value, key)) x = x->right;
else y = x, x = x->left;
}
return (y == header_ || value_compare_(key, static_cast<Node*>(y)->value))
? Iterator(header_) : Iterator(y);
}
Iterator lower_bound(const Key& key) const {
NodeBase* x = header_->parent;
NodeBase* y = header_;
while (x != nullptr) {
if (value_compare_(static_cast<Node*>(x)->value, key)) x = x->right;
else y = x, x = x->left;
}
return Iterator(y);
}
Iterator upper_bound(const Key& key) const {
NodeBase* x = header_->parent;
NodeBase* y = header_;
while (x != nullptr) {
if (value_compare_(key, static_cast<Node*>(x)->value)) y = x, x = x->left;
else x = x->right;
}
return Iterator(y);
}
Iterator find_by_order(SizeType order) const {
NodeBase* x = header_->parent;
SizeType lsize{};
while (x != nullptr) {
lsize = (x->left == nullptr ? 0U : x->left->size);
if (order < lsize) x = x->left;
else if (order == lsize) return Iterator(x);
else order -= lsize + 1, x = x->right;
}
return Iterator(header_);
}
SizeType order_of_key(const Key& key) const {
SizeType res = 0U;
NodeBase* x = header_->parent;
while (x != nullptr) {
if (value_compare_(static_cast<Node*>(x)->value, key))
res += (x->left == nullptr ? 1U : x->left->size + 1U), x = x->right;
else x = x->left;
}
return res;
}
Iterator erase(Iterator pos) {
assert(pos.baseptr_ != header_);
NodeBase* x = (pos++).baseptr_;
if (x->left != nullptr && x->right != nullptr) {
swap_with_successor(x, pos.baseptr_);
} else if (x->parent == header_) {
int situ = (x->left != nullptr) + ((x->right != nullptr) << 1) - 1;
switch (situ) {
case 1: header_->parent = (header_->left = header_->right); break;
case 0: header_->parent = (header_->right = header_->left); break;
case -1: {
header_->parent = nullptr;
header_->left = header_;
header_->right = header_;
} break;
default: assert(false);
}
delete x;
return pos;
}
NodeBase* y = (x->left != nullptr ? x->left : x->right);
if (x == x->parent->left) x->parent->left = y;
else x->parent->right = y;
if (y != nullptr) y->parent = x->parent;
else y = x->parent;
if (header_->left == x) header_->left = y;
else header_->right = y;
y = backtrack(x->parent);
while (y != header_) {
--y->size;
y = y->parent;
}
delete x;
return pos;
}
SizeType erase(const Key& key) {
SizeType res = 0U;
Iterator it = find(key);
while (it.baseptr_ != header_ && !value_compare_(key, *it)) {
it = erase(it);
++res;
}
return res;
}
void clear() {
NodeBase* x = header_->left;
if (x == header_) return;
while (x->left != nullptr || x->right != nullptr)
x = x->left != nullptr ? x->left : x->right;
NodeBase* y = x;
header_->parent = nullptr;
header_->left = header_->right = header_;
while (x != header_) {
x = x->parent;
if (y == x->left && x->right != nullptr) {
x = x->right;
while (x->left != nullptr || x->right != nullptr)
x = x->left != nullptr ? x->left : x->right;
}
delete y;
y = x;
}
}
void swap(AVLTree& tr) {
NodeBase* tmp = header_;
header_ = tr.header_;
tr.header_ = tmp;
}
protected:
static NodeBase* prev_node(NodeBase* x) {
if (x->size == 0U) {
x = x->right;
} else if (x->left != nullptr) {
x = x->left;
while (x->right != nullptr) x = x->right;
} else {
while (x == x->parent->left) x = x->parent;
if (x->left != x->parent) x = x->parent;
}
return x;
}
static NodeBase* next_node(NodeBase* x) {
if (x->size == 0U) {
x = x->left;
} else if (x->right != nullptr) {
x = x->right;
while (x->left != nullptr) x = x->left;
} else {
while (x == x->parent->right) x = x->parent;
if (x->right != x->parent) x = x->parent;
}
return x;
}
static HgtType factor(const NodeBase* const x) {
return (x->left == nullptr ? 0 : x->left->height)
- (x->right == nullptr ? 0 : x->right->height);
}
static void update_info(NodeBase* const x) {
x->height = 1;
x->size = 1U;
if (x->left != nullptr) {
x->height += x->left->height;
x->size += x->left->size;
}
if (x->right != nullptr) {
x->height = std::max(x->height, x->right->height + 1);
x->size += x->right->size;
}
}
static NodeBase* rotate_right(NodeBase* const x) {
NodeBase* const y = x->left;
x->left = y->right;
if (y->right != nullptr) y->right->parent = x;
y->parent = x->parent;
if (x->parent->size == 0U) x->parent->parent = y;
else if (x == x->parent->right) x->parent->right = y;
else x->parent->left = y;
y->right = x;
x->parent = y;
update_info(x), update_info(y);
return y;
}
static NodeBase* rotate_left(NodeBase* const x) {
NodeBase* const y = x->right;
x->right = y->left;
if (y->left != nullptr) y->left->parent = x;
y->parent = x->parent;
if (x->parent->size == 0U) x->parent->parent = y;
else if (x == x->parent->left) x->parent->left = y;
else x->parent->right = y;
y->left = x;
x->parent = y;
update_info(x), update_info(y);
return y;
}
static NodeBase* rotate_left_then_right(NodeBase* const x) {
NodeBase* y = x->left;
NodeBase* z = y->right;
x->left = z->right;
y->right = z->left;
if (z->right != nullptr) z->right->parent = x;
if (z->left != nullptr) z->left->parent = y;
z->parent = x->parent;
if (x->parent->size == 0U) x->parent->parent = z;
else if (x == x->parent->right) x->parent->right = z;
else x->parent->left = z;
z->right = x;
z->left = y;
x->parent = y->parent = z;
update_info(x), update_info(y), update_info(z);
return z;
}
static NodeBase* rotate_right_then_left(NodeBase* const x) {
NodeBase* y = x->right;
NodeBase* z = y->left;
x->right = z->left;
y->left = z->right;
if (z->left != nullptr) z->left->parent = x;
if (z->right != nullptr) z->right->parent = y;
z->parent = x->parent;
if (x->parent->size == 0U) x->parent->parent = z;
else if (x == x->parent->left) x->parent->left = z;
else x->parent->right = z;
z->left = x;
z->right = y;
x->parent = y->parent = z;
update_info(x), update_info(y), update_info(z);
return z;
}
static void swap_with_successor(NodeBase* x, NodeBase* y) {
y->height = x->height;
y->size = x->size;
y->left = x->left;
x->left->parent = y;
x->left = nullptr;
if (y == x->right) {
x->right = y->right;
if (y->right != nullptr) y->right->parent = x;
y->parent = x->parent;
y->right = x;
x->parent = y;
} else {
NodeBase* tmp = x->parent;
x->parent = y->parent;
y->parent = tmp;
tmp = x->right;
x->right = y->right;
y->right = tmp;
x->parent->left = x;
if (x->right != nullptr) x->right->parent = x;
y->right->parent = y;
}
if (y->parent->size == 0U) y->parent->parent = y;
else if (x == y->parent->left) y->parent->left = y;
else y->parent->right = y;
}
static NodeBase* backtrack(NodeBase* x) {
for (HgtType orig_hgt{}, bal_factor{}; x->size != 0U; x = x->parent) {
orig_hgt = x->height;
bal_factor = factor(x);
switch (bal_factor) {
case 2: {
x = factor(x->left) < 0
? rotate_left_then_right(x)
: rotate_right(x);
} break;
case -2: {
x = factor(x->right) > 0
? rotate_right_then_left(x)
: rotate_left(x);
} break;
default: update_info(x);
}
if (x->height == orig_hgt) return x->parent;
}
return x;
}
protected:
NodeBase* header_;
ValueCompare value_compare_;
};
} // namespace agl
#endif // AVL_TREE_HPP