AOJ2667 Tree

2015/10/28 (Wed) AOJ データ構造

問題

問題文

方針

HL 分解で解いた.

add クエリでは $v$ が乗っている chain の自分以下の頂点の区間に加算する.

dist クエリでは,頂点 $a$ と根との距離を $dist(a)$ とおくと,まず $l = lca(u,v)$ を求め $dist(u) + dist(v) - 2 \times dist(l)$ として計算する. $dist(x)$ を求めるには,まず $y = x$,$y$ が乗っている chain の付け根 (chain の先頭の親) を $p$ とし, segtree の $p$ の値と,$depth(x) - depth(p)$ の積を答えに加える.さらに segtree の先頭から $y$ までの区間和を答えに加える. $y = p$ と更新して繰り返す. $y$ が根を含む chain に乗ったらループを抜け,根から $y$ までの区間和を答えに加える. これで答えがでる.

実はこの問題のライターは自分で,出題当時は HL 分解なんぞ知らなくてオイラーツアーが想定でした. それだと segtree を 4 つ使ってややこしいインデックスと戦う必要があったのですが, HL 分解だとそんなこともなくあっさり解けてしまいました. HL 分解は汎用性が高いですね…

実装

#define _CRT_SECURE_NO_WARNINGS
// #define _GLIBCXX_DEBUG
using namespace std;
#include <iostream>
#include <vector>
#include <utility>
#include <map>
#include <cstdio>
#include <cmath>
#include <algorithm>
#include <cstring>
#include <tuple>
#define DEBUG
typedef long long ll;
// #define int ll
typedef vector<int> vi;
typedef vector<vi> vvi;
typedef pair<int, int> pii;
#define all(c) begin(c), end(c)
#define range(i,a,b) for(int i = a; i < (int)(b); i++)
#define rep(i,b) range(i,0,b)
#define pb push_back
#define eb emplace_back
#define mp make_pair
#define mt make_tuple
 
void fastios() {
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    // #define endl '\n'
}
 
ll const mod = 1000000007;
auto const inf = numeric_limits<int>::max() / 8;
 
template <class T>
struct FenwickTree {
    int n;
    vector<T> x;
    FenwickTree(int n_) : n(n_), x(n) {}
    T sum(int i, int j) {
        if (i == 0) {
            T S = 0;
            for (; j >= 0; j = (j&(j + 1)) - 1) S += x[j];
            return S;
        }
        else return sum(0, j) - sum(0, i - 1);
    }
    void add(int k, T a) {
        for (; k < n; k |= k + 1) x[k] += a;
    }
    void set(int k, T a) {
        int x = sum(k, k + 1);
        add(k, a - x);
    }
};
 
template <class T>
struct FenwickTreeImos {
    FenwickTree<T> a, b;
    FenwickTreeImos(int n) : a(n + 1), b(n + 1) {}
    T sum(int l, int r) {
        T res = 0;
        res += a.sum(0, r) + b.sum(0, r)*r;
        res -= a.sum(0, l) + b.sum(0, l)*l;
        return res;
    }
    void add(int l, int r, T x) {
        a.add(l, -x*l);
        b.add(l, x);
        a.add(r + 1, x*r);
        b.add(r + 1, -x);
    }
};
typedef FenwickTreeImos<int> BITr;
 
struct HeavyLightDecomposition {
    int size;
    vector<vector<int>> g;
    vector<int> parent, subtreeSize, depth;
    vector<int> head, next, chain, at;
    vector<vector<int>> chains;
 
    template<class Graph>
    HeavyLightDecomposition(const Graph &g_) {
        size = g_.size();
        g.resize(size);
        vector<int64_t> es;
        for (size_t i = 0; i < g_.size(); i++) {
            for (auto &e : g_[i]) {
                int a = e.src, b = e.dst;
                if (a > b) swap(a, b);
                es.push_back((int64_t)a << 32 | b);
            }
        }
        sort(es.begin(), es.end());
        es.erase(unique(es.begin(), es.end()), es.end());
        for (auto & e : es) addEdge(e >> 32, e & 0xFFFFFFFF);
    }
 
