Home > libalgo > セグメントツリー (区間 add, 区間 sum)

セグメントツリー (区間 add, 区間 sum)

使い方

未検証なのでバグっている可能性が高いです. 外からは t.add(l,r), t.sum(l,r) を呼ぶ.

実装

template <typename T>
class segment_tree {
public:
    segment_tree(int n_) : n(__lowest_power_of_2(n_)), dat(n * 2, 0), lazy(n * 2, 0) {}

    template <typename Iterator>
    segment_tree(Iterator left, Iterator right)
        : n(__lowest_power_of_2(std::distance(left, right))), dat(n * 2, 0), lazy(n * 2, 0) {
        std::copy(left + n, right + n, dat.begin());
        for (int i = n - 1; i >= 1; --i) {
            dat[i] = std::max(dat[i * 2], dat[i * 2 + 1]);
        }
    }

private:
    int n;
    std::vector<T> dat, lazy;

public:
    void add(int ql, int qr, const T &x) { __add(1, 0, n, ql, qr, x); }

    T sum(int ql, int qr) { return __sum(1, 0, n, ql, qr); }

private:
    void __flush_lazy(int v, int l, int r) {
        dat[v] += lazy[v] * (r - l);
        if (v < n) {
            lazy[v * 2] += lazy[v];
            lazy[v * 2 + 1] += lazy[v];
        }
        lazy[v] = 0;
    }

    void __update_dat(int v) { dat[v] = dat[v * 2] + dat[v * 2 + 1]; }

    void __add(int n, int l, int r, const int ql, const int qr, const T &x) {
        __flush_lazy(n, l, r);
        if (r <= ql || qr <= l) {
            return;
        } else if (ql <= l && r <= qr) {
            lazy[n] += x;
            __flush_lazy(n, l, r);
        } else {
            int m = (l + r) / 2;
            __add(n * 2, l, m, ql, qr, x);
            __add(n * 2 + 1, m, r, ql, qr, x);
            __update_dat(n);
        }
    }

    T __sum(int n, int l, int r, const int ql, const int qr) {
        __flush_lazy(n, l, r);
        if (r <= ql || qr <= l) {
            return 0;
        } else if (ql <= l && r <= qr) {
            return dat[n];
        } else {
            int m = (l + r) / 2;
            T res_left = __sum(n * 2, l, m, ql, qr);
            T res_right = __sum(n * 2 + 1, m, r, ql, qr);
            __update_dat(n);
            return res_left + res_right;
        }
    }

    int __lowest_power_of_2(int x) const {
        int res = 1;
        while (res < x) res <<= 1;
        return res;
    }
};

検証

参考文献

僕のセグメントツリーの使い方 - kyuridenamidaのチラ裏