Skip to the content.

:heavy_check_mark: Segment Tree (data_structure/segment_tree.hpp)

要素の更新や区間に対する集約された値の取得を高速に行うことができるデータ構造です。

テンプレートパラメータ:

使い方

コンストラクタ

SegmentTree(int n)

説明

サイズ n のセグメントツリーを構築します。初期値は全て単位元 e() です。

計算量

$O(N)$

SegmentTree(const std::vector<T>& _v)

説明

与えられた配列 _v を元にセグメントツリーを構築します。

計算量

$O(N)$

operator[]

T operator[](int p) const

説明

インデックス p の要素の値を返します。

制約

$0 \le p < N$ (構築時の要素数)

計算量

$O(1)$

set

void set(int p, T x)

説明

インデックス p の要素の値を x に更新します。

制約

$0 \le p < N$ (構築時の要素数)

計算量

$O(\log N)$

add

void add(int p, T x)

説明

インデックス p の要素に x を加算します。これは特に加算モノイドの場合に便利ですが、一般のモノイドに対しても v[p] = op(v[p], x) のように動作します。

制約

$0 \le p < N$ (構築時の要素数)

計算量

$O(\log N)$

operator()

T operator()(int l, int r) const

説明

区間 [l, r) (半開区間) の要素に対する二項演算 op の結果を返します。

制約

$0 \le l \le r \le N$ (構築時の要素数)

計算量

$O(\log N)$

max_right

int max_right(int l, auto f) const

説明

インデックス l から開始して、条件 f(op(v[l], v[l+1], ..., v[i])) を満たす最大のインデックス i を見つけます。つまり、区間 [l, i] の集約値に対して ftrue となる最大の i+1 を返します。f は単調である必要があります。

制約

$0 \le l \le N$ (構築時の要素数)

計算量

$O(\log N)$

min_left

int min_left(int r, auto f) const

説明

インデックス r で終了して、条件 f(op(v[i], ..., v[r-2], v[r-1])) を満たす最小のインデックス i を見つけます。つまり、区間 [i, r) の集約値に対して ftrue となる最小の i を返します。f は単調である必要があります。

制約

$0 \le r \le N$ (構築時の要素数)

計算量

$O(\log N)$

Required by

Verified with

Code

#pragma once
#include <bit>
#include <cstdint>
#include <vector>

template <typename T, auto op, auto e>
struct SegmentTree {
    using _T = T;
    using _F = T;
    static constexpr auto _op = op;
    static constexpr auto _e = e;
    int n, _n;
    std::vector<T> v;
    explicit SegmentTree() : SegmentTree(0) {}
    explicit SegmentTree(int n) : SegmentTree(std::vector<T>(n, e())) {}
    explicit SegmentTree(const std::vector<T>& _v) : n(std::bit_ceil(_v.size()) << 1), _n(_v.size()), v(n * 2, e()) {
        for (uint32_t i = 0; i < _v.size(); ++i) {
            v[i + n] = _v[i];
        }
        for (int i = n - 1; i; --i) {
            v[i] = op(v[i << 1], v[(i << 1) | 1]);
        }
    }

    T operator[](int p) const {
        return v[p + n];
    }

    void set(int p, T x) {
        p += n;
        v[p] = x;
        for (p >>= 1; p; p >>= 1) {
            v[p] = op(v[p << 1], v[(p << 1) | 1]);
        }
    }

    void add(int p, T x) {
        p += n;
        v[p] += x;
        for (p >>= 1; p; p >>= 1) {
            v[p] = op(v[p << 1], v[(p << 1) | 1]);
        }
    }

    T operator()(int l, int r) const {
        T L{e()}, R{e()};
        for (l += n, r += n; l < r; l >>= 1, r >>= 1) {
            if (l & 1) L = op(L, v[l++]);
            if (r & 1) R = op(v[--r], R);
        }
        return op(L, R);
    }

    int max_right(int l, auto f) const {
        if (l == _n) return _n;
        l += n;
        T sm = e();
        do {
            while ((l & 1) == 0) l >>= 1;
            if (!f(op(sm, v[l]))) {
                while (l < n) {
                    l <<= 1;
                    if (f(op(sm, v[l]))) {
                        sm = op(sm, v[l]);
                        ++l;
                    }
                }
                return l - n;
            }
            sm = op(sm, v[l]);
            ++l;
        } while ((l & -l) != l);
        return _n;
    }

    int min_left(int r, auto f) const {
        if (r == 0) return 0;
        r += n;
        T sm = e();
        do {
            --r;
            while (r > 1 && (r & 1)) r >>= 1;
            if (!f(op(v[r], sm))) {
                while (r < n) {
                    r = ((r << 1) + 1);
                    if (f(op(v[r], sm))) {
                        sm = op(v[r], sm);
                        r--;
                    }
                }
                return r + 1 - n;
            }
            sm = op(v[r], sm);
        } while ((r & -r) != r);
        return 0;
    }
};
#line 2 "data_structure/segment_tree.hpp"
#include <bit>
#include <cstdint>
#include <vector>

template <typename T, auto op, auto e>
struct SegmentTree {
    using _T = T;
    using _F = T;
    static constexpr auto _op = op;
    static constexpr auto _e = e;
    int n, _n;
    std::vector<T> v;
    explicit SegmentTree() : SegmentTree(0) {}
    explicit SegmentTree(int n) : SegmentTree(std::vector<T>(n, e())) {}
    explicit SegmentTree(const std::vector<T>& _v) : n(std::bit_ceil(_v.size()) << 1), _n(_v.size()), v(n * 2, e()) {
        for (uint32_t i = 0; i < _v.size(); ++i) {
            v[i + n] = _v[i];
        }
        for (int i = n - 1; i; --i) {
            v[i] = op(v[i << 1], v[(i << 1) | 1]);
        }
    }

    T operator[](int p) const {
        return v[p + n];
    }

    void set(int p, T x) {
        p += n;
        v[p] = x;
        for (p >>= 1; p; p >>= 1) {
            v[p] = op(v[p << 1], v[(p << 1) | 1]);
        }
    }

    void add(int p, T x) {
        p += n;
        v[p] += x;
        for (p >>= 1; p; p >>= 1) {
            v[p] = op(v[p << 1], v[(p << 1) | 1]);
        }
    }

    T operator()(int l, int r) const {
        T L{e()}, R{e()};
        for (l += n, r += n; l < r; l >>= 1, r >>= 1) {
            if (l & 1) L = op(L, v[l++]);
            if (r & 1) R = op(v[--r], R);
        }
        return op(L, R);
    }

    int max_right(int l, auto f) const {
        if (l == _n) return _n;
        l += n;
        T sm = e();
        do {
            while ((l & 1) == 0) l >>= 1;
            if (!f(op(sm, v[l]))) {
                while (l < n) {
                    l <<= 1;
                    if (f(op(sm, v[l]))) {
                        sm = op(sm, v[l]);
                        ++l;
                    }
                }
                return l - n;
            }
            sm = op(sm, v[l]);
            ++l;
        } while ((l & -l) != l);
        return _n;
    }

    int min_left(int r, auto f) const {
        if (r == 0) return 0;
        r += n;
        T sm = e();
        do {
            --r;
            while (r > 1 && (r & 1)) r >>= 1;
            if (!f(op(v[r], sm))) {
                while (r < n) {
                    r = ((r << 1) + 1);
                    if (f(op(v[r], sm))) {
                        sm = op(v[r], sm);
                        r--;
                    }
                }
                return r + 1 - n;
            }
            sm = op(v[r], sm);
        } while ((r & -r) != r);
        return 0;
    }
};
Back to top page