Skip to the content.

:heavy_check_mark: Li Chao Tree (data_structure/li_chao_tree.hpp)

Li Chao Tree は、区間内の直線群に対する最小値(または最大値)クエリを高速に処理するためのデータ構造です。この実装は、コンストラクタで与えられた離散的な座標集合に対して動作する静的なバージョンです。座標圧縮を内部的に行います。

テンプレートパラメータ:

使い方

コンストラクタ

LiChaoTree(const std::vector<T>& vec)

説明

管理する座標集合 vec を元に Li Chao Tree を構築します。vec は内部でソートされ、重複が排除されて座標圧縮に利用されます。

計算量

$O(M \log M)$ (ただし $M$ は vec のサイズ)

add (直線追加)

void add(T a, T b, T L, T R)

説明

元の座標系における区間 [L, R) に直線 $y = ax + b$ を追加します。内部的には、コンストラクタで与えられた座標集合のうち、区間 [L, R) に含まれる座標に対応する区間に直線を追加します。

制約

L < R である必要があります。また、L および R はコンストラクタで与えられた座標集合の範囲内にあることが想定されますが、実装はそれより広い範囲も受け入れます(ただし、実際に直線が追加されるのは座標集合内の点に対応する区間のみです)。

計算量

$O(\log^2 N)$ (ただし $N$ は座標圧縮後の座標数)

operator() (クエリ)

T operator()(T x)

説明

元の座標系における座標 x における、追加された全ての直線の最小値(または最大値)を返します。座標 x はコンストラクタで与えられた座標集合に含まれている必要があります。

制約

座標 x はコンストラクタで与えられた座標集合に含まれている必要があります。

計算量

$O(\log N)$ (ただし $N$ は座標圧縮後の座標数)

Verified with

Code

#pragma once
#include <stack>
#include <vector>

template <typename T, auto e>
struct LiChaoTree {
    int n;
    int siz;
    std::vector<T> cmp;
    struct node {
        T a, b;
        node(T _a, T _b) : a(_a), b(_b) {}
        node() {}
        T get(T x) {
            return x * a + b;
        }
    };
    std::vector<node> v;
    LiChaoTree(const std::vector<T>& vec) : cmp(vec) {
        sort(cmp.begin(), cmp.end());
        cmp.erase(unique(cmp.begin(), cmp.end()), cmp.end());
        n = cmp.size();
        siz = (1 << (std::__lg(std::max(1, n)) + 1)) << 1;
        v = std::vector<node>(siz, node{0, e()});
    }
    void add(T a, T b, int c = 1, T l = 0, T r = -1) {
        if (r == -1) {
            r = n;
        }
        while (c < siz) {
            if (r <= l + 1) break;
            if (b == e()) return;
            node& cur = v[c];
            T mid = (l + r) >> 1;
            T L = a * cmp[l] + b;
            T M = a * cmp[mid] + b;
            T R = a * cmp[(r - 1)] + b;
            T cL = cur.get(cmp[l]);
            T cM = cur.get(cmp[mid]);
            T cR = cur.get(cmp[r - 1]);
            if (cL <= L && cR <= R) {
                return;
            }
            if (L <= cL && R <= cR) {
                std::swap(cur.a, a);
                std::swap(cur.b, b);
                return;
            }
            if (L <= cL) {
                if (M <= cM) {
                    std::swap(cur.a, a);
                    std::swap(cur.b, b);
                    c = (c << 1) | 1;
                    l = mid;
                } else {
                    c = (c << 1);
                    r = mid;
                }
            } else {
                if (M <= cM) {
                    std::swap(cur.a, a);
                    std::swap(cur.b, b);
                    c = (c << 1);
                    r = mid;
                } else {
                    c = (c << 1) | 1;
                    l = mid;
                }
            }
        }
        if (c < siz) {
            if (a * cmp[l] + b < v[c].get(cmp[l])) {
                std::swap(a, v[c].a);
                std::swap(b, v[c].b);
            }
        }
    }

