Home > libalgo > Scapegoat tree

Scapegoat tree

実装

// Using PartialOrd instead of Ord because float types does not have
// the impl of Ord :(

#[allow(dead_code)]
mod scapegoat_tree {
    use std;
    use std::cmp::Ordering::*;

    type Link<T> = Option<Box<Node<T>>>;

    pub struct ScapegoatTree<T> {
        pub root: Link<T>,
    }

    impl<T> ScapegoatTree<T> {
        pub fn new() -> Self {
            ScapegoatTree { root: None }
        }

        pub fn at(&self, k: usize) -> Option<&T> {
            if let Some(n) = at(&self.root, k) {
                Some(&n.val)
            } else {
                None
            }
        }

        pub fn remove_at(&mut self, k: usize) -> bool {
            if remove_at(&mut self.root, k) {
                if self.should_rebuild() {
                    self.root = rebuild(self.root.take());
                }
                true
            } else {
                false
            }
        }

        fn should_rebuild(&self) -> bool {
            total_size(&self.root) > size(&self.root) * 2
        }
    }

    impl<T: PartialOrd> ScapegoatTree<T> {
        pub fn insert(&mut self, val: T) {
            self.root = insert(self.root.take(), val);
        }

        pub fn remove(&mut self, val: &T) -> bool {
            if remove(&mut self.root, val) {
                if self.should_rebuild() {
                    self.root = rebuild(self.root.take());
                }
                true
            } else {
                false
            }
        }

        pub fn find(&self, val: &T) -> Option<&T> {
            if let Some(n) = find(&self.root, &val) {
                Some(&n.val)
            } else {
                None
            }
        }
    }

    pub struct Node<T> {
        val: T,
        left: Link<T>,
        right: Link<T>,
        size: usize,
        total_size: usize,
        height: usize,
        deleted: bool,
    }

    impl<T> Node<T> {
        fn new(val: T) -> Link<T> {
            Some(Box::new(Node {
                val: val,
                left: None,
                right: None,
                size: 1,
                total_size: 1,
                height: 1,
                deleted: false,
            }))
        }

        fn update(&mut self) {
            self.size = size(&self.left) + size(&self.right) + if self.deleted { 0 } else { 1 };
            self.total_size = total_size(&self.left) + total_size(&self.right) + 1;
            self.height = std::cmp::max(height(&self.left), height(&self.right)) + 1;
        }
    }

    // verified
    fn insert<T: PartialOrd>(n: Link<T>, val: T) -> Link<T> {
        match n {
            None => Node::new(val),
            Some(mut n) => {
                match val.partial_cmp(&n.val).unwrap() {
                    Less => n.left = insert(n.left.take(), val),
                    Equal => {
                        if n.deleted {
                            n.deleted = false
                        }
                    }
                    Greater => n.right = insert(n.right.take(), val),
                }
                n.update();
                if n.height as f64 > 3.0 * (n.total_size as f64).ln() {
                    rebuild(Some(n))
                } else {
                    Some(n)
                }
            }
        }
    }

    // verified
    fn find<'a, T: PartialOrd>(n: &'a Link<T>, val: &T) -> Option<&'a Node<T>> {
        match *n {
            None => None,
            Some(ref n) => match val.partial_cmp(&n.val).unwrap() {
                Less => find(&n.left, val),
                Equal => if !n.deleted {
                    Some(n)
                } else {
                    None
                },
                Greater => find(&n.right, val),
            },
        }
    }

    // verified
    fn remove<T: PartialOrd>(n: &mut Link<T>, val: &T) -> bool {
        match *n {
            None => false,
            Some(ref mut n) => {
                let res = match val.partial_cmp(&n.val).unwrap() {
                    Less => remove(&mut n.left, val),
                    Equal => {
                        if !n.deleted {
                            n.deleted = true;
                            true
                        } else {
                            false
                        }
                    }
                    Greater => remove(&mut n.right, val),
                };
                if res {
                    n.update()
                }
                res
            }
        }
    }

    // verified
    fn at<T>(n: &Link<T>, k: usize) -> Option<&Node<T>> {
        match *n {
            None => None,
            Some(ref n) => {
                let ls = size(&n.left);
                match k.cmp(&ls) {
                    Less => at(&n.left, k),
                    Equal if !n.deleted => Some(n),
                    _ => at(&n.right, k - if n.deleted { ls } else { ls + 1 }),
                }
            }
        }
    }

    // verified
    fn remove_at<T>(n: &mut Link<T>, k: usize) -> bool {
        match *n {
            None => false,
            Some(ref mut n) => {
                let ls = size(&n.left);
                let del = n.deleted;
                let res = match k.cmp(&ls) {
                    Less => remove_at(&mut n.left, k),
                    Equal if !del => {
                        n.deleted = true;
                        true
                    }
                    _ => remove_at(&mut n.right, k - if del { ls } else { ls + 1 }),
                };
                if res {
                    n.update()
                }
                res
            }
        }
    }

    // verified
    fn rebuild<T>(n: Link<T>) -> Link<T> {
        let mut buf = Vec::with_capacity(size(&n));
        flatten(n, &mut buf);
        let len = buf.len();
        make_bst(&mut buf, 0, len)
    }

    // verified
    fn flatten<T>(n: Link<T>, buf: &mut Vec<Link<T>>) {
        match n {
            Some(mut n) => {
                let (l, r) = (n.left.take(), n.right.take());
                flatten(l, buf);
                if !n.deleted {
                    buf.push(Some(n));
                }
                flatten(r, buf);
            }
            _ => (),
        }
    }

    // verified
    fn make_bst<T>(buf: &mut Vec<Link<T>>, l: usize, r: usize) -> Link<T> {
        if l == r {
            None
        } else {
            let m = l + (r - l) / 2;
            let mut res = buf[m].take().unwrap();
            res.left = make_bst(buf, l, m);
            res.right = make_bst(buf, m + 1, r);
            res.update();
            Some(res)
        }
    }

    // verified
    #[inline(always)]
    fn size<T>(n: &Link<T>) -> usize {
        match *n {
            None => 0,
            Some(ref n) => n.size,
        }
    }

    #[inline(always)]
    fn total_size<T>(n: &Link<T>) -> usize {
        match *n {
            None => 0,
            Some(ref n) => n.total_size,
        }
    }

    // verified
    #[inline(always)]
    fn height<T>(n: &Link<T>) -> usize {
        match *n {
            None => 0,
            Some(ref n) => n.height,
        }
    }
}

検証

https://arc033.contest.atcoder.jp/submissions/1290335