平衡二分探索木 (AVL Tree)
概要
AVL 木という平衡二分探索木の実装.
各頂点で自身を根とする部分木の高さ (葉との距離の最大値) 管理する. AVL 木では,左右の子の高さの差が高々 1 であるということを不変条件とする.
高さ $h$ の AVL 木であって最も偏っている (深さが最大の頂点と最小の頂点の深さの差が最大である) ようなものが 格納している要素数を $f(h)$ とおく. すると $f(h)=f(h-1)+f(h-2)+1$ という漸化式が成り立つ. これは明らかにフィボナッチ数列よりも速く増加するので,$f(h)$ の逆関数 $g(n)$, すなわち要素数 $n$ の最も偏っている AVL 木の高さは $O(\log_\phi n)$ となり ($\phi$ は黄金比),確かに平衡が保たれていることが分かる.
使い方
lower_bound はソート済みの列にしか使えない.
計算量
insert_at, erase_at, merge, split, lower_bound, 単一ノードの更新が全て $O(\log n)$.
私の実装では treap より少し速い. しかし実装は重いので,ICPC で特に理由がなければ treap の方が良い.
#define NDEBUG 1
を忘れないこと.
実装
// Enable NDEBUG!
using key_type = int;
enum { L, R };
struct avl_node {
key_type key;
std::array<avl_node *, 2> ch;
int size;
int height;
static avl_node *const nil;
avl_node() : avl_node(key_type()) {}
avl_node(key_type key) : avl_node(key, nil, nil, 1, 1) {}
avl_node(const key_type key, avl_node *left, avl_node *right, int size, int height)
: key(key), ch({{left, right}}), size(size), height(height) {}
// void *operator new(size_t) {
// static int p = 0;
// static avl_node pool[400010];
// return pool + p++;
// }
};
avl_node *const avl_node::nil = new avl_node(key_type(), nullptr, nullptr, 0, 0);
avl_node *const nil = avl_node::nil;
namespace init_nil {
int _ = (nil->ch = {{nil, nil}}, 0);
}
using np = avl_node *;
using cnp = const avl_node *;
np update(np n) {
n->size = n->ch[L]->size + 1 + n->ch[R]->size;
n->height = std::max(n->ch[L]->height, n->ch[R]->height) + 1;
return n;
}
template <int dir>
np rotate(np n) {
assert(n->ch[!dir] != nil);
np root = n->ch[!dir];
n->ch[!dir] = root->ch[dir];
root->ch[dir] = n;
update(n);
update(root);
return root;
}
int balance_factor(np n) {
assert(n != nil);
return n->ch[R]->height - n->ch[L]->height;
}
np balance(np n) {
assert(abs(balance_factor(n)) <= 2);
if (balance_factor(n) == +2) {
if (balance_factor(n->ch[R]) < 0) n->ch[R] = rotate<R>(n->ch[R]);
return rotate<L>(n);
} else if (balance_factor(n) == -2) {
if (balance_factor(n->ch[L]) > 0) n->ch[L] = rotate<L>(n->ch[L]);
return rotate<R>(n);
} else {
return n;
}
}
np insert_at(np n, int k, key_type x) {
assert(0 <= k && k <= n->size);
if (n == nil) return new avl_node(x);
int sl = n->ch[L]->size;
if (k <= sl)
n->ch[L] = insert_at(n->ch[L], k, x);
else
n->ch[R] = insert_at(n->ch[R], k - sl - 1, x);
return balance(update(n));
}
template <int dir>
std::pair<np, np> remove_most(np n) {
assert(n != nil);
if (n->ch[dir] != nil) {
np most;
std::tie(n->ch[dir], most) = remove_most<dir>(n->ch[dir]);
return {balance(update(n)), most};
} else {
np res = n->ch[!dir];
n->ch[!dir] = nil;
return {res, update(n)};
}
}
std::pair<np, key_type> remove_at(np n, int k) {
assert(n != nil);
int sl = n->ch[L]->size;
if (k < sl) {
key_type most;
std::tie(n->ch[L], most) = remove_at(n->ch[L], k);
return {balance(update(n)), most};
}
if (k == sl) {
if (n->ch[R] == nil) {
return {n->ch[L], n->key};
} else {
np most;
std::tie(n->ch[R], most) = remove_most<L>(n->ch[R]);
most->ch = n->ch;
// delete n;
return {balance(update(most)), n->key};
}
} else {
key_type most;
std::tie(n->ch[R], most) = remove_at(n->ch[R], k - sl - 1);
return {balance(update(n)), most};
}
}
np merge_with_root(np l, np root, np r) {
// Members of `root` except root->key may not be valid for performance
// reason.
if (abs(l->height - r->height) <= 1) {
root->ch = {{l, r}};
return update(root);
} else if (l->height > r->height) {
l->ch[R] = merge_with_root(l->ch[R], root, r);
return balance(update(l));
} else {
r->ch[L] = merge_with_root(l, root, r->ch[L]);
return balance(update(r));
}
}
np merge(np l, np r) {
if (l == nil) return r;
if (r == nil) return l;
np m;
if (l->height > r->height)
std::tie(r, m) = remove_most<L>(r);
else
std::tie(l, m) = remove_most<R>(l);
return merge_with_root(l, m, r);
}
std::pair<np, np> split_at(np n, int k) {
assert(0 <= k && k <= n->size);
if (n == nil) return {nil, nil};
int sl = n->ch[L]->size;
np l = n->ch[L];
np r = n->ch[R];
n->ch[L] = n->ch[R] = nil;
// Members of avl_node passed to `merge` must be valid, but ones for
// `merge_with_root` doesn't have to.
np nl, nr;
if (k < sl) {
std::tie(nl, nr) = split_at(l, k);
return {nl, merge_with_root(nr, n, r)};
} else if (k == sl) {
update(n);
return {l, merge(n, r)};
} else {
std::tie(nl, nr) = split_at(r, k - sl - 1);
return {merge_with_root(l, n, nl), nr};
}
}
void update_at(np n, int k, key_type x) {
assert(0 <= k && k < n->size);
int sl = n->ch[L]->size;
if (k < sl)
update_at(n->ch[L], k, x);
else if (k == sl)
n->key = x;
else
update_at(n->ch[R], k - sl - 1, x);
update(n);
}
np lower_bound(np n, key_type x) {
if (n == nil) return nil;
if (x <= n->key) {
np res = lower_bound(n->ch[L], x);
return res != nil ? res : n;
} else {
return lower_bound(n->ch[R], x);
}
}
template <typename Iterator>
np build(Iterator left, Iterator right) {
int n = right - left;
Iterator mid = left + n / 2;
if (n == 0) return nil;
np l = build(left, mid);
np r = build(mid + 1, right);
np m = new avl_node(*mid);
m->ch = {{l, r}};
return update(m);
}
void to_a(np n, std::vector<key_type> &v) {
if (n == nil) return;
to_a(n->ch[L], v);
v.push_back(n->key);
to_a(n->ch[R], v);
}
std::vector<key_type> to_a(np n) {
std::vector<key_type> res;
to_a(n, res);
return res;
}
namespace test {
int real_height(np n) {
if (n == nil) return 0;
return 1 + std::max(real_height(n->ch[L]), real_height(n->ch[R]));
}
int real_size(np n) {
if (n == nil) return 0;
return 1 + real_size(n->ch[L]) + real_size(n->ch[R]);
}
bool verify(np n) {
if (n == nil) {
if (n->ch[L] != nil) return false;
if (n->ch[R] != nil) return false;
return true;
}
np l = n->ch[L];
np r = n->ch[R];
if (n->size != real_size(n)) return false;
if (n->height != real_height(n)) return false;
if (n->size != l->size + 1 + r->size) return false;
if (n->height != std::max(l->height, r->height) + 1) return false;
if (abs(balance_factor(n)) >= 2) return false;
return true;
}
}