    void add(T a, T b, T L, T R) {
        std::stack<std::tuple<int, T, T>> st;
        st.emplace(1, 0, n);
        while (!st.empty()) {
            auto [c, l, r] = st.top();
            st.pop();
            if (siz <= c) continue;
            if (cmp[r - 1] < L || R <= cmp[l]) continue;
            if (L <= cmp[l] && cmp[r - 1] < R) {
                add(a, b, c, l, r);
                continue;
            }
            T mid = (l + r) >> 1;
            st.emplace((c << 1), l, mid);
            st.emplace((c << 1) | 1, mid, r);
        }
    }

    T operator()(T x) {
        int c = 1;
        T l = 0, r = n;
        T res{e()};
        while (c < siz) {
            res = std::min(res, v[c].get(x));
            T mid = (l + r) >> 1;
            if (x < cmp[mid]) {
                r = mid;
                c = (c << 1);
            } else {
                l = mid;
                c = (c << 1) | 1;
            }
        }
        return res;
    }
};
#line 2 "data_structure/li_chao_tree.hpp"
#include <stack>
#include <vector>

template <typename T, auto e>
struct LiChaoTree {
    int n;
    int siz;
    std::vector<T> cmp;
    struct node {
        T a, b;
        node(T _a, T _b) : a(_a), b(_b) {}
        node() {}
        T get(T x) {
            return x * a + b;
        }
    };
    std::vector<node> v;
    LiChaoTree(const std::vector<T>& vec) : cmp(vec) {
        sort(cmp.begin(), cmp.end());
        cmp.erase(unique(cmp.begin(), cmp.end()), cmp.end());
        n = cmp.size();
        siz = (1 << (std::__lg(std::max(1, n)) + 1)) << 1;
        v = std::vector<node>(siz, node{0, e()});
    }
    void add(T a, T b, int c = 1, T l = 0, T r = -1) {
        if (r == -1) {
            r = n;
        }
        while (c < siz) {
            if (r <= l + 1) break;
            if (b == e()) return;
            node& cur = v[c];
            T mid = (l + r) >> 1;
            T L = a * cmp[l] + b;
            T M = a * cmp[mid] + b;
            T R = a * cmp[(r - 1)] + b;
            T cL = cur.get(cmp[l]);
            T cM = cur.get(cmp[mid]);
            T cR = cur.get(cmp[r - 1]);
            if (cL <= L && cR <= R) {
                return;
            }
            if (L <= cL && R <= cR) {
                std::swap(cur.a, a);
                std::swap(cur.b, b);
                return;
            }
            if (L <= cL) {
                if (M <= cM) {
                    std::swap(cur.a, a);
                    std::swap(cur.b, b);
                    c = (c << 1) | 1;
                    l = mid;
                } else {
                    c = (c << 1);
                    r = mid;
                }
            } else {
                if (M <= cM) {
                    std::swap(cur.a, a);
                    std::swap(cur.b, b);
                    c = (c << 1);
                    r = mid;
                } else {
                    c = (c << 1) | 1;
                    l = mid;
                }
            }
        }
        if (c < siz) {
            if (a * cmp[l] + b < v[c].get(cmp[l])) {
                std::swap(a, v[c].a);
                std::swap(b, v[c].b);
            }
        }
    }

    void add(T a, T b, T L, T R) {
        std::stack<std::tuple<int, T, T>> st;
        st.emplace(1, 0, n);
        while (!st.empty()) {
            auto [c, l, r] = st.top();
            st.pop();
            if (siz <= c) continue;
            if (cmp[r - 1] < L || R <= cmp[l]) continue;
            if (L <= cmp[l] && cmp[r - 1] < R) {
                add(a, b, c, l, r);
                continue;
            }
            T mid = (l + r) >> 1;
            st.emplace((c << 1), l, mid);
            st.emplace((c << 1) | 1, mid, r);
        }
    }

    T operator()(T x) {
        int c = 1;
        T l = 0, r = n;
        T res{e()};
        while (c < siz) {
            res = std::min(res, v[c].get(x));
            T mid = (l + r) >> 1;
            if (x < cmp[mid]) {
                r = mid;
                c = (c << 1);
            } else {
                l = mid;
                c = (c << 1) | 1;
            }
        }
        return res;
    }
};
Back to top page