【学习笔记】AVL树类模板(参考STL红黑树的实现及 pb_ds 库的模板参数及函数命名)

发布时间 2023-07-01 00:05:47作者: xb2

嵌套类 PairCompare 以及一些 typedef

模板参数命名参考 pb_ds 库,当第二个模板参数不为 NullType 时,结点值域类型为 std::pair<const Key, Mapped>,否则为 const Key。

public:
  class PairCompare {
  public:
    PairCompare() : key_compare_(KeyCompare()) {}

    bool operator()(
      const std::pair<const Key, Mapped>& x,
      const std::pair<const Key, Mapped>& y) const {
      return key_compare_(x.first, y.first);
    }

    bool operator()(const Key& key, const std::pair<const Key, Mapped>& y) const {
      return key_compare_(key, y.first);
    }

    bool operator()(const std::pair<const Key, Mapped>& x, const Key& key) const {
      return key_compare_(x.first, key);
    }

  protected:
    KeyCompare key_compare_;
  };

public:
  using HgtType = int;            // 高度的存储类型为 int, 方便计算平衡因子
  using SizeType = unsigned int;  // 子树大小用 unsigned int 来存储

  using ValueType = typename std::conditional<
    std::is_same<Mapped, NullType>::value,
    const Key, std::pair<const Key, Mapped>>::type;  // 结点值域类型

  using ValueCompare = typename std::conditional<
    std::is_same<Mapped, NullType>::value,
    KeyCompare, PairCompare>::type;                  // 比较类选取

 

嵌套类 NodeBase 及其派生类 Node

protected:
  struct NodeBase {  // 结点类的基类, 没有值域, 值域在其派生类 Node 内
    NodeBase(HgtType hgt, SizeType sz, NodeBase* p, NodeBase* l, NodeBase* r)
      : height(hgt), size(sz), parent(p), left(l), right(r) {}
    HgtType height;
    SizeType size;
    NodeBase* parent;  // 参考 STL 红黑树源码,借助双亲结点做到所有操作无递归
    NodeBase* left;
    NodeBase* right;
  };

  struct Node : public NodeBase {
    Node(HgtType hgt, SizeType sz, NodeBase* p,
         NodeBase* l, NodeBase* r, const ValueType& val)
      : NodeBase(hgt, sz, p, l, r), value(val) {}
    ValueType value;
  };

 

成员变量

protected:
  NodeBase* header_;  // 参考 STL, 这是实现上的一个技巧, 当根节点存在时, header 和根节点互为双亲结点
  ValueCompare value_compare_;

 

静态函数,由其他函数及迭代器的成员函数调用

在静态函数中判断指针 x 指向的是 header 的条件为 x->size == 0U

寻找中序遍历顺序中的上一个节点或下一个节点

protected:
  static NodeBase* prev_node(NodeBase* x) {
    if (x->size == 0U) {
      x = x->right;
    } else if (x->left != nullptr) {
      x = x->left;
      while (x->right != nullptr) x = x->right;
    } else {
      while (x == x->parent->left) x = x->parent;
      if (x->left != x->parent) x = x->parent;
    }
    return x;
  }

  static NodeBase* next_node(NodeBase* x) {
    if (x->size == 0U) {
      x = x->left;
    } else if (x->right != nullptr) {
      x = x->right;
      while (x->left != nullptr) x = x->left;
    } else {
      while (x == x->parent->right) x = x->parent;
      if (x->right != x->parent) x = x->parent;
    }
    return x;
  }

 

平衡因子

  static HgtType factor(const NodeBase* const x) {
    return (x->left == nullptr ? 0 : x->left->height)
         - (x->right == nullptr ? 0 : x->right->height);
  }

 

更新节点的高度及子树大小

  static void update_info(NodeBase* const x) {
    x->height = 1;
    x->size = 1U;
    if (x->left != nullptr) {
      x->height += x->left->height;
      x->size += x->left->size;
    }
    if (x->right != nullptr) {
      x->height = std::max(x->height, x->right->height + 1);
      x->size += x->right->size;
    }
  }

 

