二叉搜索树(BST)

发布时间 2023-12-15 14:34:50作者: 会有续命晴空

\(INF\)表示无效值的常量,\(Node\)表示二叉搜索树的结点,\(key、count、left、right\)分别表示节点的键值、数量、左孩子、右孩子,\(size\)表示以该节点为根的子树大小,\(root\)表示树的根节点。

\(private\)中的函数解析见注释。

\(public\)中:\(empty\):判断树是否为空,时间复杂度:\(O(1)\)

以下操作时间复杂度相同,均为期望\(O(logn)\),最坏\(O(n)\)

\(Max\):求树中最大键值。 \(Min\):求树中最小键值。 \(insert(val)\):向树中插入键值为\(val\)的节点。

\(count(val)\):求键值为\(val\)的节点数量。 \(erase(val, cnt)\):删除\(cnt\)个键值为\(val\)的节点。

\(rank(val)\):求键值\(val\)的排名。(排名定义为比当前键值小的键值的个数加一) \(kth(val)\):求第\(k\)小的键值。

\(prev(val)\):求键值\(val\)的前驱。(前驱定义为小于\(val\),且最大的键值)

\(next(val)\):求键值\(val\)的后继。(后继定义为大于\(val\),且最小的键值)

class BinarySearchTree
{
private:
    int INF = 0x3f3f3f3f;
    struct Node 
    {
        int key, size = 1, count = 1;
        Node *left = nullptr, *right = nullptr;
        Node(int key) : key(key) {}
    };

    Node *root = nullptr;

    Node *FindMinNode(Node *root) // 在以root为根的子树中查找最小节点
    {
        if (!root) return root;
        while (root->left) root = root->left;
        return root;
    }

    Node *FindMaxNode(Node *root) // 在以root为根的子树中查找最大节点
    {
        if (!root) return root;
        while (root->right) root = root->right;
        return root;
    }

    int FindMin(Node *root) // 在以root为根的子树中查找最小键值
    {
        root = FindMinNode(root);
        return (root ? root->key : INF);
    }

    int FindMax(Node *root) // 在以root为根的子树中查找最大键值
    {
        root = FindMaxNode(root);
        return (root ? root->key : INF);
    }

    int count(Node *root, int val) // 在以root为根的子树中查找键值为val的数量
    {
        if (!root) return 0;
        if (root->key == val) return root->count;
        else if (val < root->key) return count(root->left, val);
        else return count(root->right, val);
    }

    Node *insert(Node *root, int val) // 在以root为根的子树中插入键值为val的节点
    {
        if (!root) return new Node(val);
        root->size++;
        if (val < root->key) root->left = insert(root->left, val);
        else if (val > root->key) root->right = insert(root->right, val);
        else root->count++;
        return root;
    }

    Node *erase(Node *root, int val, int cnt) // 在以root为根的子树中删除cnt个键值为val的节点
    {
        if (!root) return root;
        root->size -= cnt;
        if (val < root->key) root->left = erase(root->left, val, cnt);
        else if (val > root->key) root->right = erase(root->right, val, cnt);
        else 
        {
            if (root->count > cnt) root->count -= cnt;
            else 
            {
                if (!root->left)
                {
                    Node *temp = root->right;
                    delete root;
                    return temp;
                } 
                else if (!root->right) 
                {
                    Node *temp = root->left;
                    delete root;
                    return temp;
                } 
                else 
                {
                    Node *successor = FindMinNode(root->right);
                    root->key = successor->key;
                    root->count = successor->count;
                    root->right = erase(root->right, successor->key, successor->count);
                }
            }
        }
        return root;
    }

    int rank(Node *root, int val) // 计算键值为val的排名
    {
        if (!root) return 1;
        if (val == root->key) return (root->left ? root->left->size : 0) + 1;
        if (val < root->key) return rank(root->left, val);
        return rank(root->right, val) + root->count + (root->left ? root->left->size : 0);
    }

    int kth(Node *root, int k) // 计算排名为k的键值
    {
        if (!root) return INF;
        int leftSize = (root->left ? root->left->size : 0);
        if (leftSize >= k) return kth(root->left, k);
        if (leftSize + root->count >= k) return root->key;
        return kth(root->right, k - leftSize - root->count);
    }
public:
    bool empty(){return !root;}
    int Max(){return (root ? FindMax(root) : INF);}
    int Min(){return (root ? FindMin(root) : INF);}
    void insert(int val){root = insert(root, val);}
    int count(int val){return count(root, val);}
    void erase(int val, int cnt = 1){cnt = min(cnt, count(val)); if (cnt) root = erase(root, val, cnt);}
    int rank(int val){return rank(root, val);}
    int kth(int k){return kth(root, k);}
    int prev(int val){return kth(rank(val) - 1);}
    int next(int val){return kth(rank(val + 1));}
} BST;