Home > libalgo > AVL tree

AVL tree

概要

find, insert, erase, merge, split に対応.全て $O(h) = O(\log n)$ ($h$ は高さ).RBST と Treap よりちょっと速い.TODO: set と multiset との切り替え機能の実装.

使い方

at (インデックスアクセス), find, insert, erase, merge, split は通常通り.count_less(x) で x 以下の要素の数を取得.

実装

// Using PartialOrd instead of Ord because float types does not have
// the impl of Ord :(

#[allow(dead_code)]
mod avl_tree {
    use std;
    use std::cmp::Ordering::*;
    use std::fmt::Debug;

    type Link<T> = Option<Box<Node<T>>>;

    pub struct AVLTree<T> {
        pub root: Link<T>,
    }

    impl<T> AVLTree<T> {
        pub fn new() -> Self {
            AVLTree { root: None }
        }

        pub fn at(&self, k: usize) -> Option<&T> {
            if let Some(n) = at(&self.root, k) {
                Some(&n.val)
            } else {
                None
            }
        }

        pub fn remove_at(&mut self, k: usize) -> Option<T> {
            let (root, kth) = remove_at(self.root.take(), k);
            self.root = root;
            kth
        }

        pub fn merge(&mut self, right: &mut AVLTree<T>) {
            self.root = merge(self.root.take(), right.root.take());
        }

        pub fn split_at(&mut self, k: usize) -> (AVLTree<T>, AVLTree<T>) {
            let (left, right) = split_at(self.root.take(), k);
            (AVLTree { root: left }, AVLTree { root: right })
        }
    }

    impl<T: PartialOrd> AVLTree<T> {
        pub fn insert(&mut self, val: T) {
            self.root = insert(self.root.take(), val);
        }

        pub fn remove(&mut self, val: &T) -> bool {
            let (root, flg) = remove(self.root.take(), val);
            self.root = root;
            flg
        }

        pub fn find(&self, val: &T) -> Option<&T> {
            if let Some(n) = find(&self.root, &val) {
                Some(&n.val)
            } else {
                None
            }
        }

        pub fn split(&mut self, val: &T) -> (AVLTree<T>, AVLTree<T>) {
            let (left, right) = split(self.root.take(), val);
            (AVLTree { root: left }, AVLTree { root: right })
        }

        // get the number of elements which is less than val
        pub fn count_less(&mut self, val: &T) -> usize {
            count_less(&self.root, &val)
        }
    }

    impl<T: PartialOrd + Debug> AVLTree<T> {
        pub fn check(&self) {
            if let Some(ref root) = self.root {
                debug_assert!(balance_factor(&self.root).abs() < 2);
                root.verify();
            }
        }
    }

    #[derive(Debug)]
    pub struct Node<T> {
        left: Link<T>,
        val: T,
        right: Link<T>,
        size: usize,
        height: isize,
    }

    impl<T> Node<T> {
        fn new(val: T) -> Link<T> {
            Some(Box::new(Node {
                val: val,
                left: None,
                right: None,
                size: 1,
                height: 1,
            }))
        }

        #[inline]
        fn update(&mut self) {
            self.size = size(&self.left) + size(&self.right) + 1;
            self.height = std::cmp::max(height(&self.left), height(&self.right)) + 1;
        }
    }

    impl<T: PartialOrd> Node<T> {
        fn verify(&self) {
            debug_assert_eq!(self.size, size(&self.left) + size(&self.right) + 1);
            debug_assert_eq!(
                self.height,
                std::cmp::max(height(&self.left), height(&self.right)) + 1
            );
            debug_assert!(balance_factor(&self.left).abs() < 2);
            if let Some(ref l) = self.left {
                debug_assert!(l.val < self.val);
                l.verify();
            }
            debug_assert!(balance_factor(&self.right).abs() < 2);
            if let Some(ref r) = self.right {
                debug_assert!(self.val < r.val);
                r.verify();
            }
        }
    }

    // verified
    fn insert<T: PartialOrd>(n: Link<T>, val: T) -> Link<T> {
        match n {
            None => Node::new(val),
            Some(mut n) => {
                match val.partial_cmp(&n.val).unwrap() {
                    Less => n.left = insert(n.left.take(), val),
                    Equal => (), // TODO: optional
                    Greater => n.right = insert(n.right.take(), val),
                }
                n.update();
                balance(Some(n))
            }
        }
    }

