Home > libalgo > 平衡二分探索木 (AVL Tree)

平衡二分探索木 (AVL Tree)

概要

AVL 木という平衡二分探索木の実装.

各頂点で自身を根とする部分木の高さ (葉との距離の最大値) 管理する. AVL 木では,左右の子の高さの差が高々 1 であるということを不変条件とする.

高さ $h$ の AVL 木であって最も偏っている (深さが最大の頂点と最小の頂点の深さの差が最大である) ようなものが 格納している要素数を $f(h)$ とおく. すると $f(h)=f(h-1)+f(h-2)+1$ という漸化式が成り立つ. これは明らかにフィボナッチ数列よりも速く増加するので,$f(h)$ の逆関数 $g(n)$, すなわち要素数 $n$ の最も偏っている AVL 木の高さは $O(\log_\phi n)$ となり ($\phi$ は黄金比),確かに平衡が保たれていることが分かる.

使い方

lower_bound はソート済みの列にしか使えない.

計算量

insert_at, erase_at, merge, split, lower_bound, 単一ノードの更新が全て $O(\log n)$.

私の実装では treap より少し速い. しかし実装は重いので,ICPC で特に理由がなければ treap の方が良い.

#define NDEBUG 1 を忘れないこと.

実装

// Enable NDEBUG!
using key_type = int;

enum { L, R };
struct avl_node {
    key_type key;
    std::array<avl_node *, 2> ch;
    int size;
    int height;

    static avl_node *const nil;

    avl_node() : avl_node(key_type()) {}
    avl_node(key_type key) : avl_node(key, nil, nil, 1, 1) {}
    avl_node(const key_type key, avl_node *left, avl_node *right, int size, int height)
        : key(key), ch({{left, right}}), size(size), height(height) {}
    // void *operator new(size_t) {
    //     static int p = 0;
    //     static avl_node pool[400010];
    //     return pool + p++;
    // }
};

avl_node *const avl_node::nil = new avl_node(key_type(), nullptr, nullptr, 0, 0);
avl_node *const nil = avl_node::nil;
namespace init_nil {
int _ = (nil->ch = {{nil, nil}}, 0);
}

using np = avl_node *;
using cnp = const avl_node *;

np update(np n) {
    n->size = n->ch[L]->size + 1 + n->ch[R]->size;
    n->height = std::max(n->ch[L]->height, n->ch[R]->height) + 1;
    return n;
}

template <int dir>
np rotate(np n) {
    assert(n->ch[!dir] != nil);
    np root = n->ch[!dir];
    n->ch[!dir] = root->ch[dir];
    root->ch[dir] = n;
    update(n);
    update(root);
    return root;
}

int balance_factor(np n) {
    assert(n != nil);
    return n->ch[R]->height - n->ch[L]->height;
}

np balance(np n) {
    assert(abs(balance_factor(n)) <= 2);
    if (balance_factor(n) == +2) {
        if (balance_factor(n->ch[R]) < 0) n->ch[R] = rotate<R>(n->ch[R]);
        return rotate<L>(n);
    } else if (balance_factor(n) == -2) {
        if (balance_factor(n->ch[L]) > 0) n->ch[L] = rotate<L>(n->ch[L]);
        return rotate<R>(n);
    } else {
        return n;
    }
}

np insert_at(np n, int k, key_type x) {
    assert(0 <= k && k <= n->size);
    if (n == nil) return new avl_node(x);
    int sl = n->ch[L]->size;
    if (k <= sl)
        n->ch[L] = insert_at(n->ch[L], k, x);
    else
        n->ch[R] = insert_at(n->ch[R], k - sl - 1, x);
    return balance(update(n));
}

template <int dir>
std::pair<np, np> remove_most(np n) {
    assert(n != nil);
    if (n->ch[dir] != nil) {
        np most;
        std::tie(n->ch[dir], most) = remove_most<dir>(n->ch[dir]);
        return {balance(update(n)), most};
    } else {
        np res = n->ch[!dir];
        n->ch[!dir] = nil;
        return {res, update(n)};
    }
}