用于处理 LL 型失衡的右旋以及用于处理 RR 型失衡的左旋

  static NodeBase* rotate_right(NodeBase* const x) {
    NodeBase* const y = x->left;
    x->left = y->right;
    if (y->right != nullptr) y->right->parent = x;
    y->parent = x->parent;
    if (x->parent->size == 0U) x->parent->parent = y;
    else if (x == x->parent->right) x->parent->right = y;
    else x->parent->left = y;
    y->right = x;
    x->parent = y;
    update_info(x), update_info(y);
    return y;
  }

  static NodeBase* rotate_left(NodeBase* const x) {
    NodeBase* const y = x->right;
    x->right = y->left;
    if (y->left != nullptr) y->left->parent = x;
    y->parent = x->parent;
    if (x->parent->size == 0U) x->parent->parent = y;
    else if (x == x->parent->left) x->parent->left = y;
    else x->parent->right = y;
    y->left = x;
    x->parent = y;
    update_info(x), update_info(y);
    return y;
  }

 

双旋

先将左子节点左旋,再将该节点右旋,处理 LR 型失衡

先将右子节点右旋,再将该节点左旋,处理 RL 型失衡

这里我特意将双旋操作都整理在一个函数里

  static NodeBase* rotate_left_then_right(NodeBase* const x) {
    NodeBase* y = x->left;
    NodeBase* z = y->right;
    x->left = z->right;
    y->right = z->left;
    if (z->right != nullptr) z->right->parent = x;
    if (z->left != nullptr) z->left->parent = y;
    z->parent = x->parent;
    if (x->parent->size == 0U) x->parent->parent = z;
    else if (x == x->parent->right) x->parent->right = z;
    else x->parent->left = z;
    z->right = x;
    z->left = y;
    x->parent = y->parent = z;
    update_info(x), update_info(y), update_info(z);
    return z;
  }

  static NodeBase* rotate_right_then_left(NodeBase* const x) {
    NodeBase* y = x->right;
    NodeBase* z = y->left;
    x->right = z->left;
    y->left = z->right;
    if (z->left != nullptr) z->left->parent = x;
    if (z->right != nullptr) z->right->parent = y;
    z->parent = x->parent;
    if (x->parent->size == 0U) x->parent->parent = z;
    else if (x == x->parent->left) x->parent->left = z;
    else x->parent->right = z;
    z->left = x;
    z->right = y;
    x->parent = y->parent = z;
    update_info(x), update_info(y), update_info(z);
    return z;
  }

 

将某个节点和其中序遍历的后继节点的位置进行对调,只有删除节点时会用到它,把它单独写是因为如果不单独写的话,删除节点的方法太长了,感觉不太好看。。。

static void swap_with_successor(NodeBase* x, NodeBase* y) {
    y->height = x->height;  // 用该节点的高度及子树大小给后继节点的高度及子树大小赋值
    y->size = x->size;
    y->left = x->left;
    x->left->parent = y;
    x->left = nullptr;
    if (y == x->right) {
      x->right = y->right;
      if (y->right != nullptr) y->right->parent = x;
      y->parent = x->parent;
      y->right = x;
      x->parent = y;
    } else {
      NodeBase* tmp = x->parent;
      x->parent = y->parent;
      y->parent = tmp;
      tmp = x->right;
      x->right = y->right;
      y->right = tmp;
      x->parent->left = x;
      if (x->right != nullptr) x->right->parent = x;
      y->right->parent = y;
    }
    if (y->parent->size == 0U) y->parent->parent = y;
    else if (x == y->parent->left) y->parent->left = y;
    else y->parent->right = y;
  }

 

从某个节点开始向上回溯,根据平衡因子采取调整策略,并更新所经过的节点的信息,直到某个节点的高度没有变或回溯到尽头为止,且不去管停止回溯时上面还有多少节点的信息没更新。

  static NodeBase* backtrack(NodeBase* x) {
    for (HgtType orig_hgt{}, bal_factor{}; x->size != 0U; x = x->parent) {
      orig_hgt = x->height;
      bal_factor = factor(x);
      switch (bal_factor) {
        case 2: {
          x = factor(x->left) < 0
            ? rotate_left_then_right(x)
            : rotate_right(x);
        } break;
        case -2: {
          x = factor(x->right) > 0
            ? rotate_right_then_left(x)
            : rotate_left(x);
        } break;
        default: update_info(x);
      }
      if (x->height == orig_hgt) return x->parent;
    }
    return x;
  }

 

