查看: 31|回复: 0

洛谷P3369题解:Treap数据结构从入门到精通

[复制链接]
  • TA的每日心情
    无聊
    昨天 10:23
  • 签到天数: 52 天

    [LV.5]常住居民I

    51

    主题

    0

    回帖

    86

    积分

    注册会员

    Rank: 2

    积分
    86
    发表于 2025-7-21 11:10:58 | 显示全部楼层 |阅读模式

    1. #include <iostream>
    2. #include <cstdlib>
    3. #include <climits>
    4. using namespace std;

    5. const int INF = 1e8; // 处理大数值范围

    6. struct Node {
    7.     int val, size, cnt;
    8.     int priority;
    9.     Node *l, *r;
    10.     Node(int v) : val(v), size(1), cnt(1), l(nullptr), r(nullptr) {
    11.         priority = rand() % INF; // 随机优先级保证平衡
    12.     }
    13. };

    14. class Treap {
    15. private:
    16.     Node *root;
    17.    
    18.     // 更新节点子树大小
    19.     void update(Node *node) {
    20.         if(!node) return;
    21.         node->size = node->cnt;
    22.         if(node->l) node->size += node->l->size;
    23.         if(node->r) node->size += node->r->size;
    24.     }
    25.    
    26.     // 左旋操作
    27.     void rotateLeft(Node *&node) {
    28.         Node *temp = node->r;
    29.         node->r = temp->l;
    30.         temp->l = node;
    31.         update(node);
    32.         update(temp);
    33.         node = temp;
    34.     }
    35.    
    36.     // 右旋操作
    37.     void rotateRight(Node *&node) {
    38.         Node *temp = node->l;
    39.         node->l = temp->r;
    40.         temp->r = node;
    41.         update(node);
    42.         update(temp);
    43.         node = temp;
    44.     }
    45.    
    46.     // 插入操作
    47.     void insert(Node *&node, int val) {
    48.         if(!node) {
    49.             node = new Node(val);
    50.             return;
    51.         }
    52.         if(val == node->val) {
    53.             node->cnt++; // 重复值计数
    54.         } else if(val < node->val) {
    55.             insert(node->l, val);
    56.             if(node->l->priority > node->priority)
    57.                 rotateRight(node); // 维护堆性质
    58.         } else {
    59.             insert(node->r, val);
    60.             if(node->r->priority > node->priority)
    61.                 rotateLeft(node); // 维护堆性质
    62.         }
    63.         update(node);
    64.     }
    65.    
    66.     // 删除操作
    67.     void remove(Node *&node, int val) {
    68.         if(!node) return;
    69.         if(val < node->val) {
    70.             remove(node->l, val);
    71.         } else if(val > node->val) {
    72.             remove(node->r, val);
    73.         } else {
    74.             if(node->cnt > 1) {
    75.                 node->cnt--; // 减少计数
    76.             } else {
    77.                 if(!node->l || !node->r) {
    78.                     Node *temp = node->l ? node->l : node->r;
    79.                     delete node;
    80.                     node = temp; // 单子树情况直接替换
    81.                 } else {
    82.                     // 选择优先级高的子树旋转
    83.                     if(node->l->priority > node->r->priority) {
    84.                         rotateRight(node);
    85.                         remove(node->r, val);
    86.                     } else {
    87.                         rotateLeft(node);
    88.                         remove(node->l, val);
    89.                     }
    90.                 }
    91.             }
    92.         }
    93.         if(node) update(node);
    94.     }
    95.    
    96.     // 获取排名
    97.     int getRank(Node *node, int val) {
    98.         if(!node) return 0;
    99.         if(val < node->val) return getRank(node->l, val);
    100.         int leftSize = node->l ? node->l->size : 0;
    101.         if(val == node->val) return leftSize + 1;
    102.         return leftSize + node->cnt + getRank(node->r, val);
    103.     }
    104.    
    105.     // 根据排名获取值
    106.     int getValue(Node *node, int rank) {
    107.         if(!node) return INF;
    108.         int leftSize = node->l ? node->l->size : 0;
    109.         if(rank <= leftSize) return getValue(node->l, rank);
    110.         if(rank <= leftSize + node->cnt) return node->val;
    111.         return getValue(node->r, rank - leftSize - node->cnt);
    112.     }
    113.    
    114.     // 获取前驱
    115.     int getPre(Node *node, int val) {
    116.         if(!node) return -INF;
    117.         if(node->val >= val) return getPre(node->l, val);
    118.         return max(node->val, getPre(node->r, val));
    119.     }
    120.    
    121.     // 获取后继
    122.     int getNext(Node *node, int val) {
    123.         if(!node) return INF;
    124.         if(node->val <= val) return getNext(node->r, val);
    125.         return min(node->val, getNext(node->l, val));
    126.     }

    127. public:
    128.     Treap() : root(nullptr) { srand(time(0)); }
    129.    
    130.     // 公开接口
    131.     void insert(int val) { insert(root, val); }
    132.     void remove(int val) { remove(root, val); }
    133.     int getRank(int val) { return getRank(root, val); }
    134.     int getValue(int rank) { return getValue(root, rank); }
    135.     int getPre(int val) { return getPre(root, val); }
    136.     int getNext(int val) { return getNext(root, val); }
    137. };

    138. int main() {
    139.     ios::sync_with_stdio(false);
    140.     cin.tie(0);
    141.    
    142.     Treap treap;
    143.     int n, opt, x;
    144.     cin >> n;
    145.     while(n--) {
    146.         cin >> opt >> x;
    147.         switch(opt) {
    148.             case 1: treap.insert(x); break;
    149.             case 2: treap.remove(x); break;
    150.             case 3: cout << treap.getRank(x) << '\n'; break;
    151.             case 4: cout << treap.getValue(x) << '\n'; break;
    152.             case 5: cout << treap.getPre(x) << '\n'; break;
    153.             case 6: cout << treap.getNext(x) << '\n'; break;
    154.         }
    155.     }
    156.     return 0;
    157. }
    复制代码


    来源:洛谷题解

    回复

    使用道具 举报

    您需要登录后才可以回帖 登录 | 立即注册

    本版积分规则

    快速回复 返回顶部 返回列表