std::pair<np, key_type> remove_at(np n, int k) {
    assert(n != nil);
    int sl = n->ch[L]->size;
    if (k < sl) {
        key_type most;
        std::tie(n->ch[L], most) = remove_at(n->ch[L], k);
        return {balance(update(n)), most};
    }
    if (k == sl) {
        if (n->ch[R] == nil) {
            return {n->ch[L], n->key};
        } else {
            np most;
            std::tie(n->ch[R], most) = remove_most<L>(n->ch[R]);
            most->ch = n->ch;
            // delete n;
            return {balance(update(most)), n->key};
        }
    } else {
        key_type most;
        std::tie(n->ch[R], most) = remove_at(n->ch[R], k - sl - 1);
        return {balance(update(n)), most};
    }
}

np merge_with_root(np l, np root, np r) {
    // Members of `root` except root->key may not be valid for performance
    // reason.
    if (abs(l->height - r->height) <= 1) {
        root->ch = {{l, r}};
        return update(root);
    } else if (l->height > r->height) {
        l->ch[R] = merge_with_root(l->ch[R], root, r);
        return balance(update(l));
    } else {
        r->ch[L] = merge_with_root(l, root, r->ch[L]);
        return balance(update(r));
    }
}

np merge(np l, np r) {
    if (l == nil) return r;
    if (r == nil) return l;
    np m;
    if (l->height > r->height)
        std::tie(r, m) = remove_most<L>(r);
    else
        std::tie(l, m) = remove_most<R>(l);
    return merge_with_root(l, m, r);
}

std::pair<np, np> split_at(np n, int k) {
    assert(0 <= k && k <= n->size);
    if (n == nil) return {nil, nil};
    int sl = n->ch[L]->size;
    np l = n->ch[L];
    np r = n->ch[R];
    n->ch[L] = n->ch[R] = nil;
    // Members of avl_node passed to `merge` must be valid, but ones for
    // `merge_with_root` doesn't have to.
    np nl, nr;
    if (k < sl) {
        std::tie(nl, nr) = split_at(l, k);
        return {nl, merge_with_root(nr, n, r)};
    } else if (k == sl) {
        update(n);
        return {l, merge(n, r)};
    } else {
        std::tie(nl, nr) = split_at(r, k - sl - 1);
        return {merge_with_root(l, n, nl), nr};
    }
}

void update_at(np n, int k, key_type x) {
    assert(0 <= k && k < n->size);
    int sl = n->ch[L]->size;
    if (k < sl)
        update_at(n->ch[L], k, x);
    else if (k == sl)
        n->key = x;
    else
        update_at(n->ch[R], k - sl - 1, x);
    update(n);
}

np lower_bound(np n, key_type x) {
    if (n == nil) return nil;
    if (x <= n->key) {
        np res = lower_bound(n->ch[L], x);
        return res != nil ? res : n;
    } else {
        return lower_bound(n->ch[R], x);
    }
}

template <typename Iterator>
np build(Iterator left, Iterator right) {
    int n = right - left;
    Iterator mid = left + n / 2;
    if (n == 0) return nil;
    np l = build(left, mid);
    np r = build(mid + 1, right);
    np m = new avl_node(*mid);
    m->ch = {{l, r}};
    return update(m);
}

void to_a(np n, std::vector<key_type> &v) {
    if (n == nil) return;
    to_a(n->ch[L], v);
    v.push_back(n->key);
    to_a(n->ch[R], v);
}

std::vector<key_type> to_a(np n) {
    std::vector<key_type> res;
    to_a(n, res);
    return res;
}

namespace test {
int real_height(np n) {
    if (n == nil) return 0;
    return 1 + std::max(real_height(n->ch[L]), real_height(n->ch[R]));
}

int real_size(np n) {
    if (n == nil) return 0;
    return 1 + real_size(n->ch[L]) + real_size(n->ch[R]);
}

bool verify(np n) {
    if (n == nil) {
        if (n->ch[L] != nil) return false;
        if (n->ch[R] != nil) return false;
        return true;
    }
    np l = n->ch[L];
    np r = n->ch[R];

    if (n->size != real_size(n)) return false;
    if (n->height != real_height(n)) return false;

    if (n->size != l->size + 1 + r->size) return false;
    if (n->height != std::max(l->height, r->height) + 1) return false;

    if (abs(balance_factor(n)) >= 2) return false;
    return true;
}
}

参考文献