参考 STL 实现的迭代器

public:
  class Iterator {
  public:
    Iterator() : baseptr_(nullptr) {}
    explicit Iterator(NodeBase* x) : baseptr_(x) {}

    ValueType& operator*() const { return static_cast<Node*>(baseptr_)->value; }
    ValueType* operator->() const { return &(operator*()); }

    Iterator& operator--() {
      baseptr_ = AVLTree::prev_node(baseptr_);
      return *this;
    }

    Iterator operator--(int) {
      Iterator tmp = *this;
      --*this;
      return tmp;
    }

    Iterator& operator++() {
      baseptr_ = AVLTree::next_node(baseptr_);
      return *this;
    }

    Iterator operator++(int) {
      Iterator tmp = *this;
      ++*this;
      return tmp;
    }

    friend bool operator==(const Iterator& x, const Iterator& y) {
      return x.baseptr_ == y.baseptr_;
    }

    friend bool operator!=(const Iterator& x, const Iterator& y) {
      return x.baseptr_ != y.baseptr_;
    }

  protected:
    friend AVLTree;
    NodeBase* baseptr_;
  };

 

成员函数

begin() -- 指向第一个元素的迭代器

last() -- 指向最后一个元素的迭代器

end() -- 指向最后一个元素的下一个位置(即 header)的迭代器

没有元素时,这三个方法返回的迭代器是相同的

  Iterator begin() const { return Iterator(header_->left); }
  Iterator last() const { return Iterator(header_->right); }
  Iterator end() const { return Iterator(header_); }

 

判断是否为空树及树中元素个数

  bool empty() const { return header_->parent == nullptr; }
  SizeType size() const { return empty() ? 0U : header_->parent->size; }

 

插入元素,不允许该元素与树中已有元素重复

std::pair<Iterator, bool> insert_unique(const ValueType& val) {
    NodeBase* x = header_->parent;
    NodeBase* y = header_;
    bool cmp = true;
    while (x != nullptr) {
      y = x;
      cmp = value_compare_(val, static_cast<Node*>(x)->value);
      x = cmp ? x->left : x->right;  // 若要插入的值小于当前节点的值则往左走, 否则一律往右走, 判断重复的环节后面再进行
    }
    NodeBase* z = y;
    if (cmp) z = prev_node(z);                                               // 如果要插入的位置是 y 的左子节点, 令 z 走到 y 的前驱
    if (z != header_ && !value_compare_(static_cast<Node*>(z)->value, val))  // 判重
      return std::pair<Iterator, bool>(Iterator(z), false);
    x = new Node(1, 1U, y, nullptr, nullptr, val);
    if (y == header_) {
      y->parent = x;
      y->left = y->right = x;
    } else if (cmp) {
      y->left = x;
      if (header_->left == y) header_->left = x;
    } else {
      y->right = x;
      if (header_->right == y) header_->right = x;
    }
    y = backtrack(y);
    while (y != header_) {
      ++y->size;
      y = y->parent;
    }
    return std::pair<Iterator, bool>(Iterator(x), true);
  }

 

插入元素,允许该元素与树中已有元素重复

  std::pair<Iterator, bool> insert_equal(const ValueType& val) {
    NodeBase* x = header_->parent;
    NodeBase* y = header_;
    while (x != nullptr) {
      y = x;
      ++x->size;
      x = value_compare_(val, static_cast<Node*>(x)->value) ? x->left : x->right;
    }
    x = new Node(1, 1U, y, nullptr, nullptr, val);
    if (y == header_) {
      y->parent = x;
      y->left = x;
      y->right = x;
    } else if (value_compare_(val, static_cast<Node*>(y)->value)) {
      y->left = x;
      if (header_->left == y) header_->left = x;
    } else {
      y->right = x;
      if (header_->right == y) header_->right = x;
    }
    backtrack(y);
    return std::pair<Iterator, bool>(Iterator(x), true);
  }

 

