Segment Tree 2D (data_structure/segment_tree_2d.hpp)
- View this file on GitHub
- View document part on GitHub
- Last update: 2025-06-22 11:04:54+09:00
- Include:
#include "data_structure/segment_tree_2d.hpp"
2次元平面上の点に対する要素の更新や矩形領域に対する集約された値の取得を高速に行うことができるデータ構造です。座標圧縮を内部で行います。
テンプレートパラメータ:
-
T
: 要素の型。 -
op
: 二項演算 (モノイド)。T op(T, T)
の形式である必要があります。 -
e
: モノイドの単位元を返す関数。T e()
の形式である必要があります。 -
S
: 座標の型 (デフォルトはint
)。
使い方
コンストラクタ
SegmentTree2D(std::vector<Point<S>> p)
説明
与えられた点の集合 p
を元に2次元セグメントツリーを構築します。点のx座標は自動的に座標圧縮されます。初期値は全て単位元 e()
です。
-
p
: 初期値となる点の集合。各点には座標x
,y
が含まれます。
計算量
$O(N \log N)$ (座標圧縮と構築にかかる時間)
set
void set(Point<S> p, T v)
説明
点 p
の要素の値を v
に更新します。
-
p
: 更新する点の座標。 -
v
: 更新後の値。
計算量
$O(\log^2 N)$
void set(S x, S y, T v)
説明
座標 (x, y)
の要素の値を v
に更新します。
-
x
: 更新する点のx座標。 -
y
: 更新する点のy座標。 -
v
: 更新後の値。
計算量
$O(\log^2 N)$
add
void add(Point<S> p, T v)
説明
点 p
の要素に v
を加算します。これは特に加算モノイドの場合に便利ですが、一般のモノイドに対しても v[p] = op(v[p], x)
のように動作します。
-
p
: 加算する点の座標。 -
v
: 加算する値。
計算量
$O(\log^2 N)$
void add(S x, S y, T v)
説明
座標 (x, y)
の要素に v
を加算します。
-
x
: 加算する点のx座標。 -
y
: 加算する点のy座標。 -
v
: 加算する値。
計算量
$O(\log^2 N)$
get
T get(S x, S y)
説明
座標 (x, y)
の要素の値を返します。
-
x
: 取得する点のx座標。 -
y
: 取得する点のy座標。 - 戻り値: 座標
(x, y)
の要素の値。
計算量
$O(\log^2 N)$
operator()
T operator()(S l, S r, S t, S b)
説明
矩形領域 [l, r)
(x座標の半開区間) $\times$ [t, b)
(y座標の半開区間) 内の要素に対する二項演算 op
の結果を返します。
-
l
: 矩形領域の左端x座標 (含む)。 -
r
: 矩形領域の右端x座標 (含まない)。 -
t
: 矩形領域の上端y座標 (含む)。 -
b
: 矩形領域の下端y座標 (含まない)。 - 戻り値: 矩形領域内の要素に対する
op
の結果。
計算量
$O(\log^2 N)$
Depends on
Verified with
Code
#pragma once
#include <algorithm>
#include <data_structure/segment_tree.hpp>
#include <math/point.hpp>
template <typename T, auto op, auto e, typename S = int>
struct SegmentTree2D {
int n;
std::vector<SegmentTree<T, op, e>> seg;
std::vector<S> X;
std::vector<std::vector<S>> pos;
explicit SegmentTree2D(std::vector<Point<S>> p) {
for (auto [_x, _] : p) {
X.emplace_back(_x);
}
std::sort(X.begin(), X.end());
X.erase(std::unique(X.begin(), X.end()), X.end());
n = (1 << (std::__lg(std::max(1, static_cast<int>(X.size()))) + 1)) << 1;
pos.assign(n * 2, {});
for (auto [_x, _y] : p) {
_x = std::distance(X.begin(), std::lower_bound(X.begin(), X.end(), _x));
_x += n;
for (; _x; _x >>= 1) {
pos[_x].emplace_back(_y);
}
}
for (int i = n * 2 - 1; i; --i) {
std::sort(pos[i].begin(), pos[i].end());
pos[i].erase(unique(pos[i].begin(), pos[i].end()), pos[i].end());
}
seg.emplace_back(SegmentTree<T, op, e>(0));
for (int i = 1; i < n * 2; ++i) {
seg.emplace_back(SegmentTree<T, op, e>(pos[i].size()));
}
}
void set(Point<S> p, T v) {
set(p.x, p.y, v);
}
void set(S x, S y, T v) {
x = distance(X.begin(), lower_bound(X.begin(), X.end(), x));
x += n;
{
int p = distance(pos[x].begin(), lower_bound(pos[x].begin(), pos[x].end(), y));
seg[x].set(p, v);
}
while (x >>= 1) {
int p = std::distance(pos[x].begin(), std::lower_bound(pos[x].begin(), pos[x].end(), y));
auto left = lower_bound(pos[x << 1].begin(), pos[x << 1].end(), y);
auto right = lower_bound(pos[(x << 1) + 1].begin(), pos[(x << 1) + 1].end(), y);
if (left == pos[x << 1].end() || *left != y) {
seg[x].set(p, seg[(x << 1) + 1][distance(pos[(x << 1) + 1].begin(), right)]);
} else if (right == pos[(x << 1) + 1].end() || *right != y) {
seg[x].set(p, seg[x << 1][distance(pos[x << 1].begin(), left)]);
} else {
seg[x].set(p, op(
seg[x << 1][distance(pos[x << 1].begin(), left)],
seg[(x << 1) + 1][distance(pos[(x << 1) + 1].begin(), right)]));
}
}
}
void add(Point<S> p, T v) {
add(p.x, p.y, v);
}
void add(S x, S y, T v) {
x = distance(X.begin(), lower_bound(X.begin(), X.end(), x));
x += n;
for (; x; x >>= 1) {
int p = std::distance(pos[x].begin(), std::lower_bound(pos[x].begin(), pos[x].end(), y));
seg[x].add(p, v);
}
}
T get(S x, S y) {
x = distance(X.begin(), lower_bound(X.begin(), X.end(), x));
x += n;
int p = distance(pos[x].begin(), lower_bound(pos[x].begin(), pos[x].end(), y));
return seg[x][p];
}
T get(Point<S> p) {
return get(p.x, p.y);
}
T operator[](Point<S> p) {
return get(p);
}
T operator()(S l, S r, S t, S b) {
l = distance(X.begin(), lower_bound(X.begin(), X.end(), l));
r = distance(X.begin(), lower_bound(X.begin(), X.end(), r));
T left = e(), right = e();
for (l += n, r += n; l < r; l >>= 1, r >>= 1) {
if (l & 1) {
int L = distance(pos[l].begin(), lower_bound(pos[l].begin(), pos[l].end(), t));
int R = distance(pos[l].begin(), lower_bound(pos[l].begin(), pos[l].end(), b));
left = op(left, seg[l](L, R));
++l;
}
if (r & 1) {
--r;
int L = distance(pos[r].begin(), lower_bound(pos[r].begin(), pos[r].end(), t));
int R = distance(pos[r].begin(), lower_bound(pos[r].begin(), pos[r].end(), b));
right = op(seg[r](L, R), right);
}
}
return op(left, right);
}
};
#line 2 "data_structure/segment_tree_2d.hpp"
#include <algorithm>
#include <data_structure/segment_tree.hpp>
#include <math/point.hpp>
template <typename T, auto op, auto e, typename S = int>
struct SegmentTree2D {
int n;
std::vector<SegmentTree<T, op, e>> seg;
std::vector<S> X;
std::vector<std::vector<S>> pos;
explicit SegmentTree2D(std::vector<Point<S>> p) {
for (auto [_x, _] : p) {
X.emplace_back(_x);
}
std::sort(X.begin(), X.end());
X.erase(std::unique(X.begin(), X.end()), X.end());
n = (1 << (std::__lg(std::max(1, static_cast<int>(X.size()))) + 1)) << 1;
pos.assign(n * 2, {});
for (auto [_x, _y] : p) {
_x = std::distance(X.begin(), std::lower_bound(X.begin(), X.end(), _x));
_x += n;
for (; _x; _x >>= 1) {
pos[_x].emplace_back(_y);
}
}
for (int i = n * 2 - 1; i; --i) {
std::sort(pos[i].begin(), pos[i].end());
pos[i].erase(unique(pos[i].begin(), pos[i].end()), pos[i].end());
}
seg.emplace_back(SegmentTree<T, op, e>(0));
for (int i = 1; i < n * 2; ++i) {
seg.emplace_back(SegmentTree<T, op, e>(pos[i].size()));
}
}
void set(Point<S> p, T v) {
set(p.x, p.y, v);
}
void set(S x, S y, T v) {
x = distance(X.begin(), lower_bound(X.begin(), X.end(), x));
x += n;
{
int p = distance(pos[x].begin(), lower_bound(pos[x].begin(), pos[x].end(), y));
seg[x].set(p, v);
}
while (x >>= 1) {
int p = std::distance(pos[x].begin(), std::lower_bound(pos[x].begin(), pos[x].end(), y));
auto left = lower_bound(pos[x << 1].begin(), pos[x << 1].end(), y);
auto right = lower_bound(pos[(x << 1) + 1].begin(), pos[(x << 1) + 1].end(), y);
if (left == pos[x << 1].end() || *left != y) {
seg[x].set(p, seg[(x << 1) + 1][distance(pos[(x << 1) + 1].begin(), right)]);
} else if (right == pos[(x << 1) + 1].end() || *right != y) {
seg[x].set(p, seg[x << 1][distance(pos[x << 1].begin(), left)]);
} else {
seg[x].set(p, op(
seg[x << 1][distance(pos[x << 1].begin(), left)],
seg[(x << 1) + 1][distance(pos[(x << 1) + 1].begin(), right)]));
}
}
}
void add(Point<S> p, T v) {
add(p.x, p.y, v);
}
void add(S x, S y, T v) {
x = distance(X.begin(), lower_bound(X.begin(), X.end(), x));
x += n;
for (; x; x >>= 1) {
int p = std::distance(pos[x].begin(), std::lower_bound(pos[x].begin(), pos[x].end(), y));
seg[x].add(p, v);
}
}
T get(S x, S y) {
x = distance(X.begin(), lower_bound(X.begin(), X.end(), x));
x += n;
int p = distance(pos[x].begin(), lower_bound(pos[x].begin(), pos[x].end(), y));
return seg[x][p];
}
T get(Point<S> p) {
return get(p.x, p.y);
}
T operator[](Point<S> p) {
return get(p);
}
T operator()(S l, S r, S t, S b) {
l = distance(X.begin(), lower_bound(X.begin(), X.end(), l));
r = distance(X.begin(), lower_bound(X.begin(), X.end(), r));
T left = e(), right = e();
for (l += n, r += n; l < r; l >>= 1, r >>= 1) {
if (l & 1) {
int L = distance(pos[l].begin(), lower_bound(pos[l].begin(), pos[l].end(), t));
int R = distance(pos[l].begin(), lower_bound(pos[l].begin(), pos[l].end(), b));
left = op(left, seg[l](L, R));
++l;
}
if (r & 1) {
--r;
int L = distance(pos[r].begin(), lower_bound(pos[r].begin(), pos[r].end(), t));
int R = distance(pos[r].begin(), lower_bound(pos[r].begin(), pos[r].end(), b));
right = op(seg[r](L, R), right);
}
}
return op(left, right);
}
};