    // verified
    fn remove<T: PartialOrd>(n: Link<T>, val: &T) -> (Link<T>, bool) {
        match n {
            None => (n, false),
            Some(mut n) => match val.partial_cmp(&n.val).unwrap() {
                Less => {
                    let (left, flg) = remove(n.left.take(), val);
                    n.left = left;
                    n.update();
                    (balance(Some(n)), flg)
                }
                Equal => {
                    let (q, r) = (n.left.take(), n.right.take());
                    match r {
                        None => (q, true),
                        Some(r) => {
                            let (r, mut min) = remove_leftmost(Some(r));
                            min.left = q;
                            min.right = r;
                            min.update();
                            (balance(Some(min)), true)
                        }
                    }
                }
                Greater => {
                    let (right, flg) = remove(n.right.take(), val);
                    n.right = right;
                    n.update();
                    (balance(Some(n)), flg)
                }
            },
        }
    }

    // verified
    // delete the leftmost (minimum) Node from n. returns (modified n, removed Node)
    fn remove_leftmost<T>(n: Link<T>) -> (Link<T>, Box<Node<T>>) {
        let mut n = n.unwrap();
        match n.left {
            None => {
                let right = n.right.take();
                n.update();
                (right, n)
            }
            _ => {
                let (left, min) = remove_leftmost(n.left.take());
                n.left = left;
                n.update();
                (balance(Some(n)), min)
            }
        }
    }

    // (new n, deleted max)
    // delete the rightmost (maximum) Node from n. returns (modified n, removed Node)
    fn remove_rightmost<T>(n: Link<T>) -> (Link<T>, Box<Node<T>>) {
        let mut n = n.unwrap();
        match n.right {
            None => {
                let left = n.left.take();
                n.update();
                (left, n)
            }
            _ => {
                let (right, max) = remove_rightmost(n.right.take());
                n.right = right;
                n.update();
                (balance(Some(n)), max)
            }
        }
    }

    fn merge<T>(left: Link<T>, right: Link<T>) -> Link<T> {
        match (left, right) {
            (None, right) => right,
            (left, None) => left,
            (mut left, mut right) => {
                if size(&left) > size(&right) {
                    let (right, x) = remove_leftmost(right.take());
                    merge_into_left(left, Some(x), right)
                } else {
                    let (left, x) = remove_rightmost(left.take());
                    merge_into_right(left, Some(x), right)
                }
            }
        }
    }

    // merge left, par, and right. size of par must be one.
    fn merge3<T>(left: Link<T>, par: Link<T>, right: Link<T>) -> Link<T> {
        debug_assert!(size(&par) == 1);
        if size(&left) > size(&right) {
            merge_into_left(left, par, right)
        } else {
            merge_into_right(left, par, right)
        }
    }

    fn merge_into_left<T>(left: Link<T>, par: Link<T>, right: Link<T>) -> Link<T> {
        debug_assert_eq!(size(&par), 1);
        debug_assert!(size(&left) >= size(&right));
        if height(&left) > height(&right) + 1 {
            let mut left = left.unwrap();
            left.right = merge_into_left(left.right.take(), par, right);
            left.update();
            balance(Some(left))
        } else {
            let mut par = par.unwrap();
            par.left = left;
            par.right = right;
            par.update();
            balance(Some(par))
        }
    }

    fn merge_into_right<T>(left: Link<T>, par: Link<T>, right: Link<T>) -> Link<T> {
        debug_assert_eq!(size(&par), 1);
        debug_assert!(size(&left) <= size(&right));
        if height(&left) + 1 < height(&right) {
            let mut right = right.unwrap();
            right.left = merge_into_right(left, par, right.left.take());
            right.update();
            balance(Some(right))
        } else {
            let mut par = par.unwrap();
            par.left = left;
            par.right = right;
            par.update();
            balance(Some(par))
        }
    }

    fn split<T: PartialOrd>(n: Link<T>, val: &T) -> (Link<T>, Link<T>) {
        match n {
            None => (None, None),
            Some(mut n) => {
                let (l, r) = (n.left.take(), n.right.take());
                n.update();
                match val.partial_cmp(&n.val).unwrap() {
                    Less => {
                        let (ll, lr) = split(l, val);
                        (ll, merge3(lr, Some(n), r))
                    }
                    Equal => (l, merge(Some(n), r)),
                    Greater => {
                        let (rl, rr) = split(r, val);
                        (merge3(l, Some(n), rl), rr)
                    }
                }
            }
        }
    }

    // verified
    fn split_at<T>(n: Link<T>, k: usize) -> (Link<T>, Link<T>) {
        debug_assert!(k <= size(&n));
        match n {
            None => (None, None),
            Some(mut n) => {
                let (l, r) = (n.left.take(), n.right.take());
                n.update();
                let sl = size(&l);
                match k.cmp(&sl) {
                    Less => {
                        let (ll, lr) = split_at(l, k);
                        (ll, merge3(lr, Some(n), r))
                    }
                    Equal => (l, merge(Some(n), r)),
                    Greater => {
                        let (rl, rr) = split_at(r, k - sl - 1);
                        (merge3(l, Some(n), rl), rr)
                    }
                }
            }
        }
    }