查找某个元素,如果存在多个,则总是找到中序遍历过程中最先访问到的那个,当然也可以说是最先插入的那个,因为在 insert_equal 方法中,寻找插入位置过程中若遇到相等的元素一律往右走。

  Iterator find(const Key& key) const {
    NodeBase* x = header_->parent;
    NodeBase* y = header_;
    while (x != nullptr) {
      if (value_compare_(static_cast<Node*>(x)->value, key)) x = x->right;
      else y = x, x = x->left;
    }
    return (y == header_ || value_compare_(key, static_cast<Node*>(y)->value))
           ? Iterator(header_) : Iterator(y);
  }

 

lower_bound 和 upper_bound,不解释,容易注意到 lower_bound 和 find 相比仅仅是 return 后面的部分有差别。

  Iterator lower_bound(const Key& key) const {
    NodeBase* x = header_->parent;
    NodeBase* y = header_;
    while (x != nullptr) {
      if (value_compare_(static_cast<Node*>(x)->value, key)) x = x->right;
      else y = x, x = x->left;
    }
    return Iterator(y);
  }

  Iterator upper_bound(const Key& key) const {
    NodeBase* x = header_->parent;
    NodeBase* y = header_;
    while (x != nullptr) {
      if (value_compare_(key, static_cast<Node*>(x)->value)) y = x, x = x->left;
      else x = x->right;
    }
    return Iterator(y);
  }

 

根据排名查值,注意排名是从 0 开始的

  Iterator find_by_order(SizeType order) const {
    NodeBase* x = header_->parent;
    SizeType lsize{};
    while (x != nullptr) {
      lsize = (x->left == nullptr ? 0U : x->left->size);
      if (order < lsize) x = x->left;
      else if (order == lsize) return Iterator(x);
      else order -= lsize + 1, x = x->right;
    }
    return Iterator(header_);
  }

 

根据值查排名,注意事项同上

  SizeType order_of_key(const Key& key) const {
    SizeType res = 0U;
    NodeBase* x = header_->parent;
    while (x != nullptr) {
      if (value_compare_(static_cast<Node*>(x)->value, key))
        res += (x->left == nullptr ? 1U : x->left->size + 1U), x = x->right;
      else x = x->left;
    }
    return res;
  }

 

根据迭代器删除单个节点

注意:在 AVL 树中,如果某个节点仅有一棵子树,那么这棵子树一定仅有一个节点,否则就会违反“每个结点的平衡因子的绝对值对不大于 1”这一性质。

Iterator erase(Iterator pos) {
    assert(pos.baseptr_ != header_);                  // header 肯定是不能删除的
    NodeBase* x = (pos++).baseptr_;                   // AVLTree 是 Iterator 的友元
    if (x->left != nullptr && x->right != nullptr) {  // 如果要 x 既有左子树也有右子树
      swap_with_successor(x, pos.baseptr_);           // 让 x 和它的后继节点进行交换
    } else if (x->parent == header_) {                // 如果 x 至多有一棵子树且是根节点, 则特殊对待
      int situ = (x->left != nullptr) + ((x->right != nullptr) << 1) - 1;
      switch (situ) {
        case 1: header_->parent = (header_->left = header_->right); break;  // x 只有右子树
        case 0: header_->parent = (header_->right = header_->left); break;  // x 只有左子树
        case -1: {                                                          // x 是叶子节点, 显然此时整棵树除 header 外只有 x 一个节点
          header_->parent = nullptr;
          header_->left = header_;
          header_->right = header_;
        } break;
        default: assert(false);                       // 如果 default 后面的代码永远不应该执行可以这么写
      }
      delete x;
      return pos;
    }  // 从下一行开始是处理 x 不是根节点, 且有至多一棵子树的情况
    NodeBase* y = (x->left != nullptr ? x->left : x->right);
    if (x == x->parent->left) x->parent->left = y;
    else x->parent->right = y;
    if (y != nullptr) y->parent = x->parent;
    else y = x->parent;
    if (header_->left == x) header_->left = y;
    else header_->right = y;
    y = backtrack(x->parent);
    while (y != header_) {
      --y->size;
      y = y->parent;
    }
    delete x;
    return pos;
  }

 

删除所有等价于指定值的元素,返回值为删除的节点个数

  SizeType erase(const Key& key) {
    SizeType res = 0U;
    Iterator it = find(key);
    while (it.baseptr_ != header_ && !value_compare_(key, *it)) {
      it = erase(it);
      ++res;
    }
    return res;
  }

 

