AVL tree
概要
find, insert, erase, merge, split に対応.全て $O(h) = O(\log n)$ ($h$ は高さ).RBST と Treap よりちょっと速い.TODO: set と multiset との切り替え機能の実装.
使い方
at (インデックスアクセス), find, insert, erase, merge, split は通常通り.count_less(x) で x 以下の要素の数を取得.
実装
// Using PartialOrd instead of Ord because float types does not have
// the impl of Ord :(
#[allow(dead_code)]
mod avl_tree {
use std;
use std::cmp::Ordering::*;
use std::fmt::Debug;
type Link<T> = Option<Box<Node<T>>>;
pub struct AVLTree<T> {
pub root: Link<T>,
}
impl<T> AVLTree<T> {
pub fn new() -> Self {
AVLTree { root: None }
}
pub fn at(&self, k: usize) -> Option<&T> {
if let Some(n) = at(&self.root, k) {
Some(&n.val)
} else {
None
}
}
pub fn remove_at(&mut self, k: usize) -> Option<T> {
let (root, kth) = remove_at(self.root.take(), k);
self.root = root;
kth
}
pub fn merge(&mut self, right: &mut AVLTree<T>) {
self.root = merge(self.root.take(), right.root.take());
}
pub fn split_at(&mut self, k: usize) -> (AVLTree<T>, AVLTree<T>) {
let (left, right) = split_at(self.root.take(), k);
(AVLTree { root: left }, AVLTree { root: right })
}
}
impl<T: PartialOrd> AVLTree<T> {
pub fn insert(&mut self, val: T) {
self.root = insert(self.root.take(), val);
}
pub fn remove(&mut self, val: &T) -> bool {
let (root, flg) = remove(self.root.take(), val);
self.root = root;
flg
}
pub fn find(&self, val: &T) -> Option<&T> {
if let Some(n) = find(&self.root, &val) {
Some(&n.val)
} else {
None
}
}
pub fn split(&mut self, val: &T) -> (AVLTree<T>, AVLTree<T>) {
let (left, right) = split(self.root.take(), val);
(AVLTree { root: left }, AVLTree { root: right })
}
// get the number of elements which is less than val
pub fn count_less(&mut self, val: &T) -> usize {
count_less(&self.root, &val)
}
}
impl<T: PartialOrd + Debug> AVLTree<T> {
pub fn check(&self) {
if let Some(ref root) = self.root {
debug_assert!(balance_factor(&self.root).abs() < 2);
root.verify();
}
}
}
#[derive(Debug)]
pub struct Node<T> {
left: Link<T>,
val: T,
right: Link<T>,
size: usize,
height: isize,
}
impl<T> Node<T> {
fn new(val: T) -> Link<T> {
Some(Box::new(Node {
val: val,
left: None,
right: None,
size: 1,
height: 1,
}))
}
#[inline]
fn update(&mut self) {
self.size = size(&self.left) + size(&self.right) + 1;
self.height = std::cmp::max(height(&self.left), height(&self.right)) + 1;
}
}
impl<T: PartialOrd> Node<T> {
fn verify(&self) {
debug_assert_eq!(self.size, size(&self.left) + size(&self.right) + 1);
debug_assert_eq!(
self.height,
std::cmp::max(height(&self.left), height(&self.right)) + 1
);
debug_assert!(balance_factor(&self.left).abs() < 2);
if let Some(ref l) = self.left {
debug_assert!(l.val < self.val);
l.verify();
}
debug_assert!(balance_factor(&self.right).abs() < 2);
if let Some(ref r) = self.right {
debug_assert!(self.val < r.val);
r.verify();
}
}
}
// verified
fn insert<T: PartialOrd>(n: Link<T>, val: T) -> Link<T> {
match n {
None => Node::new(val),
Some(mut n) => {
match val.partial_cmp(&n.val).unwrap() {
Less => n.left = insert(n.left.take(), val),
Equal => (), // TODO: optional
Greater => n.right = insert(n.right.take(), val),
}
n.update();
balance(Some(n))
}
}
}
// verified
fn remove<T: PartialOrd>(n: Link<T>, val: &T) -> (Link<T>, bool) {
match n {
None => (n, false),
Some(mut n) => match val.partial_cmp(&n.val).unwrap() {
Less => {
let (left, flg) = remove(n.left.take(), val);
n.left = left;
n.update();
(balance(Some(n)), flg)
}
Equal => {
let (q, r) = (n.left.take(), n.right.take());
match r {
None => (q, true),
Some(r) => {
let (r, mut min) = remove_leftmost(Some(r));
min.left = q;
min.right = r;
min.update();
(balance(Some(min)), true)
}
}
}
Greater => {
let (right, flg) = remove(n.right.take(), val);
n.right = right;
n.update();
(balance(Some(n)), flg)
}
},
}
}
// verified
// delete the leftmost (minimum) Node from n. returns (modified n, removed Node)
fn remove_leftmost<T>(n: Link<T>) -> (Link<T>, Box<Node<T>>) {
let mut n = n.unwrap();
match n.left {
None => {
let right = n.right.take();
n.update();
(right, n)
}
_ => {
let (left, min) = remove_leftmost(n.left.take());
n.left = left;
n.update();
(balance(Some(n)), min)
}
}
}
// (new n, deleted max)
// delete the rightmost (maximum) Node from n. returns (modified n, removed Node)
fn remove_rightmost<T>(n: Link<T>) -> (Link<T>, Box<Node<T>>) {
let mut n = n.unwrap();
match n.right {
None => {
let left = n.left.take();
n.update();
(left, n)
}
_ => {
let (right, max) = remove_rightmost(n.right.take());
n.right = right;
n.update();
(balance(Some(n)), max)
}
}
}
fn merge<T>(left: Link<T>, right: Link<T>) -> Link<T> {
match (left, right) {
(None, right) => right,
(left, None) => left,
(mut left, mut right) => {
if size(&left) > size(&right) {
let (right, x) = remove_leftmost(right.take());
merge_into_left(left, Some(x), right)
} else {
let (left, x) = remove_rightmost(left.take());
merge_into_right(left, Some(x), right)
}
}
}
}
// merge left, par, and right. size of par must be one.
fn merge3<T>(left: Link<T>, par: Link<T>, right: Link<T>) -> Link<T> {
debug_assert!(size(&par) == 1);
if size(&left) > size(&right) {
merge_into_left(left, par, right)
} else {
merge_into_right(left, par, right)
}
}
fn merge_into_left<T>(left: Link<T>, par: Link<T>, right: Link<T>) -> Link<T> {
debug_assert_eq!(size(&par), 1);
debug_assert!(size(&left) >= size(&right));
if height(&left) > height(&right) + 1 {
let mut left = left.unwrap();
left.right = merge_into_left(left.right.take(), par, right);
left.update();
balance(Some(left))
} else {
let mut par = par.unwrap();
par.left = left;
par.right = right;
par.update();
balance(Some(par))
}
}
fn merge_into_right<T>(left: Link<T>, par: Link<T>, right: Link<T>) -> Link<T> {
debug_assert_eq!(size(&par), 1);
debug_assert!(size(&left) <= size(&right));
if height(&left) + 1 < height(&right) {
let mut right = right.unwrap();
right.left = merge_into_right(left, par, right.left.take());
right.update();
balance(Some(right))
} else {
let mut par = par.unwrap();
par.left = left;
par.right = right;
par.update();
balance(Some(par))
}
}
fn split<T: PartialOrd>(n: Link<T>, val: &T) -> (Link<T>, Link<T>) {
match n {
None => (None, None),
Some(mut n) => {
let (l, r) = (n.left.take(), n.right.take());
n.update();
match val.partial_cmp(&n.val).unwrap() {
Less => {
let (ll, lr) = split(l, val);
(ll, merge3(lr, Some(n), r))
}
Equal => (l, merge(Some(n), r)),
Greater => {
let (rl, rr) = split(r, val);
(merge3(l, Some(n), rl), rr)
}
}
}
}
}
// verified
fn split_at<T>(n: Link<T>, k: usize) -> (Link<T>, Link<T>) {
debug_assert!(k <= size(&n));
match n {
None => (None, None),
Some(mut n) => {
let (l, r) = (n.left.take(), n.right.take());
n.update();
let sl = size(&l);
match k.cmp(&sl) {
Less => {
let (ll, lr) = split_at(l, k);
(ll, merge3(lr, Some(n), r))
}
Equal => (l, merge(Some(n), r)),
Greater => {
let (rl, rr) = split_at(r, k - sl - 1);
(merge3(l, Some(n), rl), rr)
}
}
}
}
}
fn find<'a, T: PartialOrd>(n: &'a Link<T>, val: &T) -> Option<&'a Node<T>> {
match *n {
None => None,
Some(ref n) => match val.partial_cmp(&n.val).unwrap() {
Less => find(&n.left, val),
Equal => Some(n),
Greater => find(&n.right, val),
},
}
}
// verified
fn at<T>(n: &Link<T>, k: usize) -> Option<&Node<T>> {
match *n {
None => None,
Some(ref n) => {
let ls = size(&n.left);
match k.cmp(&ls) {
Less => at(&n.left, k),
Equal => Some(n),
Greater => at(&n.right, k - ls - 1),
}
}
}
}
// verified
fn remove_at<T>(n: Link<T>, k: usize) -> (Link<T>, Option<T>) {
match n {
None => (n, None),
Some(mut n) => match k.cmp(&size(&n.left)) {
Less => {
let (left, kth) = remove_at(n.left.take(), k);
n.left = left;
n.update();
(balance(Some(n)), kth)
}
Equal => {
let (q, r) = (n.left.take(), n.right.take());
match r {
None => (q, Some(n.val)),
Some(r) => {
let (r, mut min) = remove_leftmost(Some(r));
min.left = q;
min.right = r;
min.update();
(balance(Some(min)), Some(n.val))
}
}
}
Greater => {
let sl = size(&n.left);
let (right, kth) = remove_at(n.right.take(), k - sl - 1);
n.right = right;
n.update();
(balance(Some(n)), kth)
}
},
}
}
fn count_less<T: PartialOrd>(n: &Link<T>, val: &T) -> usize {
match *n {
None => 0,
Some(ref n) => {
let sl = size(&n.left);
match val.partial_cmp(&n.val).unwrap() {
Less => count_less(&n.left, val),
Equal => sl,
Greater => sl + 1 + count_less(&n.right, val),
}
}
}
}
// verified
fn balance<T>(n: Link<T>) -> Link<T> {
match balance_factor(&n) {
-1 | 0 | 1 => n,
2 => {
let mut n = n.unwrap();
if balance_factor(&n.right) < 0 {
n.right = rotate_right(n.right.take());
}
rotate_left(Some(n))
}
-2 => {
let mut n = n.unwrap();
if balance_factor(&n.left) > 0 {
n.left = rotate_left(n.left.take());
}
rotate_right(Some(n))
}
_ => panic!(),
}
}
// verified
#[inline]
fn rotate_left<T>(q: Link<T>) -> Link<T> {
let mut q = q.unwrap();
let mut p = q.right.take().unwrap();
q.right = p.left.take();
q.update();
p.left = Some(q);
p.update();
Some(p)
}
// verified
#[inline]
fn rotate_right<T>(p: Link<T>) -> Link<T> {
let mut p = p.unwrap();
let mut q = p.left.take().unwrap();
p.left = q.right.take();
p.update();
q.right = Some(p);
q.update();
Some(q)
}
// verified
#[inline]
fn size<T>(n: &Link<T>) -> usize {
match *n {
None => 0,
Some(ref n) => n.size,
}
}
// verified
#[inline]
fn height<T>(n: &Link<T>) -> isize {
match *n {
None => 0,
Some(ref n) => n.height,
}
}
// verified
#[inline]
fn balance_factor<T>(n: &Link<T>) -> isize {
match *n {
None => panic!(),
Some(ref n) => height(&n.right) - height(&n.left),
}
}
// verified
pub fn insert_ms<T: PartialOrd>(n: Link<T>, x: T) -> Link<T> {
let (l, r) = split(n, &x);
merge3(l, Node::new(x), r)
}
// verified
pub fn remove_at_ms<T>(n: Link<T>, k: usize) -> (Link<T>, Option<T>) {
if k < size(&n) {
let (l, r) = split_at(n, k);
let (right, kth) = remove_leftmost(r);
(merge(l, right), Some(kth.val))
} else {
(n, None)
}
}
// verified
pub fn remove_ms<T: PartialOrd>(n: Link<T>, val: &T) -> (Link<T>, bool) {
if find(&n, &val).is_some() {
let (l, r) = split(n, val);
let (right, _) = remove_leftmost(r);
(merge(l, right), true)
} else {
(n, false)
}
}
}