Skip to the content.

:heavy_check_mark: Dynamic Segment Tree (data_structure/dynamic_segment_tree.hpp)

広大な範囲に対する要素の更新や区間に対する集約された値の取得を、実際に存在する要素のみをノードとして持つことで効率的に行うことができるデータ構造です。

テンプレートパラメータ:

使い方

コンストラクタ

DynamicSegmentTree()

説明

空の動的セグメントツリーを構築します。

計算量

$O(1)$

set

void set(int64_t p, T v)

説明

インデックス p の要素の値を v に更新します。インデックス p に対応するノードが存在しない場合は新しく作成されます。

制約

$0 \le p < n$

計算量

$O(\log n)$

operator[]

T operator[](int64_t p)

説明

インデックス p の要素の値を返します。インデックス p に対応するノードが存在しない場合は単位元 e() を返します。

制約

$0 \le p < n$

計算量

$O(\log n)$

operator()

T operator()(int64_t L, int64_t R)

説明

区間 [L, R) (半開区間) の要素に対する二項演算 op の結果を返します。

制約

$0 \le L \le R \le n$

計算量

$O(\log n)$

max_right

int64_t max_right(int64_t L, const auto& f)

説明

インデックス L から開始して、条件 f(op(v[L], v[L+1], ..., v[i])) を満たす最大のインデックス i を見つけます。つまり、区間 [L, i] の集約値に対して ftrue となる最大の i+1 を返します。f は単調である必要があります。

制約

$0 \le L \le n$

計算量

$O(\log n)$

min_left

int64_t min_left(int64_t R, const auto& f)

説明

インデックス R で終了して、条件 f(op(v[i], ..., v[R-2], v[R-1])) を満たす最小のインデックス i を見つけます。つまり、区間 [i, R) の集約値に対して ftrue となる最小の i を返します。f は単調である必要があります。

制約

$0 \le R \le n$

計算量

$O(\log n)$

Verified with

Code

#pragma once
#include <memory>
#include <stack>

template <typename T, auto op, auto e, typename S = int64_t, S n = 1000000000000000001LL>
struct DynamicSegmentTree {
    struct node;
    using nptr = std::unique_ptr<node>;
    struct node {
        S p;
        T v, prod;
        nptr left, right;
        node(S _p, T _v) : p(_p), v(_v), prod(_v), left(nullptr), right(nullptr) {}
        void update() {
            prod = op(op(left ? left->prod : e(), v), right ? right->prod : e());
        }
    };
    nptr root{nullptr};
    DynamicSegmentTree() {}

    void set(S p, T v) {
        std::stack<nptr*> st;
        nptr* ptr = &root;
        S l = 0, r = n;
        bool flg = true;
        while (*ptr) {
            st.emplace(ptr);
            nptr& cur = *ptr;
            S mid = (l + r) >> 1;
            if (cur->p == p) {
                cur->v = v;
                flg = false;
                break;
            }
            if (p < mid) {
                if (cur->p < p) {
                    std::swap(cur->p, p);
                    std::swap(cur->v, v);
                }
                ptr = &cur->left;
                r = mid;
            } else {
                if (p < cur->p) {
                    std::swap(cur->p, p);
                    std::swap(cur->v, v);
                }
                ptr = &cur->right;
                l = mid;
            }
        }
        if (flg) {
            *ptr = std::make_unique<node>(p, v);
        }
        while (!st.empty()) {
            st.top()->get()->update();
            st.pop();
        }
    }

    void add(S p, T v) {
        std::stack<nptr*> st;
        nptr* ptr = &root;
        S l = 0, r = n;
        bool flg = true;
        while (*ptr) {
            st.emplace(ptr);
            nptr& cur = *ptr;
            S mid = (l + r) >> 1;
            if (cur->p == p) {
                cur->v += v;
                flg = false;
                break;
            }
            if (p < mid) {
                if (cur->p < p) {
                    std::swap(cur->p, p);
                    std::swap(cur->v, v);
                }
                ptr = &cur->left;
                r = mid;
            } else {
                if (p < cur->p) {
                    std::swap(cur->p, p);
                    std::swap(cur->v, v);
                }
                ptr = &cur->right;
                l = mid;
            }
        }
        if (flg) {
            *ptr = std::make_unique<node>(p, v);
        }
        while (!st.empty()) {
            st.top()->get()->update();
            st.pop();
        }
    }