用 O(1) 的时间交换两棵模板参数完全相同的 AVL 树,也就是仅交换 header 指针,当然这个不需要解释

  void swap(AVLTree& tr) {
    NodeBase* tmp = header_;
    header_ = tr.header_;
    tr.header_ = tmp;
  }

 

用非递归后序遍历的方式清除所有节点(当然不包括 header)

void clear() {
    NodeBase* x = header_->left;
    if (x == header_) return;  // 如果没有这条语句, 若树中本来就没有节点, 则会陷入无限循环
    while (x->left != nullptr || x->right != nullptr)
      x = x->left != nullptr ? x->left : x->right;
    NodeBase* y = x;
    header_->parent = nullptr;
    header_->left = header_->right = header_;
    while (x != header_) {
      x = x->parent;
      if (y == x->left && x->right != nullptr) {
        x = x->right;
        while (x->left != nullptr || x->right != nullptr)
          x = x->left != nullptr ? x->left : x->right;
      }
      delete y;
      y = x;
    }
  }

 

构造函数及析构函数

  AVLTree()
    : header_(new NodeBase(0, 0U, nullptr, nullptr, nullptr))
    , value_compare_(ValueCompare()) {
    header_->left = header_;   // 当树为空时, 先令 header 的左指针和右指针都指向自己
    header_->right = header_;  // 此时 begin() last() end() 方法返回的的迭代器是相同的
  }

  AVLTree(std::initializer_list<ValueType> lst)
    : header_(new NodeBase(0, 0U, nullptr, nullptr, nullptr))
    , value_compare_(ValueCompare()) {
    header_->left = header_;
    header_->right = header_;
    for (const ValueType& val : lst) insert_equal(val);
  }

  ~AVLTree() {
    clear();         // 先销毁除根节点的双亲结点之外的所有结点
    delete header_;  // 最后清除根节点的双亲结点
  }

 

完整代码

#ifndef AVL_TREE_HPP
#define AVL_TREE_HPP

#include <algorithm>
#include <cassert>
#include <functional>
#include <utility>

#ifndef NULL_TYPE_IN_CONTAINER
#define NULL_TYPE_IN_CONTAINER

namespace agl {

struct NullType {};

}  // namespace agl

#endif  // NULL_TYPE_IN_CONTAINER

namespace agl {

template<typename Key, typename Mapped, typename KeyCompare = std::less<Key>>
class AVLTree {
public:
  class PairCompare {
  public:
    PairCompare() : key_compare_(KeyCompare()) {}

    bool operator()(
      const std::pair<const Key, Mapped>& x,
      const std::pair<const Key, Mapped>& y) const {
      return key_compare_(x.first, y.first);
    }

    bool operator()(const Key& key, const std::pair<const Key, Mapped>& y) const {
      return key_compare_(key, y.first);
    }

    bool operator()(const std::pair<const Key, Mapped>& x, const Key& key) const {
      return key_compare_(x.first, key);
    }

  protected:
    KeyCompare key_compare_;
  };

public:
  using HgtType = int;
  using SizeType = unsigned int;

  using ValueType = typename std::conditional<
    std::is_same<Mapped, NullType>::value,
    const Key, std::pair<const Key, Mapped>>::type;

  using ValueCompare = typename std::conditional<
    std::is_same<Mapped, NullType>::value,
    KeyCompare, PairCompare>::type;

protected:
  struct NodeBase {
    NodeBase(HgtType hgt, SizeType sz, NodeBase* p, NodeBase* l, NodeBase* r)
      : height(hgt), size(sz), parent(p), left(l), right(r) {}
    HgtType height;
    SizeType size;
    NodeBase* parent;
    NodeBase* left;
    NodeBase* right;
  };

  struct Node : public NodeBase {
    Node(HgtType hgt, SizeType sz, NodeBase* p,
         NodeBase* l, NodeBase* r, const ValueType& val)
      : NodeBase(hgt, sz, p, l, r), value(val) {}
    ValueType value;
  };

public:
  class Iterator {
  public:
    Iterator() : baseptr_(nullptr) {}
    explicit Iterator(NodeBase* x) : baseptr_(x) {}

