#include <iostream>
#include <queue>
#include <vector>
#include <algorithm>
using namespace std;
enum Color {RED, BLACK};
string color_str[] = {"R", "B"};
struct Node {
int key;
Color color;
Node *parent;
Node *left;
Node *right;
Node(int key) : key(key), color(Color::RED), parent(nullptr), left(nullptr), right(nullptr) {}
};
class RedBlack {
Node *root;
int nodeNum;
public:
RedBlack() : root(nullptr), nodeNum(0) {}
void rotateLeft(Node *node)
{
Node *right_child = node->right;
Node *parent = node->parent;
node->right = right_child->left;
if (right_child->left != nullptr) {
right_child->left->parent = node;
}
right_child->parent = parent;
if (parent == nullptr) {
root = right_child;
} else if (node == parent->left) {
parent->left = right_child;
} else {
parent->right = right_child;
}
node->parent = right_child;
right_child->left = node;
}
void rotateRight(Node *node)
{
Node *left_child = node->left;
Node *parent = node->parent;
node->left = left_child->right;
if (left_child->right != nullptr) {
left_child->right->parent = node;
}
left_child->parent = parent;
if (parent == nullptr) {
root = left_child;
} else if (node == parent->left) {
parent->left = left_child;
} else {
parent->right = left_child;
}
node->parent = left_child;
left_child->right = node;
}
void insertFixup(Node *node)
{
while (node->parent != nullptr && node->parent->color == Color::RED) {
if (node->parent == node->parent->parent->left) {
Node *uncle = node->parent->parent->right;
if (uncle != nullptr && uncle->color == Color::RED) {
node->parent->color = Color::BLACK;
uncle->color = Color::BLACK;
node->parent->parent->color = Color::RED;
node = node->parent->parent;
} else {
if (node == node->parent->right) {
node = node->parent;
rotateLeft(node);
}
node->parent->color = Color::BLACK;
node->parent->parent->color = Color::RED;
rotateRight(node->parent->parent);
}
}
else {
Node *uncle = node->parent->parent->left;
if (uncle != nullptr && uncle->color == Color::RED) {
node->parent->color = Color::BLACK;
uncle->color = Color::BLACK;
node->parent->parent->color = Color::RED;
node = node->parent->parent;
} else {
if (node == node->parent->left) {
node = node->parent;
rotateRight(node);
}
node->parent->color = Color::BLACK;
node->parent->parent->color = Color::RED;
rotateLeft(node->parent->parent);
}
}
}
root->color = Color::BLACK;
}
void insert(int key)
{
Node *new_node = new Node(key);
if (root == nullptr){
root = new_node;
root->color = Color::BLACK;
nodeNum++;
return;
}
Node *current = root;
Node *parent = nullptr;
while (current != nullptr){
parent = current;
if (key < current->key){
current = current->left;
} else {
current = current->right;
}
}
new_node->parent = parent;
if (key < parent->key){
parent->left = new_node;
} else {
parent->right = new_node;
}
nodeNum++;
insertFixup(new_node);
}
Node *findKey(int key)
{
Node * cur = root;
while(cur != nullptr && cur->key != key){
cur = (key < cur->key) ? cur->left : cur->right;
}
return cur;
}
void transPlant(Node * src, Node *dest)
{
if (src->parent == nullptr) // root 이면
root = dest;
else if(src == src->parent->left) // parent의 왼쪽 자식이면
src->parent->left = dest;
else
src->parent->right = dest;
if(dest != nullptr)
dest->parent = src->parent;
}
void removeFixup(Node *node)
{
Node *brother = nullptr;
if (node == nullptr)
return;
while (node != root && node->color == Color::BLACK) {
// node가 parent의 왼쪽 자식인 경우
if (node == node->parent->left) {
brother = node->parent->right;
if (brother->left == nullptr)
brother->left = new Node(-1);
if (brother->right == nullptr)
brother->right = new Node(-1);
// case 1
if (brother->color == Color::RED) {
brother->color = Color::BLACK;
node->parent->color = Color::RED;
rotateLeft(node->parent);
brother = node->parent->right;
}
if (brother->left == nullptr && brother->right == nullptr) {
brother->color = Color::RED;
node = node->parent;
continue;
}
// case 2
if (brother->left->color == Color::BLACK && brother->right->color == Color::BLACK) {
brother->color = Color::RED;
node = node->parent;
} else {
// case 3
if (brother->right->color == Color::BLACK){
brother->left->color = Color::BLACK;
brother->color = Color::RED;
rotateRight(brother);
brother = node->parent->right;
}
// case 4
else {
brother->color = node->parent->color;
node->parent->color = Color::BLACK;
brother->right->color = Color::BLACK;
rotateLeft(node->parent);
node = root;
}
}
}
// node가 parent의 오른쪽 자식인 경우
else if (node == node->parent->right)
{
brother = node->parent->left;
if (brother->left == nullptr)
brother->left = new Node(-1);
if (brother->right == nullptr)
brother->right = new Node(-1);
// case 1
if (brother->color == Color::RED){
brother->color = Color::BLACK;
node->parent->color = Color::RED;
rotateRight(node->parent);
brother = node->parent->left;
}
if (brother->left == nullptr && brother->right == nullptr) {
brother->color = Color::RED;
node = node->parent;
continue;
}
// case 2
if (brother->left->color == Color::BLACK && brother->right->color == Color::BLACK) {
brother->color = Color::RED;
node = node->parent;
} else {
// case 3
if (brother->left->color == Color::BLACK) {
brother->right->color = Color::BLACK;
brother->color = Color::RED;
rotateLeft(brother);
brother = node->parent->left;
}
// case 4
else {
brother->color = node->parent->color;
node->parent->color = Color::BLACK;
brother->left->color = Color::BLACK;
rotateRight(node->parent);
node = root;
}
}
}
}
// brother->color = Color::BLACK;
node->color = Color::BLACK;
}
bool remove(int key)
{
Node *delNode = findKey(key);
Node *fixUpNode = delNode;
Color originalColor = delNode->color;
if (delNode == nullptr) // remove node가 없음
return false;
if (delNode->left == nullptr && delNode->right == nullptr)
{
if (originalColor == Color::BLACK)
removeFixup(fixUpNode);
}
else if (delNode->left == nullptr)
{
fixUpNode = delNode->right;
transPlant(delNode, delNode->right);
removeFixup(fixUpNode);
}
else if (delNode->right == nullptr)
{
fixUpNode = delNode->left;
transPlant(delNode, delNode->left);
removeFixup(fixUpNode);
}
else {
Node *successor = delNode->right;
while (successor->left != nullptr)
{
successor = successor->left;
}
fixUpNode = successor;
removeFixup(fixUpNode);
delNode->key = successor->key;
if (successor->right != nullptr)
{
if (successor == successor->parent->left)
successor->parent->left = successor->right;
else
successor->parent->right = successor->right;
}
else
{
if (successor == successor->parent->left)
successor->parent->left = nullptr;
else
successor->parent->right = nullptr;
}
delNode = successor;
}
if(delNode == delNode->parent->left)
delNode->parent->left = nullptr;
if (delNode == delNode->parent->right)
delNode->parent->left = nullptr;
delete delNode;
nodeNum--;
return true;
}
void inorderTraversal()
{
inorderTraversal(root);
}
void inorderTraversal(Node *node)
{
static int dept = 0;
if (node != nullptr)
{
inorderTraversal(node->left);
cout << node->key << "[" << color_str[node->color] << "] / ";
inorderTraversal(node->right);
}
}
void levelOrderTraversal()
{
levelOrderTraversal(root);
}
void levelOrderTraversal(Node *node)
{
vector<vector<Node>> res;
if (node == nullptr)
return;
queue<Node *> q;
q.push(node);
int num = nodeNum;
Node *temp = new Node(-1);
while (!q.empty() && num > 0)
{
int size = q.size();
vector<Node> brothers;
for (int i = 0; i < size; i++)
{
Node *front = q.front();
q.pop();
if (front->key != -1)
num--;
brothers.push_back(*front);
if (front->left != nullptr)
q.push(front->left);
else
q.push(temp);
if (front->right != nullptr)
q.push(front->right);
else
q.push(temp);
}
res.push_back(brothers);
}
vector<vector<Node>>::iterator itr;
vector<Node>::iterator itr_brother;
for (itr = res.begin(); itr != res.end(); itr++)
{
for (itr_brother = itr->begin(); itr_brother != itr->end(); itr_brother++)
{
if (itr_brother->key == -1)
cout << "X ";
else
cout << itr_brother->key << "[" << color_str[itr_brother->color] << "] ";
}
cout << endl;
}
}
};
int main()
{
RedBlack tree;
tree.insert(10);
tree.insert(12);
tree.insert(3);
tree.insert(4);
tree.insert(7);
tree.insert(6);
tree.insert(9);
tree.insert(11);
tree.levelOrderTraversal();
tree.remove(10);
cout << "\\nAfter remove 10" << endl;
tree.levelOrderTraversal();
tree.remove(4);
cout << "\\nAfter remove 4" << endl;
tree.levelOrderTraversal();
tree.remove(3);
cout << "\\nAfter remove 3" << endl;
tree.levelOrderTraversal();
}
/*
10[B]
4[R] 12[B]
3[B] 7[B] 11[R] X
X X 6[R] 9[R] X X X X
After remove 10
11[B]
4[R] 12[B]
3[B] 7[B] X X
X X 6[R] 9[R] X X X X
After remove 4
11[B]
6[R] 12[B]
3[B] 7[B] X X
X X X 9[R] X X X X
After remove 3
11[B]
7[R] 12[B]
6[B] 9[B] X X
*/
