1#ifndef DICE_HASH_LTHASH_HPP
2#define DICE_HASH_LTHASH_HPP
14#if __has_include(<sodium.h>)
24#include "dice/hash/blake/Blake3.hpp"
25#include "dice/hash/lthash/MathEngine.hpp"
27namespace dice::hash::lthash {
30 template <
size_t n_bits_per_element>
35 static constexpr bool needs_padding =
false;
36 static constexpr size_t bits_per_element = 16;
43 static constexpr uint64_t data_mask = ~0xC000020000100000ULL;
44 static constexpr bool needs_padding =
true;
45 static constexpr size_t bits_per_element = 20;
50 static constexpr bool needs_padding =
false;
51 static constexpr size_t bits_per_element = 32;
54 template<
size_t min_key_extent,
size_t max_key_extent>
56 std::array<std::byte, max_key_extent> key_{};
59 [[nodiscard]]
constexpr std::span<std::byte const> get() const noexcept {
60 return {key_.data(), key_len_};
63 constexpr void clear() noexcept {
64 if (!std::is_constant_evaluated()) {
65 sodium_memzero(key_.data(), key_.size());
67 std::fill(key_.begin(), key_.end(), std::byte{0});
73 template<
size_t supplied_key_len>
74 requires (supplied_key_len == std::dynamic_extent || (supplied_key_len >= min_key_extent
75 && supplied_key_len <= max_key_extent))
76 constexpr void set_unchecked(std::span<std::byte const, supplied_key_len> new_key)
noexcept {
77 assert(new_key.size() >= min_key_extent && new_key.size() <= max_key_extent);
80 std::copy(new_key.begin(), new_key.end(), key_.begin());
81 key_len_ = new_key.size();
85 template<
size_t KeyExtent>
86 struct Key<KeyExtent, KeyExtent> {
87 std::array<std::byte, KeyExtent> key_{};
89 [[nodiscard]]
constexpr std::span<std::byte const, KeyExtent> get() const noexcept {
93 constexpr void clear() noexcept {
94 if (!std::is_constant_evaluated()) {
95 sodium_memzero(key_.data(), key_.size());
97 std::fill(key_.begin(), key_.end(), std::byte{0});
101 constexpr void set_unchecked(std::span<std::byte const, KeyExtent> new_key)
noexcept {
103 std::copy(new_key.begin(), new_key.end(), key_.begin());
114 template<
size_t n_bits_per_elem,
size_t n_elems,
template<
size_t>
typename HashT = blake3::Blake3,
template<
typename>
typename MathEngineT = DefaultMathEngine>
116 static_assert((n_bits_per_elem == 16 && n_elems % 32 == 0)
117 || (n_bits_per_elem == 20 && n_elems % 24 == 0)
118 || (n_bits_per_elem == 32 && n_elems % 16 == 0));
120 static_assert(MathEngine<MathEngineT, detail::Bits<n_bits_per_elem>>);
123 template<
size_t,
size_t,
template<
size_t>
typename,
template<
typename>
typename>
124 friend struct LtHash;
126 using Bits = detail::Bits<n_bits_per_elem>;
127 using MathEngine = MathEngineT<Bits>;
130 static constexpr bool needs_padding = Bits::needs_padding;
132 static constexpr size_t element_bits = n_bits_per_elem;
133 static constexpr size_t element_count = n_elems;
135 static constexpr size_t elements_per_uint64 = needs_padding ? (
sizeof(uint64_t) * 8) / (element_bits + 1)
136 : (sizeof(uint64_t) * 8) / element_bits;
138 static constexpr size_t checksum_len = (element_count / elements_per_uint64) *
sizeof(uint64_t);
139 static constexpr size_t checksum_align = MathEngine::min_buffer_align;
141 static constexpr std::array<std::byte, checksum_len> default_checksum{};
144 using Hash = HashT<checksum_len>;
146 detail::Key<Hash::min_key_extent, Hash::max_key_extent> key_;
147 alignas(checksum_align) std::array<std::byte, checksum_len> checksum_;
150 constexpr void set_checksum_unchecked(std::span<std::byte const, checksum_len> new_checksum)
noexcept {
151 std::copy(new_checksum.begin(), new_checksum.end(), checksum_.begin());
154 [[nodiscard]]
constexpr std::span<std::byte, checksum_len> checksum_mut() noexcept {
158 void hash_object(std::span<std::byte, checksum_len> out, std::span<std::byte const> obj)
const noexcept {
159 Hash::hash_single(obj, out, key_.get());
161 if constexpr (needs_padding) {
162 MathEngine::clear_padding_bits(out);
170 explicit constexpr LtHash(std::span<std::byte const, checksum_len> initial_checksum = default_checksum)
noexcept {
171 set_checksum_unchecked(initial_checksum);
174 constexpr LtHash(LtHash
const &other)
noexcept =
default;
176 template<
template<
typename>
typename MathEngineT2>
177 constexpr LtHash(LtHash<n_bits_per_elem, n_elems, HashT, MathEngineT2>
const &other) noexcept : key_{other.key_},
178 checksum_{other.checksum_} {
181 constexpr LtHash(LtHash &&other) noexcept : key_{other.key_},
182 checksum_{other.checksum_} {
187 template<
template<
typename>
typename MathEngineT2>
188 constexpr LtHash(LtHash<n_bits_per_elem, n_elems, HashT, MathEngineT2> &&other) noexcept : key_{other.key_},
189 checksum_{other.checksum_} {
193 constexpr LtHash &operator=(LtHash
const &other)
noexcept {
194 if (
this == &other) [[unlikely]] {
200 checksum_ = other.checksum_;
204 constexpr LtHash &operator=(LtHash &&other)
noexcept {
205 assert(
this != &other);
210 checksum_ = other.checksum_;
214 constexpr ~LtHash() noexcept {
222 [[nodiscard]]
constexpr bool key_equal(std::span<std::byte const> other_key)
const noexcept {
223 auto const this_key = key_.get();
224 return std::equal(this_key.begin(), this_key.end(), other_key.begin(), other_key.end());
231 [[nodiscard]]
constexpr bool key_equal(LtHash
const &other)
const noexcept {
232 return key_equal(other.key_.get());
239 template<
size_t supplied_key_len>
240 requires (supplied_key_len == std::dynamic_extent || (supplied_key_len >= Hash::min_key_extent
241 && supplied_key_len <= Hash::max_key_extent))
242 constexpr void set_key(std::span<std::byte const, supplied_key_len> key)
noexcept(supplied_key_len != std::dynamic_extent) {
243 if constexpr (supplied_key_len == std::dynamic_extent) {
244 if (key.size() < Hash::min_key_extent || key.size() > Hash::max_key_extent) [[unlikely]] {
245 throw std::invalid_argument{
"Invalid key size for Blake2Xb"};
249 key_.set_unchecked(key);
255 constexpr void clear_key() noexcept {
259 [[nodiscard]]
constexpr std::span<std::byte const, checksum_len> checksum() const noexcept {
267 [[nodiscard]]
constexpr bool checksum_equal(std::span<std::byte const, checksum_len> other_checksum)
const noexcept {
268 return std::equal(checksum_.begin(), checksum_.end(), other_checksum.begin());
275 [[nodiscard]]
constexpr bool checksum_equal(LtHash
const &other)
const noexcept {
276 return checksum_equal(other.checksum());
283 [[nodiscard]]
bool checksum_equal_constant_time(std::span<std::byte const, checksum_len> other_checksum)
const noexcept {
284 return sodium_memcmp(checksum_, other_checksum.data(), checksum_len) == 0;
291 [[nodiscard]]
bool checksum_equal_constant_time(LtHash
const &other)
const noexcept {
292 return checksum_equal_constant_time(other.checksum());
299 constexpr void set_checksum(std::span<std::byte const, checksum_len> new_checksum)
noexcept(!needs_padding) {
300 set_checksum_unchecked(new_checksum);
301 if constexpr (needs_padding) {
302 if (!MathEngine::check_padding_bits(checksum())) [[unlikely]] {
303 throw std::invalid_argument{
"Invalid checksum: found non-zero padding bits"};
311 constexpr void clear_checksum() noexcept {
312 std::fill(checksum_.begin(), checksum_.end(), std::byte{0});
321 LtHash &combine_add(LtHash
const &other) {
322 if (!key_equal(other)) [[unlikely]] {
323 throw std::invalid_argument{
"Cannot combine hashes with different keys"};
326 MathEngine::add(checksum_mut(), other.checksum());
336 LtHash &combine_remove(LtHash
const &other) {
337 if (!key_equal(other)) [[unlikely]] {
338 throw std::invalid_argument{
"Cannot combine hashes with different keys"};
341 MathEngine::sub(checksum_mut(), other.checksum());
350 LtHash &add(std::span<std::byte const> obj)
noexcept {
351 alignas(MathEngine::min_buffer_align) std::array<std::byte, checksum_len> obj_hash;
352 hash_object(obj_hash, obj);
353 MathEngine::add(checksum_mut(), std::span<std::byte const, checksum_len>{obj_hash});
362 LtHash &remove(std::span<std::byte const> obj)
noexcept {
363 alignas(MathEngine::min_buffer_align) std::array<std::byte, checksum_len> obj_hash;
364 hash_object(obj_hash, obj);
365 MathEngine::sub(checksum_mut(), std::span<std::byte const, checksum_len>{obj_hash});
373 constexpr bool operator==(LtHash
const &other)
const noexcept {
374 return checksum_equal(other);
381 constexpr bool operator!=(LtHash
const &other)
const noexcept {
382 return !LtHash::operator==(other);
386 using LtHash16 = LtHash<16, 1024>;
387 using LtHash20 = LtHash<20, 1008>;
388 using LtHash32 = LtHash<32, 1024>;
393#error "Cannot include LtHash.hpp if sodium is not available"