    ValueType& operator*() const { return static_cast<Node*>(baseptr_)->value; }
    ValueType* operator->() const { return &(operator*()); }

    Iterator& operator--() {
      baseptr_ = AVLTree::prev_node(baseptr_);
      return *this;
    }

    Iterator operator--(int) {
      Iterator tmp = *this;
      --*this;
      return tmp;
    }

    Iterator& operator++() {
      baseptr_ = AVLTree::next_node(baseptr_);
      return *this;
    }

    Iterator operator++(int) {
      Iterator tmp = *this;
      ++*this;
      return tmp;
    }

    friend bool operator==(const Iterator& x, const Iterator& y) {
      return x.baseptr_ == y.baseptr_;
    }

    friend bool operator!=(const Iterator& x, const Iterator& y) {
      return x.baseptr_ != y.baseptr_;
    }

  protected:
    friend AVLTree;
    NodeBase* baseptr_;
  };

public:
  AVLTree()
    : header_(new NodeBase(0, 0U, nullptr, nullptr, nullptr))
    , value_compare_(ValueCompare()) {
    header_->left = header_;
    header_->right = header_;
  }

  AVLTree(std::initializer_list<ValueType> lst)
    : header_(new NodeBase(0, 0U, nullptr, nullptr, nullptr))
    , value_compare_(ValueCompare()) {
    header_->left = header_;
    header_->right = header_;
    for (const ValueType& val : lst) insert_equal(val);
  }

  ~AVLTree() {
    clear();
    delete header_;
  }

  Iterator begin() const { return Iterator(header_->left); }
  Iterator last() const { return Iterator(header_->right); }
  Iterator end() const { return Iterator(header_); }

  bool empty() const { return header_->parent == nullptr; }
  SizeType size() const { return empty() ? 0U : header_->parent->size; }

  std::pair<Iterator, bool> insert_unique(const ValueType& val) {
    NodeBase* x = header_->parent;
    NodeBase* y = header_;
    bool cmp = true;
    while (x != nullptr) {
      y = x;
      cmp = value_compare_(val, static_cast<Node*>(x)->value);
      x = cmp ? x->left : x->right;
    }
    NodeBase* z = y;
    if (cmp) z = prev_node(z);
    if (z != header_ && !value_compare_(static_cast<Node*>(z)->value, val))
      return std::pair<Iterator, bool>(Iterator(z), false);
    x = new Node(1, 1U, y, nullptr, nullptr, val);
    if (y == header_) {
      y->parent = x;
      y->left = y->right = x;
    } else if (cmp) {
      y->left = x;
      if (header_->left == y) header_->left = x;
    } else {
      y->right = x;
      if (header_->right == y) header_->right = x;
    }
    y = backtrack(y);
    while (y != header_) {
      ++y->size;
      y = y->parent;
    }
    return std::pair<Iterator, bool>(Iterator(x), true);
  }

  std::pair<Iterator, bool> insert_equal(const ValueType& val) {
    NodeBase* x = header_->parent;
    NodeBase* y = header_;
    while (x != nullptr) {
      y = x;
      ++x->size;
      x = value_compare_(val, static_cast<Node*>(x)->value) ? x->left : x->right;
    }
    x = new Node(1, 1U, y, nullptr, nullptr, val);
    if (y == header_) {
      y->parent = x;
      y->left = x;
      y->right = x;
    } else if (value_compare_(val, static_cast<Node*>(y)->value)) {
      y->left = x;
      if (header_->left == y) header_->left = x;
    } else {
      y->right = x;
      if (header_->right == y) header_->right = x;
    }
    backtrack(y);
    return std::pair<Iterator, bool>(Iterator(x), true);
  }

  Iterator find(const Key& key) const {
    NodeBase* x = header_->parent;
    NodeBase* y = header_;
    while (x != nullptr) {
      if (value_compare_(static_cast<Node*>(x)->value, key)) x = x->right;
      else y = x, x = x->left;
    }
    return (y == header_ || value_compare_(key, static_cast<Node*>(y)->value))
           ? Iterator(header_) : Iterator(y);
  }