    fn find<'a, T: PartialOrd>(n: &'a Link<T>, val: &T) -> Option<&'a Node<T>> {
        match *n {
            None => None,
            Some(ref n) => match val.partial_cmp(&n.val).unwrap() {
                Less => find(&n.left, val),
                Equal => Some(n),
                Greater => find(&n.right, val),
            },
        }
    }

    // verified
    fn at<T>(n: &Link<T>, k: usize) -> Option<&Node<T>> {
        match *n {
            None => None,
            Some(ref n) => {
                let ls = size(&n.left);
                match k.cmp(&ls) {
                    Less => at(&n.left, k),
                    Equal => Some(n),
                    Greater => at(&n.right, k - ls - 1),
                }
            }
        }
    }

    // verified
    fn remove_at<T>(n: Link<T>, k: usize) -> (Link<T>, Option<T>) {
        match n {
            None => (n, None),
            Some(mut n) => match k.cmp(&size(&n.left)) {
                Less => {
                    let (left, kth) = remove_at(n.left.take(), k);
                    n.left = left;
                    n.update();
                    (balance(Some(n)), kth)
                }
                Equal => {
                    let (q, r) = (n.left.take(), n.right.take());
                    match r {
                        None => (q, Some(n.val)),
                        Some(r) => {
                            let (r, mut min) = remove_leftmost(Some(r));
                            min.left = q;
                            min.right = r;
                            min.update();
                            (balance(Some(min)), Some(n.val))
                        }
                    }
                }
                Greater => {
                    let sl = size(&n.left);
                    let (right, kth) = remove_at(n.right.take(), k - sl - 1);
                    n.right = right;
                    n.update();
                    (balance(Some(n)), kth)
                }
            },
        }
    }

    fn count_less<T: PartialOrd>(n: &Link<T>, val: &T) -> usize {
        match *n {
            None => 0,
            Some(ref n) => {
                let sl = size(&n.left);
                match val.partial_cmp(&n.val).unwrap() {
                    Less => count_less(&n.left, val),
                    Equal => sl,
                    Greater => sl + 1 + count_less(&n.right, val),
                }
            }
        }
    }

    // verified
    fn balance<T>(n: Link<T>) -> Link<T> {
        match balance_factor(&n) {
            -1 | 0 | 1 => n,
            2 => {
                let mut n = n.unwrap();
                if balance_factor(&n.right) < 0 {
                    n.right = rotate_right(n.right.take());
                }
                rotate_left(Some(n))
            }
            -2 => {
                let mut n = n.unwrap();
                if balance_factor(&n.left) > 0 {
                    n.left = rotate_left(n.left.take());
                }
                rotate_right(Some(n))
            }
            _ => panic!(),
        }
    }

    // verified
    #[inline]
    fn rotate_left<T>(q: Link<T>) -> Link<T> {
        let mut q = q.unwrap();
        let mut p = q.right.take().unwrap();
        q.right = p.left.take();
        q.update();
        p.left = Some(q);
        p.update();
        Some(p)
    }

    // verified
    #[inline]
    fn rotate_right<T>(p: Link<T>) -> Link<T> {
        let mut p = p.unwrap();
        let mut q = p.left.take().unwrap();
        p.left = q.right.take();
        p.update();
        q.right = Some(p);
        q.update();
        Some(q)
    }

    // verified
    #[inline]
    fn size<T>(n: &Link<T>) -> usize {
        match *n {
            None => 0,
            Some(ref n) => n.size,
        }
    }

    // verified
    #[inline]
    fn height<T>(n: &Link<T>) -> isize {
        match *n {
            None => 0,
            Some(ref n) => n.height,
        }
    }

    // verified
    #[inline]
    fn balance_factor<T>(n: &Link<T>) -> isize {
        match *n {
            None => panic!(),
            Some(ref n) => height(&n.right) - height(&n.left),
        }
    }

    // verified
    pub fn insert_ms<T: PartialOrd>(n: Link<T>, x: T) -> Link<T> {
        let (l, r) = split(n, &x);
        merge3(l, Node::new(x), r)
    }

    // verified
    pub fn remove_at_ms<T>(n: Link<T>, k: usize) -> (Link<T>, Option<T>) {
        if k < size(&n) {
            let (l, r) = split_at(n, k);
            let (right, kth) = remove_leftmost(r);
            (merge(l, right), Some(kth.val))
        } else {
            (n, None)
        }
    }

    // verified
    pub fn remove_ms<T: PartialOrd>(n: Link<T>, val: &T) -> (Link<T>, bool) {
        if find(&n, &val).is_some() {
            let (l, r) = split(n, val);
            let (right, _) = remove_leftmost(r);
            (merge(l, right), true)
        } else {
            (n, false)
        }
    }
}

検証

http://arc033.contest.atcoder.jp/submissions/1292412