Home > libalgo > Randomized binary search trees (RBST)

Randomized binary search trees (RBST)

実装

// Using PartialOrd instead of Ord because float types does not have
// the impl of Ord :(

#[allow(dead_code)]
struct Node<T> {
    val: T,
    size: usize,
    left: Option<Box<Node<T>>>,
    right: Option<Box<Node<T>>>,
}

#[allow(dead_code)]
type Link<T> = Option<Box<Node<T>>>;

#[allow(dead_code)]
impl<T: std::cmp::PartialOrd + Copy> Node<T> {
    fn singleton(val: T) -> Node<T> {
        Node {
            val: val,
            size: 1,
            left: None,
            right: None,
        }
    }

    fn update(&mut self) {
        self.size = Self::size(&self.left) + Self::size(&self.right) + 1;
    }

    fn insert(n: Link<T>, new_val: T) -> Link<T> {
        let (left, right) = Self::split_less(n, new_val);
        let single = Some(Box::new(Node::singleton(new_val)));
        let n = Self::merge(left, single);
        Self::merge(n, right)
    }

    fn erase_all(n: Link<T>, val: T) -> Link<T> {
        let (left, right) = Self::split_less(n, val);
        let (_, right) = Self::split_leq(right, val);
        Self::merge(left, right)
    }

    fn erase_one(n: Link<T>, val: T) -> Link<T> {
        let (left, right) = Self::split_less(n, val);
        let (_, right) = Self::split_at(right, 1);
        Self::merge(left, right)
    }

    fn merge(left: Link<T>, right: Link<T>) -> Link<T> {
        match (left, right) {
            (None, r) => r,
            (l, None) => l,
            (mut l, mut r) => {
                let mut l = l.take().unwrap();
                let mut r = r.take().unwrap();
                if Self::rand() as usize % (l.size + r.size) < l.size {
                    let lr = (*l).right.take();
                    l.right = Self::merge(lr, Some(r));
                    l.update();
                    Some(l)
                } else {
                    let rl = (*r).left.take();
                    r.left = Self::merge(Some(l), rl);
                    r.update();
                    Some(r)
                }
            }
        }
    }

    fn split_less(n: Link<T>, val: T) -> (Link<T>, Link<T>) {
        Self::split_cmp_impl(n, &|n: &Box<Node<T>>| n.val < val)
    }

    fn split_leq(n: Link<T>, val: T) -> (Link<T>, Link<T>) {
        Self::split_cmp_impl(n, &|n: &Box<Node<T>>| n.val <= val)
    }

    fn split_cmp_impl(n: Link<T>, f: &Fn(&Box<Node<T>>) -> bool) -> (Link<T>, Link<T>) {
        match n {
            None => (None, None),
            Some(mut n) => {
                if f(&n) {
                    let (l, r) = Self::split_cmp_impl(n.right.take(), f);
                    n.right = l;
                    n.update();
                    (Some(n), r)
                } else {
                    let (l, r) = Self::split_cmp_impl(n.left.take(), f);
                    n.left = r;
                    n.update();
                    (l, Some(n))
                }
            }
        }
    }

    fn split_at(n: Link<T>, k: usize) -> (Link<T>, Link<T>) {
        match n {
            None => (None, None),
            Some(mut n) => {
                let ls = Self::size(&n.left);
                if k <= ls {
                    let nl = n.left.take();
                    let (l, r) = Self::split_at(nl, k);
                    n.left = r;
                    n.update();
                    (l, Some(n))
                } else {
                    let nr = n.right.take();
                    let (l, r) = Self::split_at(nr, k - ls - 1);
                    n.right = l;
                    n.update();
                    (Some(n), r)
                }
            }
        }
    }

    fn size(n: &Link<T>) -> usize {
        match n {
            &None => 0,
            &Some(ref n) => n.size,
        }
    }

    fn at(n: &Link<T>, k: usize) -> Option<T> {
        match n {
            &None => None,
            &Some(ref n) => {
                let ls = Node::size(&n.left);
                if k < ls {
                    Self::at(&n.left, k)
                } else if k == ls {
                    Some(n.val)
                } else {
                    Self::at(&n.right, k - ls - 1)
                }
            }
        }
    }

    fn rand() -> u32 {
        static mut RAND_X: u32 = 123456789;
        static mut RAND_Y: u32 = 987654321;
        static mut RAND_Z: u32 = 1000000007;
        static mut RAND_W: u32 = 1145141919;
        unsafe {
            let t = RAND_X ^ (RAND_X << 11);
            RAND_X = RAND_Y;
            RAND_Y = RAND_Z;
            RAND_Z = RAND_W;
            RAND_W = (RAND_W ^ (RAND_W >> 19)) ^ (t ^ (t >> 8));
            RAND_W
        }
    }
}

検証

http://arc033.contest.atcoder.jp/submissions/1212481