    T operator[](S p) {
        nptr* ptr = &root;
        S l = 0, r = n;
        while (*ptr) {
            nptr& cur = *ptr;
            if (cur->p == p) {
                return cur->v;
            }
            S mid = (l + r) >> 1;
            if (p < mid) {
                ptr = &cur->left;
                r = mid;
            } else {
                ptr = &cur->right;
                l = mid;
            }
        }
        return e();
    }

    T operator()(S L, S R) {
        if (!root) return e();
        T res = e();
        std::stack<std::tuple<const nptr&, S, S, bool>> st;
        st.emplace(root, 0, n, true);
        while (!st.empty()) {
            auto [ptr, l, r, flg] = st.top();
            st.pop();
            if (flg) {
                if (!ptr || r <= L || R <= l) continue;
                if (L <= l && r <= R) {
                    res = op(res, ptr->prod);
                    continue;
                }
                S mid = (l + r) >> 1;
                st.emplace(ptr->right, mid, r, true);
                if (ptr->p >= L && ptr->p < R) {
                    st.emplace(ptr, l, r, false);
                }
                st.emplace(ptr->left, l, mid, true);
            } else {
                res = op(res, ptr->v);
            }
        }
        return res;
    }

    S max_right(S L, const auto& f) {
        T sum = e();
        std::stack<std::tuple<const nptr&, S, S, bool>> st;
        st.emplace(root, 0, n, true);
        while (!st.empty()) {
            auto [ptr, l, r, flg] = st.top();
            st.pop();
            S mid = (l + r) >> 1;
            if (flg) {
                if (!ptr || r <= L) continue;
                if (L <= l && f(op(sum, ptr->prod))) {
                    sum = op(sum, ptr->prod);
                    continue;
                }
                if (ptr->p < L) {
                    st.emplace(ptr->right, mid, r, true);
                } else {
                    st.emplace(ptr, l, r, false);
                    st.emplace(ptr->left, l, mid, true);
                }
            } else {
                sum = op(sum, ptr->v);
                if (!f(sum)) {
                    return ptr->p;
                }
                st.emplace(ptr->right, mid, r, true);
            }
        }
        return n;
    }

    S min_left(S R, const auto& f) {
        T sum = e();
        std::stack<std::tuple<const nptr&, S, S, bool>> st;
        st.emplace(root, 0, n, true);
        while (!st.empty()) {
            auto [ptr, l, r, first] = st.top();
            st.pop();
            if (!ptr || R <= l) continue;
            S mid = (l + r) >> 1;
            if (first) {
                if (r <= R && f(op(ptr->prod, sum))) {
                    sum = op(ptr->prod, sum);
                    continue;
                }
                if (R <= ptr->p) {
                    st.emplace(ptr->left, l, mid, true);
                } else {
                    st.emplace(ptr, l, r, false);
                    st.emplace(ptr->right, mid, r, true);
                }
            } else {
                sum = op(ptr->v, sum);
                if (!f(sum)) {
                    return ptr->p + 1;
                }
                st.emplace(ptr->left, l, mid, true);
            }
        }
        return 0;
    }
};
#line 2 "data_structure/dynamic_segment_tree.hpp"
#include <memory>
#include <stack>

template <typename T, auto op, auto e, typename S = int64_t, S n = 1000000000000000001LL>
struct DynamicSegmentTree {
    struct node;
    using nptr = std::unique_ptr<node>;
    struct node {
        S p;
        T v, prod;
        nptr left, right;
        node(S _p, T _v) : p(_p), v(_v), prod(_v), left(nullptr), right(nullptr) {}
        void update() {
            prod = op(op(left ? left->prod : e(), v), right ? right->prod : e());
        }
    };
    nptr root{nullptr};
    DynamicSegmentTree() {}

    void set(S p, T v) {
        std::stack<nptr*> st;
        nptr* ptr = &root;
        S l = 0, r = n;
        bool flg = true;
        while (*ptr) {
            st.emplace(ptr);
            nptr& cur = *ptr;
            S mid = (l + r) >> 1;
            if (cur->p == p) {
                cur->v = v;
                flg = false;
                break;
            }
            if (p < mid) {
                if (cur->p < p) {
                    std::swap(cur->p, p);
                    std::swap(cur->v, v);
                }
                ptr = &cur->left;
                r = mid;
            } else {
                if (p < cur->p) {
                    std::swap(cur->p, p);
                    std::swap(cur->v, v);
                }
                ptr = &cur->right;
                l = mid;
            }
        }
        if (flg) {
            *ptr = std::make_unique<node>(p, v);
        }
        while (!st.empty()) {
            st.top()->get()->update();
            st.pop();
        }
    }

