AVL Tree

AVL Tree code

#include <iostream>
#include <vector>
#include <queue>
using namespace std;

struct Node {
    int key;
    int height;
    Node* left;
    Node* right;

    Node(int key) : key(key), height(1), left(nullptr), right(nullptr) {  }
};

class AVLTree {
    Node * root;
    int nodeNum;
public:
    AVLTree() : root(nullptr), nodeNum(0) { }
    int getNodeNum() { return nodeNum; }

    int get_height(Node *node)
    {
        if (node == nullptr)
        {
            return 0;
        }
        return node->height;
    }

    int get_balance(Node *node)
    {
        if (node == nullptr)
        {
            return 0;
        }
        return get_height(node->left) - get_height(node->right);
    }

    void update_height(Node *node)
    {
        node->height = max(get_height(node->left), get_height(node->right)) + 1;
    }

    Node *rotate_left(Node *node)
    {
        Node *right = node->right;
        node->right = right->left;
        right->left = node;
        update_height(node);
        update_height(right);
        return right;
    }

    Node *rotate_right(Node *node)
    {
        Node *left = node->left;
        node->left = left->right;
        left->right = node;
        update_height(node);
        update_height(left);
        return left;
    }

    Node *balance(Node *node)
    {
        int balance_factor = get_balance(node);
        if (balance_factor > 1)
        {
            if (get_balance(node->left) < 0)
            {
                node->left = rotate_left(node->left);
            }
            return rotate_right(node);
        }
        else if (balance_factor < -1)
        {
            if (get_balance(node->right) > 0)
            {
                node->right = rotate_right(node->right);
            }
            return rotate_left(node);
        }
        return node;
    }

    void insert(int key) {
        root = insert(root, key);
    }

    Node * insert(Node *node, int key)
    {
        if (node == nullptr)
        {
            nodeNum++;
            return new Node(key);
        }
        if (key < node->key)
        {
            node->left = insert(node->left, key);
        }
        else
        {
            node->right = insert(node->right, key);
        }
        update_height(node);
        return balance(node);
    }

    void remove(int key){
        root = remove(root, key);
    }

    Node *remove(Node *node, int key)
    {
        if (node == nullptr)
        {
            return nullptr;
        }
        if (key < node->key)
        {
            node->left = remove(node->left, key);
        }
        else if (key > node->key)
        {
            node->right = remove(node->right, key);
        }
        else
        {
            if (node->left == nullptr && node->right == nullptr)
            {
                delete node;
                nodeNum--;
                return nullptr;
            }
            if (node->left == nullptr)
            {
                Node *right = node->right;
                delete node;
                nodeNum--;
                return right;
            }
            if (node->right == nullptr)
            {
                Node *left = node->left;
                delete node;
                nodeNum--;
                return left;
            }
            Node *successor = node->right;
            while (successor->left != nullptr)
            {
                successor = successor->left;
            }
            node->key = successor->key;
            node->right = remove(node->right, successor->key);
        }
        update_height(node);
        return balance(node);
    }

    // AVL 트리에서 특정 key 값을 가진 노드를 찾는 함수
    bool search(int key){
        root = search(root, key);
        return root != nullptr;
    }

    Node *search(Node *node, int key)
    {
        if (node == nullptr || node->key == key)
        {
            return node;
        }
        if (key < node->key)
        {
            return search(node->left, key);
        }
        else
        {
            return search(node->right, key);
        }
    }

    void inorder(){
        inorder(root);
    }

    void inorder(Node *node)
    {
        if (node != nullptr)
        {
            inorder(node->left);
            cout << node->key << " ";
            inorder(node->right);
        }
    }

    void levelOrderTraversal(){
        levelOrderTraversal(root);
    }

    void levelOrderTraversal(Node *node)
    {
        vector<vector<Node>> res;

        if (node == nullptr)
            return;

        queue<Node *> q;
        q.push(node);
        int num = nodeNum;
        Node *temp = new Node(-1);

        while (!q.empty() && num > 0)
        {
            int size = q.size();
            vector<Node> v;

            for (int i = 0; i < size; i++)
            {
                Node *front = q.front();
                q.pop();
                
                if (front->key != -1)
                    num--;

                v.push_back(*front);
                if (front->left != nullptr)
                    q.push(front->left);
                else
                    q.push(temp);

                if (front->right != nullptr)
                    q.push(front->right);
                else
                    q.push(temp);

            }
            res.push_back(v);
        }
        vector<vector<Node>>::iterator itr;
        vector<Node>::iterator itr2;

        cout << endl;
        for (itr = res.begin(); itr != res.end(); itr++)
        {
            for (itr2 = itr->begin(); itr2 != itr->end(); itr2++)
            {
                if (itr2->key == -1)
                    cout << "X ";
                else
                    cout << itr2->key << " ";
            }
            cout << endl;
        }
    }
};

int main() {
    AVLTree avlTree;

    avlTree.insert(10);
    avlTree.insert(20);
    avlTree.insert(30);
    avlTree.insert(40);
    avlTree.insert(50);
    avlTree.insert(60);

    cout << "AVL tree after insertions: " << avlTree.getNodeNum() << endl;;
//    avlTree.inorder();
    avlTree.levelOrderTraversal();
    cout << endl;

    avlTree.remove(30);
    avlTree.remove(40);

    cout << "AVL tree after deletions: ";
//    avlTree.inorder();
    avlTree.levelOrderTraversal();
    cout << "Node number : " << avlTree.getNodeNum() << endl;

    if (avlTree.search(20)) {
        cout << "Key 20을 찾음" << endl;
    } else{
        cout << "Key 20이 AVL tree에 없음" << endl;
    }

    if (avlTree.search(30)) {
        cout << "Key 30을 찾음" << endl;
    }else{
        cout << "Key 30이 AVL tree에 없음" << endl;
    }

    return 0;
}

실행 결과

AVL tree after insertions: 6
40
20 50
10 30 X 60
AVL tree after deletions:
50
20 60
10 X X X
Node number : 4
Key 20을 찾음
Key 30이 AVL tree에 없음