洛谷 P3369 【模板】普通平衡树

发布时间 2023-05-05 14:29:29作者: __Accelerator

有旋Treap模板

#include <bits/stdc++.h>

using namespace std;

struct Node {
	Node *ch[2];
	int val, rank;
	int rep_cnt;
	int siz;

	Node(int val) : val(val), rep_cnt(1), siz(1) {
		ch[0] = ch[1] = nullptr;
		rank = rand();
	}

	void upd_siz() {
		siz = rep_cnt;
		if (ch[0] != nullptr) siz += ch[0]->siz;
		if (ch[1] != nullptr) siz += ch[1]->siz;
	}
};

enum rot_type {LF = 1, RT = 0};

void _rotate(Node *&cur, rot_type dir) {
	Node *tmp = cur->ch[dir];
	
	cur->ch[dir] = tmp->ch[!dir];
	tmp->ch[!dir] = cur;
	tmp->upd_siz();
	cur->upd_siz();
	cur = tmp;
}

void _insert(Node *&cur, int val) {
	if (cur == nullptr) {
		cur = new Node(val);
		return;
	}
	else if (cur->val == val) {
		cur->rep_cnt++;
		cur->siz++;
		return;
	}
	else if (cur->val > val) {
		_insert(cur->ch[0], val);
		if (cur->ch[0]->rank < cur->rank) {
			_rotate(cur, RT);
		}
		cur->upd_siz();
	}
	else if (cur->val < val) {
		_insert(cur->ch[1], val);
		if (cur->ch[1]->rank > cur->rank) {
			_rotate(cur, LF);
		}
		cur->upd_siz();
	}
}

void _del(Node *&cur, int val) {
	if (cur->val > val) {
		_del(cur->ch[0], val);
		cur->upd_siz();
	}
	else if (cur->val < val) {
		_del(cur->ch[1], val);
		cur->upd_siz();
	}
	else {
		if (cur->rep_cnt > 1) {
			cur->rep_cnt--;
			cur->siz--;
			return;
		}
		uint8_t state = 0;
		state |= (cur->ch[0] != nullptr);
		state |= ((cur->ch[1] != nullptr) << 1);
		//00:none	01:has left		10:has right	11:both
		Node *tmp = cur;
		switch(state) {
			case 0:
				delete cur;
				cur = nullptr;
				break;
			case 1:
				cur = tmp->ch[0];
				delete tmp;
				cur->upd_siz();
				break;
			case 2:
				cur = tmp->ch[1];
				delete tmp;
				cur->upd_siz();
				break;
			case 3:
				rot_type dir = cur->ch[0]->rank < cur->ch[1]->rank ? RT : LF;
				_rotate(cur, dir);
				_del(cur->ch[!dir], val);
				cur->upd_siz();
				break;
		}
	}
}

int _query_rank(Node *&cur, int val) {
	if (cur == nullptr) return 1;
	int less_siz = cur->ch[0] == nullptr ? 0 : cur->ch[0]->siz;
	if (val == cur->val) return less_siz + 1;
	else if (cur->val > val) {
		if (cur->ch[0] != nullptr) return _query_rank(cur->ch[0], val);
		else return 1;
	}
	else {
		if (cur->ch[1] != nullptr) return _query_rank(cur->ch[1], val) + less_siz + cur->rep_cnt;
		else return cur->siz + 1;
	}
}

int _query_val(Node *&cur, int rank) {
	if (cur == nullptr) return 0;
	int less_siz = cur->ch[0] == nullptr ? 0 : cur->ch[0]->siz;
	if (less_siz >= rank) return _query_val(cur->ch[0], rank);
	else if (less_siz + cur->rep_cnt >= rank) return cur->val;
	else return _query_val(cur->ch[1], rank - less_siz - cur->rep_cnt);
}

int q_pre_tmp;

int _query_prev(Node *cur, int val) {
	if (cur->val >= val) {
		if (cur->ch[0] != nullptr) return _query_prev(cur->ch[0], val); 
	}
	else {
		//we update the value of q_pre_tmp, only if we entered the else branch.
		q_pre_tmp = cur->val;
		if (cur->ch[1] != nullptr) _query_prev(cur->ch[1], val);
		return q_pre_tmp;		
		//we return the cur->val that entered the else branch the last time, wihch make sure that q_pre_tmp is the biggest valid value.
	}
	return -1;
}

int q_suf_tmp;

int _query_sufv(Node *cur, int val) {
	if (cur->val <= val) {
		if (cur->ch[1] != nullptr) return _query_sufv(cur->ch[1], val);
	}
	else {
		q_suf_tmp = cur->val;
		if (cur->ch[0] != nullptr) _query_sufv(cur->ch[0], val);
		return q_suf_tmp;
	}
	return -1;
}

int main() {
	int n;
	cin >> n;
	Node *root = nullptr;
	while (n--) {
		int op, x;
		cin >> op >> x;
		if (op == 1) _insert(root, x);
		if (op == 2) _del(root, x);
		if (op == 3) cout << _query_rank(root, x) << endl;
		if (op == 4) cout << _query_val(root, x) << endl;
		if (op == 5) cout << _query_prev(root, x) << endl;
		if (op == 6) cout << _query_sufv(root, x) << endl;
	}
	return 0;
}