  Iterator lower_bound(const Key& key) const {
    NodeBase* x = header_->parent;
    NodeBase* y = header_;
    while (x != nullptr) {
      if (value_compare_(static_cast<Node*>(x)->value, key)) x = x->right;
      else y = x, x = x->left;
    }
    return Iterator(y);
  }

  Iterator upper_bound(const Key& key) const {
    NodeBase* x = header_->parent;
    NodeBase* y = header_;
    while (x != nullptr) {
      if (value_compare_(key, static_cast<Node*>(x)->value)) y = x, x = x->left;
      else x = x->right;
    }
    return Iterator(y);
  }

  Iterator find_by_order(SizeType order) const {
    NodeBase* x = header_->parent;
    SizeType lsize{};
    while (x != nullptr) {
      lsize = (x->left == nullptr ? 0U : x->left->size);
      if (order < lsize) x = x->left;
      else if (order == lsize) return Iterator(x);
      else order -= lsize + 1, x = x->right;
    }
    return Iterator(header_);
  }

  SizeType order_of_key(const Key& key) const {
    SizeType res = 0U;
    NodeBase* x = header_->parent;
    while (x != nullptr) {
      if (value_compare_(static_cast<Node*>(x)->value, key))
        res += (x->left == nullptr ? 1U : x->left->size + 1U), x = x->right;
      else x = x->left;
    }
    return res;
  }

  Iterator erase(Iterator pos) {
    assert(pos.baseptr_ != header_);
    NodeBase* x = (pos++).baseptr_;
    if (x->left != nullptr && x->right != nullptr) {
      swap_with_successor(x, pos.baseptr_);
    } else if (x->parent == header_) {
      int situ = (x->left != nullptr) + ((x->right != nullptr) << 1) - 1;
      switch (situ) {
        case 1: header_->parent = (header_->left = header_->right); break;
        case 0: header_->parent = (header_->right = header_->left); break;
        case -1: {
          header_->parent = nullptr;
          header_->left = header_;
          header_->right = header_;
        } break;
        default: assert(false);
      }
      delete x;
      return pos;
    }
    NodeBase* y = (x->left != nullptr ? x->left : x->right);
    if (x == x->parent->left) x->parent->left = y;
    else x->parent->right = y;
    if (y != nullptr) y->parent = x->parent;
    else y = x->parent;
    if (header_->left == x) header_->left = y;
    else header_->right = y;
    y = backtrack(x->parent);
    while (y != header_) {
      --y->size;
      y = y->parent;
    }
    delete x;
    return pos;
  }

  SizeType erase(const Key& key) {
    SizeType res = 0U;
    Iterator it = find(key);
    while (it.baseptr_ != header_ && !value_compare_(key, *it)) {
      it = erase(it);
      ++res;
    }
    return res;
  }

  void clear() {
    NodeBase* x = header_->left;
    if (x == header_) return;
    while (x->left != nullptr || x->right != nullptr)
      x = x->left != nullptr ? x->left : x->right;
    NodeBase* y = x;
    header_->parent = nullptr;
    header_->left = header_->right = header_;
    while (x != header_) {
      x = x->parent;
      if (y == x->left && x->right != nullptr) {
        x = x->right;
        while (x->left != nullptr || x->right != nullptr)
          x = x->left != nullptr ? x->left : x->right;
      }
      delete y;
      y = x;
    }
  }

  void swap(AVLTree& tr) {
    NodeBase* tmp = header_;
    header_ = tr.header_;
    tr.header_ = tmp;
  }

protected:
  static NodeBase* prev_node(NodeBase* x) {
    if (x->size == 0U) {
      x = x->right;
    } else if (x->left != nullptr) {
      x = x->left;
      while (x->right != nullptr) x = x->right;
    } else {
      while (x == x->parent->left) x = x->parent;
      if (x->left != x->parent) x = x->parent;
    }
    return x;
  }

  static NodeBase* next_node(NodeBase* x) {
    if (x->size == 0U) {
      x = x->left;
    } else if (x->right != nullptr) {
      x = x->right;
      while (x->left != nullptr) x = x->left;
    } else {
      while (x == x->parent->right) x = x->parent;
      if (x->right != x->parent) x = x->parent;
    }
    return x;
  }

  static HgtType factor(const NodeBase* const x) {
    return (x->left == nullptr ? 0 : x->left->height)
         - (x->right == nullptr ? 0 : x->right->height);
  }