    void add(S p, T v) {
        std::stack<nptr*> st;
        nptr* ptr = &root;
        S l = 0, r = n;
        bool flg = true;
        while (*ptr) {
            st.emplace(ptr);
            nptr& cur = *ptr;
            S mid = (l + r) >> 1;
            if (cur->p == p) {
                cur->v += v;
                flg = false;
                break;
            }
            if (p < mid) {
                if (cur->p < p) {
                    std::swap(cur->p, p);
                    std::swap(cur->v, v);
                }
                ptr = &cur->left;
                r = mid;
            } else {
                if (p < cur->p) {
                    std::swap(cur->p, p);
                    std::swap(cur->v, v);
                }
                ptr = &cur->right;
                l = mid;
            }
        }
        if (flg) {
            *ptr = std::make_unique<node>(p, v);
        }
        while (!st.empty()) {
            st.top()->get()->update();
            st.pop();
        }
    }

    T operator[](S p) {
        nptr* ptr = &root;
        S l = 0, r = n;
        while (*ptr) {
            nptr& cur = *ptr;
            if (cur->p == p) {
                return cur->v;
            }
            S mid = (l + r) >> 1;
            if (p < mid) {
                ptr = &cur->left;
                r = mid;
            } else {
                ptr = &cur->right;
                l = mid;
            }
        }
        return e();
    }

    T operator()(S L, S R) {
        if (!root) return e();
        T res = e();
        std::stack<std::tuple<const nptr&, S, S, bool>> st;
        st.emplace(root, 0, n, true);
        while (!st.empty()) {
            auto [ptr, l, r, flg] = st.top();
            st.pop();
            if (flg) {
                if (!ptr || r <= L || R <= l) continue;
                if (L <= l && r <= R) {
                    res = op(res, ptr->prod);
                    continue;
                }
                S mid = (l + r) >> 1;
                st.emplace(ptr->right, mid, r, true);
                if (ptr->p >= L && ptr->p < R) {
                    st.emplace(ptr, l, r, false);
                }
                st.emplace(ptr->left, l, mid, true);
            } else {
                res = op(res, ptr->v);
            }
        }
        return res;
    }

    S max_right(S L, const auto& f) {
        T sum = e();
        std::stack<std::tuple<const nptr&, S, S, bool>> st;
        st.emplace(root, 0, n, true);
        while (!st.empty()) {
            auto [ptr, l, r, flg] = st.top();
            st.pop();
            S mid = (l + r) >> 1;
            if (flg) {
                if (!ptr || r <= L) continue;
                if (L <= l && f(op(sum, ptr->prod))) {
                    sum = op(sum, ptr->prod);
                    continue;
                }
                if (ptr->p < L) {
                    st.emplace(ptr->right, mid, r, true);
                } else {
                    st.emplace(ptr, l, r, false);
                    st.emplace(ptr->left, l, mid, true);
                }
            } else {
                sum = op(sum, ptr->v);
                if (!f(sum)) {
                    return ptr->p;
                }
                st.emplace(ptr->right, mid, r, true);
            }
        }
        return n;
    }

    S min_left(S R, const auto& f) {
        T sum = e();
        std::stack<std::tuple<const nptr&, S, S, bool>> st;
        st.emplace(root, 0, n, true);
        while (!st.empty()) {
            auto [ptr, l, r, first] = st.top();
            st.pop();
            if (!ptr || R <= l) continue;
            S mid = (l + r) >> 1;
            if (first) {
                if (r <= R && f(op(ptr->prod, sum))) {
                    sum = op(ptr->prod, sum);
                    continue;
                }
                if (R <= ptr->p) {
                    st.emplace(ptr->left, l, mid, true);
                } else {
                    st.emplace(ptr, l, r, false);
                    st.emplace(ptr->right, mid, r, true);
                }
            } else {
                sum = op(ptr->v, sum);
                if (!f(sum)) {
                    return ptr->p + 1;
                }
                st.emplace(ptr->left, l, mid, true);
            }
        }
        return 0;
    }
};
Back to top page