Treap

发布时间 2023-12-16 20:44:51作者: 会有续命晴空

前置:BST

\(Treap=BST+heap\),通过额外维护堆的性质来避免退化成链的问题。

\(BST\)中结构释义相同的部分不做解释。

\(priority\)表示优先级,即该节点在小根堆中的值,\(child[0]\)表示左孩子,\(child[1]\)表示右孩子。

\(public\)中,\(empty\):时间复杂度\(O(1)\)

其余操作时间复杂度均为\(O(logn)\)

class Treap
{
private:
    int INF = 0x3f3f3f3f;
    struct Node
    {
        int key, size = 1, count = 1, priority = rand();
        Node *child[2] = {nullptr, nullptr};
        Node *&left = child[0], *&right = child[1];
        Node(int key) : key(key) {}
        void UpdSize()
        {
            size = count + (left ? left->size : 0) + (right ? right->size : 0);
        }
    };

    Node *root = nullptr;

    enum rot_type{LF = 1, RT = 0}; // 左旋和右旋
    void rotate(Node *&root, rot_type dir) // 旋转以维护小根堆的性质
    {
        Node *tmp = root->child[dir];
        root->child[dir] = tmp->child[!dir];
        tmp->child[!dir] = root;
        root->UpdSize(), tmp->UpdSize();
        root = tmp;
    }

    Node *FindMinNode(Node *root) // 在以root为根的子树中查找最小节点
    {
        if (!root) return root;
        while (root->left) root = root->left;
        return root;
    }

    Node *FindMaxNode(Node *root) // 在以root为根的子树中查找最大节点
    {
        if (!root) return root;
        while (root->right) root = root->right;
        return root;
    }

    int FindMin(Node *root) // 在以root为根的子树中查找最小键值
    {
        root = FindMinNode(root);
        return (root ? root->key : INF);
    }

    int FindMax(Node *root) // 在以root为根的子树中查找最大键值
    {
        root = FindMaxNode(root);
        return (root ? root->key : INF);
    }

    int count(Node *root, int val) // 在以root为根的子树中查找键值为val的数量
    {
        if (!root) return 0;
        if (root->key == val) return root->count;
        else if (val < root->key) return count(root->left, val);
        else return count(root->right, val);
    }

    void insert(Node *&root, int &val) // 在以root为根的子树中插入键值为val的节点
    {
        if (!root) root = new Node(val);
        else if (val < root->key) 
        {
            insert(root->left, val);
            if (root->left->priority < root->priority) rotate(root, RT);
        }
        else if (val > root->key)
        {
            insert(root->right, val);
            if (root->right->priority < root->priority) rotate(root, LF);
        }
        else root->count++;
        if (root) root->UpdSize();
    }

    void erase(Node *&root, int val, int cnt) // 在以root为根的子树中删除cnt个键值为val的节点
    {
        if (val < root->key) erase(root->left, val, cnt);
        else if (val > root->key) erase(root->right, val, cnt);
        else 
        {
            if (root->count > cnt) root->count -= cnt;
            else 
            {
                Node *tmp = root;
                if (!root->left)
                {
                    root = tmp->right;
                    delete tmp;
                } 
                else if (!root->right) 
                {
                    root = tmp->left;
                    delete tmp;
                } 
                else 
                {
                    rot_type dir = (root->left->priority < root->right->priority ? RT : LF);
                    rotate(root, dir);
                    erase(root->child[!dir], val, cnt);
                }
            }
        }
        if (root) root->UpdSize();
    }

    int rank(Node *root, int val) // 计算键值为val的排名(排名定义为比当前键值小的键值的个数加一)
    {
        if (!root) return 1;
        if (val == root->key) return (root->child[0] ? root->child[0]->size : 0) + 1;
        if (val < root->key) return rank(root->child[0], val);
        return rank(root->child[1], val) + root->count + (root->child[0] ? root->child[0]->size : 0);
    }

    int kth(Node *root, int k) // 计算排名为k的键值
    {
        if (!root) return INF;
        int leftSize = (root->child[0] ? root->child[0]->size : 0);
        if (leftSize >= k) return kth(root->child[0], k);
        if (leftSize + root->count >= k) return root->key;
        return kth(root->child[1], k - leftSize - root->count);
    }
public:
    bool empty(){return !root;}
    int Max(){return (root ? FindMax(root) : INF);}
    int Min(){return (root ? FindMin(root) : INF);}
    void insert(int val){insert(root, val);}
    int count(int val){return count(root, val);}
    void erase(int val, int cnt = 1){cnt = min(cnt, count(val)); if (cnt) erase(root, val, cnt);}
    int rank(int val){return rank(root, val);}
    int kth(int k){return kth(root, k);}
    int prev(int val){return kth(rank(val) - 1);}
    int next(int val){return kth(rank(val + 1));}
} Tp;