  static void update_info(NodeBase* const x) {
    x->height = 1;
    x->size = 1U;
    if (x->left != nullptr) {
      x->height += x->left->height;
      x->size += x->left->size;
    }
    if (x->right != nullptr) {
      x->height = std::max(x->height, x->right->height + 1);
      x->size += x->right->size;
    }
  }

  static NodeBase* rotate_right(NodeBase* const x) {
    NodeBase* const y = x->left;
    x->left = y->right;
    if (y->right != nullptr) y->right->parent = x;
    y->parent = x->parent;
    if (x->parent->size == 0U) x->parent->parent = y;
    else if (x == x->parent->right) x->parent->right = y;
    else x->parent->left = y;
    y->right = x;
    x->parent = y;
    update_info(x), update_info(y);
    return y;
  }

  static NodeBase* rotate_left(NodeBase* const x) {
    NodeBase* const y = x->right;
    x->right = y->left;
    if (y->left != nullptr) y->left->parent = x;
    y->parent = x->parent;
    if (x->parent->size == 0U) x->parent->parent = y;
    else if (x == x->parent->left) x->parent->left = y;
    else x->parent->right = y;
    y->left = x;
    x->parent = y;
    update_info(x), update_info(y);
    return y;
  }

  static NodeBase* rotate_left_then_right(NodeBase* const x) {
    NodeBase* y = x->left;
    NodeBase* z = y->right;
    x->left = z->right;
    y->right = z->left;
    if (z->right != nullptr) z->right->parent = x;
    if (z->left != nullptr) z->left->parent = y;
    z->parent = x->parent;
    if (x->parent->size == 0U) x->parent->parent = z;
    else if (x == x->parent->right) x->parent->right = z;
    else x->parent->left = z;
    z->right = x;
    z->left = y;
    x->parent = y->parent = z;
    update_info(x), update_info(y), update_info(z);
    return z;
  }

  static NodeBase* rotate_right_then_left(NodeBase* const x) {
    NodeBase* y = x->right;
    NodeBase* z = y->left;
    x->right = z->left;
    y->left = z->right;
    if (z->left != nullptr) z->left->parent = x;
    if (z->right != nullptr) z->right->parent = y;
    z->parent = x->parent;
    if (x->parent->size == 0U) x->parent->parent = z;
    else if (x == x->parent->left) x->parent->left = z;
    else x->parent->right = z;
    z->left = x;
    z->right = y;
    x->parent = y->parent = z;
    update_info(x), update_info(y), update_info(z);
    return z;
  }

  static void swap_with_successor(NodeBase* x, NodeBase* y) {
    y->height = x->height;
    y->size = x->size;
    y->left = x->left;
    x->left->parent = y;
    x->left = nullptr;
    if (y == x->right) {
      x->right = y->right;
      if (y->right != nullptr) y->right->parent = x;
      y->parent = x->parent;
      y->right = x;
      x->parent = y;
    } else {
      NodeBase* tmp = x->parent;
      x->parent = y->parent;
      y->parent = tmp;
      tmp = x->right;
      x->right = y->right;
      y->right = tmp;
      x->parent->left = x;
      if (x->right != nullptr) x->right->parent = x;
      y->right->parent = y;
    }
    if (y->parent->size == 0U) y->parent->parent = y;
    else if (x == y->parent->left) y->parent->left = y;
    else y->parent->right = y;
  }

  static NodeBase* backtrack(NodeBase* x) {
    for (HgtType orig_hgt{}, bal_factor{}; x->size != 0U; x = x->parent) {
      orig_hgt = x->height;
      bal_factor = factor(x);
      switch (bal_factor) {
        case 2: {
          x = factor(x->left) < 0
            ? rotate_left_then_right(x)
            : rotate_right(x);
        } break;
        case -2: {
          x = factor(x->right) > 0
            ? rotate_right_then_left(x)
            : rotate_left(x);
        } break;
        default: update_info(x);
      }
      if (x->height == orig_hgt) return x->parent;
    }
    return x;
  }

protected:
  NodeBase* header_;
  ValueCompare value_compare_;
};

}  // namespace agl

#endif  // AVL_TREE_HPP