Skip to the content.

:heavy_check_mark: data_structure/wavelet_matrix.hpp

Required by

Verified with

Code

#pragma once
#include <immintrin.h>

#include <algorithm>
#include <cstdint>
#include <vector>

struct BitVector {
    static constexpr uint32_t w = 64;
    std::vector<uint64_t> block;
    std::vector<uint32_t> count;
    uint64_t n, zeros;

    inline uint64_t get(uint64_t i) const { return uint64_t(block[i / w] >> (i % w)) & 1u; }
    inline void set(uint64_t i) { block[i / w] |= 1LL << (i % w); }

    BitVector() {}
    BitVector(int _n) { init(_n); }
    void init(int _n) {
        n = zeros = _n;
        block.resize(n / w + 1, 0);
        count.resize(block.size(), 0);
    }

    __attribute__((target("popcnt"))) void build() {
        for (uint32_t i = 1; i < block.size(); ++i) {
            count[i] = count[i - 1] + _mm_popcnt_u64(block[i - 1]);
        }
        zeros = rank0(n);
    }

    inline uint32_t rank0(uint32_t i) const { return i - rank1(i); }
    __attribute__((target("bmi2,popcnt"))) inline uint32_t rank1(uint32_t i) const {
        return count[i / w] + _mm_popcnt_u64(_bzhi_u64(block[i / w], i % w));
    }
};

template <typename T = int>
struct WaveletMatrix {
    std::vector<T> v;
    int n, m;
    std::vector<BitVector> b;
    WaveletMatrix(const std::vector<T>& _v) : v(_v), n(v.size()), m(std::__lg(std::max(1, *std::max_element(v.begin(), v.end()))) + 1), b(m, BitVector(n)) {
        std::vector<T> nxt(n);
        for (int i = m - 1; 0 <= i; --i) {
            for (int j = 0; j < n; ++j) {
                if ((v[j] >> i) & 1) {
                    b[m - i - 1].set(j);
                }
            }
            b[m - i - 1].build();
            int cnt[2] = {-1, b[m - i - 1].rank0(n) - 1};
            for (int j = 0; j < n; ++j) {
                nxt[++cnt[(v[j] >> i) & 1]] = v[j];
            }
            std::swap(v, nxt);
        }
    }

    T operator[](int p) {
        for (int i = 0; i < m; ++i) {
            if (b[i].get(p)) {
                p = b[i].zeros + b[i].rank1(p);
            } else {
                p = b[i].rank0(p);
            }
        }
        return v[p];
    }

    T small(int l, int r, int p);
};
#line 2 "data_structure/wavelet_matrix.hpp"
#include <immintrin.h>

#include <algorithm>
#include <cstdint>
#include <vector>

struct BitVector {
    static constexpr uint32_t w = 64;
    std::vector<uint64_t> block;
    std::vector<uint32_t> count;
    uint64_t n, zeros;

    inline uint64_t get(uint64_t i) const { return uint64_t(block[i / w] >> (i % w)) & 1u; }
    inline void set(uint64_t i) { block[i / w] |= 1LL << (i % w); }

    BitVector() {}
    BitVector(int _n) { init(_n); }
    void init(int _n) {
        n = zeros = _n;
        block.resize(n / w + 1, 0);
        count.resize(block.size(), 0);
    }

    __attribute__((target("popcnt"))) void build() {
        for (uint32_t i = 1; i < block.size(); ++i) {
            count[i] = count[i - 1] + _mm_popcnt_u64(block[i - 1]);
        }
        zeros = rank0(n);
    }

    inline uint32_t rank0(uint32_t i) const { return i - rank1(i); }
    __attribute__((target("bmi2,popcnt"))) inline uint32_t rank1(uint32_t i) const {
        return count[i / w] + _mm_popcnt_u64(_bzhi_u64(block[i / w], i % w));
    }
};

template <typename T = int>
struct WaveletMatrix {
    std::vector<T> v;
    int n, m;
    std::vector<BitVector> b;
    WaveletMatrix(const std::vector<T>& _v) : v(_v), n(v.size()), m(std::__lg(std::max(1, *std::max_element(v.begin(), v.end()))) + 1), b(m, BitVector(n)) {
        std::vector<T> nxt(n);
        for (int i = m - 1; 0 <= i; --i) {
            for (int j = 0; j < n; ++j) {
                if ((v[j] >> i) & 1) {
                    b[m - i - 1].set(j);
                }
            }
            b[m - i - 1].build();
            int cnt[2] = {-1, b[m - i - 1].rank0(n) - 1};
            for (int j = 0; j < n; ++j) {
                nxt[++cnt[(v[j] >> i) & 1]] = v[j];
            }
            std::swap(v, nxt);
        }
    }

    T operator[](int p) {
        for (int i = 0; i < m; ++i) {
            if (b[i].get(p)) {
                p = b[i].zeros + b[i].rank1(p);
            } else {
                p = b[i].rank0(p);
            }
        }
        return v[p];
    }

    T small(int l, int r, int p);
};
Back to top page