    HeavyLightDecomposition(int n_) : size(n_), g(n_) {}
 
    void addEdge(int a, int b) {
        g[a].push_back(b);
        g[b].push_back(a);
    }
 
    int goUp(int & v) const {
        return parent[head[v]];
    }
 
    pair<int, int> getIndex(int v) const {
        return make_pair(chain[v], at[v]);
    }
 
    void decompose(const int root = 0) {
        parent.assign(size, 0);
        subtreeSize.assign(size, 0);
        depth.assign(size, -1);
        head.assign(size, 0);
        next.assign(size, -1);
        chain.assign(size, -1);
        at.assign(size, 0);
        static int stk[600010], k = 0;
        stk[k++] = root;
        parent[root] = -1;
        depth[root] = 0;
        while (k) {
            const int v = stk[--k];
            if (v >= 0) {
                stk[k++] = ~v;
                for (const int ch : g[v]) {
                    if (depth[ch] != -1) continue;
                    depth[ch] = depth[v] + 1;
                    parent[ch] = v;
                    stk[k++] = ch;
                }
            }
            else {
                const int u = ~v;
                subtreeSize[u] = 1;
                int m = 0;
                for (const int ch : g[u]) {
                    if (parent[u] == ch) continue;
                    subtreeSize[u] += subtreeSize[ch];
                    if (m < subtreeSize[ch]) {
                        m = subtreeSize[ch];
                        next[u] = ch;
                    }
                }
            }
        }
 
        k = 0;
        stk[k++] = root;
        while (k) {
            const int head_ = stk[--k];
            for (const int ch : g[head_]) {
                if (parent[head_] == ch) continue;
                stk[k++] = ch;
            }
            if (chain[head_] != -1) continue;
            chains.push_back(vector<int>());
            vector<int> & path = chains.back();
            int cur = head_;
            while (cur != -1) {
                path.push_back(cur);
                cur = next[cur];
            }
            for (size_t i = 0; i < path.size(); i++) {
                const int v = path[i];
                head[v] = path.front();
                next[v] = i + 1 != path.size() ? path[i + 1] : -1;
                chain[v] = chains.size() - 1;
                at[v] = i;
            }
        }
    }
 
    void buildSegTree() {
        segtrees.clear();
        rep(i, chains.size()) segtrees.emplace_back(chains[i].size());
    }
 
    int lca(int u, int v) {
        while (chain[u] != chain[v]) {
            if (depth[head[u]] > depth[head[v]]) u = goUp(u);
            else v = goUp(v);
        }
        return depth[u] < depth[v] ? u : v;
    }
 
    vector<FenwickTreeImos<ll>> segtrees;
    ll solve(int u, int v) {
        int l = lca(u, v);
        return climb(u, 0) + climb(v, 0) - 2 * climb(l, 0);
    }
    ll climb(const int u, const int l) {
        ll res = 0;
        int v = u;
        while (chain[v] != chain[l]) {
            int c, k;
            tie(c, k) = getIndex(v);
            res += segtrees[c].sum(0, k);
            int p = goUp(v);
            if (p != -1) {
                res += segtrees[chain[p]].sum(at[p], at[p] + 1) * (depth[u] - depth[p]);
            }
            v = p;
        }
        res += segtrees[chain[l]].sum(at[l], at[v]);
        return res;
    }
    void add(int u, ll x) {
        int c, k;
        tie(c, k) = getIndex(u);
        segtrees[c].add(k, chains[c].size(), x);
    }
};
 
int main() {
    fastios();
    int n, q;
    cin >> n >> q;
    HeavyLightDecomposition hld(n);
    rep(i, n - 1) {
        int a, b;
        cin >> a >> b;
        hld.addEdge(a, b);
    }
    hld.decompose();
    hld.buildSegTree();
    rep(i, q) {
        int t;
        cin >> t;
        if (t == 0) {
            int u, v;
            cin >> u >> v;
            cout << hld.solve(u, v) << '\n';
        }
        else {
            int v, x;
            cin >> v >> x;
            hld.add(v, x);
        }
    }
}