#include <iostream>
#include <vector>
#include <queue>
using namespace std;
struct Node {
int key;
int height;
Node* left;
Node* right;
Node(int key) : key(key), height(1), left(nullptr), right(nullptr) { }
};
class AVLTree {
Node * root;
int nodeNum;
public:
AVLTree() : root(nullptr), nodeNum(0) { }
int getNodeNum() { return nodeNum; }
int get_height(Node *node)
{
if (node == nullptr)
{
return 0;
}
return node->height;
}
int get_balance(Node *node)
{
if (node == nullptr)
{
return 0;
}
return get_height(node->left) - get_height(node->right);
}
void update_height(Node *node)
{
node->height = max(get_height(node->left), get_height(node->right)) + 1;
}
Node *rotate_left(Node *node)
{
Node *right = node->right;
node->right = right->left;
right->left = node;
update_height(node);
update_height(right);
return right;
}
Node *rotate_right(Node *node)
{
Node *left = node->left;
node->left = left->right;
left->right = node;
update_height(node);
update_height(left);
return left;
}
Node *balance(Node *node)
{
int balance_factor = get_balance(node);
if (balance_factor > 1)
{
if (get_balance(node->left) < 0)
{
node->left = rotate_left(node->left);
}
return rotate_right(node);
}
else if (balance_factor < -1)
{
if (get_balance(node->right) > 0)
{
node->right = rotate_right(node->right);
}
return rotate_left(node);
}
return node;
}
void insert(int key) {
root = insert(root, key);
}
Node * insert(Node *node, int key)
{
if (node == nullptr)
{
nodeNum++;
return new Node(key);
}
if (key < node->key)
{
node->left = insert(node->left, key);
}
else
{
node->right = insert(node->right, key);
}
update_height(node);
return balance(node);
}
void remove(int key){
root = remove(root, key);
}
Node *remove(Node *node, int key)
{
if (node == nullptr)
{
return nullptr;
}
if (key < node->key)
{
node->left = remove(node->left, key);
}
else if (key > node->key)
{
node->right = remove(node->right, key);
}
else
{
if (node->left == nullptr && node->right == nullptr)
{
delete node;
nodeNum--;
return nullptr;
}
if (node->left == nullptr)
{
Node *right = node->right;
delete node;
nodeNum--;
return right;
}
if (node->right == nullptr)
{
Node *left = node->left;
delete node;
nodeNum--;
return left;
}
Node *successor = node->right;
while (successor->left != nullptr)
{
successor = successor->left;
}
node->key = successor->key;
node->right = remove(node->right, successor->key);
}
update_height(node);
return balance(node);
}
// AVL 트리에서 특정 key 값을 가진 노드를 찾는 함수
bool search(int key){
root = search(root, key);
return root != nullptr;
}
Node *search(Node *node, int key)
{
if (node == nullptr || node->key == key)
{
return node;
}
if (key < node->key)
{
return search(node->left, key);
}
else
{
return search(node->right, key);
}
}
void inorder(){
inorder(root);
}
void inorder(Node *node)
{
if (node != nullptr)
{
inorder(node->left);
cout << node->key << " ";
inorder(node->right);
}
}
void levelOrderTraversal(){
levelOrderTraversal(root);
}
void levelOrderTraversal(Node *node)
{
vector<vector<Node>> res;
if (node == nullptr)
return;
queue<Node *> q;
q.push(node);
int num = nodeNum;
Node *temp = new Node(-1);
while (!q.empty() && num > 0)
{
int size = q.size();
vector<Node> v;
for (int i = 0; i < size; i++)
{
Node *front = q.front();
q.pop();
if (front->key != -1)
num--;
v.push_back(*front);
if (front->left != nullptr)
q.push(front->left);
else
q.push(temp);
if (front->right != nullptr)
q.push(front->right);
else
q.push(temp);
}
res.push_back(v);
}
vector<vector<Node>>::iterator itr;
vector<Node>::iterator itr2;
cout << endl;
for (itr = res.begin(); itr != res.end(); itr++)
{
for (itr2 = itr->begin(); itr2 != itr->end(); itr2++)
{
if (itr2->key == -1)
cout << "X ";
else
cout << itr2->key << " ";
}
cout << endl;
}
}
};
int main() {
AVLTree avlTree;
avlTree.insert(10);
avlTree.insert(20);
avlTree.insert(30);
avlTree.insert(40);
avlTree.insert(50);
avlTree.insert(60);
cout << "AVL tree after insertions: " << avlTree.getNodeNum() << endl;;
// avlTree.inorder();
avlTree.levelOrderTraversal();
cout << endl;
avlTree.remove(30);
avlTree.remove(40);
cout << "AVL tree after deletions: ";
// avlTree.inorder();
avlTree.levelOrderTraversal();
cout << "Node number : " << avlTree.getNodeNum() << endl;
if (avlTree.search(20)) {
cout << "Key 20을 찾음" << endl;
} else{
cout << "Key 20이 AVL tree에 없음" << endl;
}
if (avlTree.search(30)) {
cout << "Key 30을 찾음" << endl;
}else{
cout << "Key 30이 AVL tree에 없음" << endl;
}
return 0;
}
실행 결과
AVL tree after insertions: 6
40
20 50
10 30 X 60
AVL tree after deletions:
50
20 60
10 X X X
Node number : 4
Key 20을 찾음
Key 30이 AVL tree에 없음