spacetimedb_sats/
algebraic_value_hash.rs

1//! Defines hash functions for `AlgebraicValue` and friends.
2
3use crate::{
4    bsatn::Deserializer,
5    buffer::{BufReader, DecodeError},
6    de::{Deserialize, Deserializer as _},
7    i256, u256, AlgebraicType, AlgebraicValue, ArrayValue, ProductType, ProductValue, SumType, F32, F64,
8};
9use bytemuck::{must_cast_slice, NoUninit};
10use core::hash::{Hash, Hasher};
11use core::{mem, slice};
12
13// We only manually implement those hash functions that cannot be `#[derive(Hash)]`ed.
14// Those that can be are:
15//
16// - `sum: SumValue`: The generated impl will first write the `sum.tag: u8`,
17//   and then proceed to write the `sum.value`, which delegates to our custom impl below.
18//   The tag is hashed so that e.g., `Result<u32, u32>` represented as an AV
19//   results in different hashes for `Ok(x)` and `Err(x)`.
20//
21// - `map: MapValue`: Uses the hash function for `BTreeMap<AV, AV>`,
22//   which is length prefixed and then writes each `(key, value)` sequentially.
23//   Eventually, this delegates to our custom impl below.
24//
25// - `str: Box<str>`: Uses the standard hash function for `str`.
26//
27// - Primitive types: Trivially what we want,
28//   except for `U256` and `I256` which hash like `[u/i128; 2]` do when outside arrays.
29
30/// The hash function for an [`AlgebraicValue`] only hashes its domain types
31/// and avoids length prefixing for product values.
32/// This avoids the hashing `Discriminant<AlgebraicValue>`
33/// which is OK as a table column will only ever have the same type (and so the same discriminant).
34impl Hash for AlgebraicValue {
35    fn hash<H: Hasher>(&self, state: &mut H) {
36        match self {
37            AlgebraicValue::Sum(x) => x.hash(state),
38            AlgebraicValue::Product(x) => x.hash(state),
39            AlgebraicValue::Array(x) => x.hash(state),
40            AlgebraicValue::Bool(x) => x.hash(state),
41            AlgebraicValue::I8(x) => x.hash(state),
42            AlgebraicValue::U8(x) => x.hash(state),
43            AlgebraicValue::I16(x) => x.hash(state),
44            AlgebraicValue::U16(x) => x.hash(state),
45            AlgebraicValue::I32(x) => x.hash(state),
46            AlgebraicValue::U32(x) => x.hash(state),
47            AlgebraicValue::I64(x) => x.hash(state),
48            AlgebraicValue::U64(x) => x.hash(state),
49            AlgebraicValue::I128(x) => x.hash(state),
50            AlgebraicValue::U128(x) => x.hash(state),
51            AlgebraicValue::I256(x) => x.hash(state),
52            AlgebraicValue::U256(x) => x.hash(state),
53            AlgebraicValue::F32(x) => x.hash(state),
54            AlgebraicValue::F64(x) => x.hash(state),
55            AlgebraicValue::String(s) => s.hash(state),
56            AlgebraicValue::Min | AlgebraicValue::Max => panic!("not defined for Min/Max"),
57        }
58    }
59}
60
61/// The hash function for `ProductValue` does *not* length prefix.
62impl Hash for ProductValue {
63    fn hash<H: Hasher>(&self, state: &mut H) {
64        for field in self.elements.iter() {
65            field.hash(state);
66        }
67    }
68}
69
70/// Hashes `slice` by converting to bytes first,
71/// as done in the standard library.
72fn hash_bytes_of(state: &mut impl Hasher, slice: &[impl NoUninit]) {
73    hash_len_and_bytes(state, slice.len(), must_cast_slice(slice))
74}
75
76/// Hashes `slice` by converting to bytes first,
77/// as done in the standard library.
78///
79/// SAFETY: The type `T` must have no padding.
80unsafe fn unchecked_hash_bytes_of<T>(state: &mut impl Hasher, slice: &[T]) {
81    let newlen = mem::size_of_val(slice);
82    let ptr = slice.as_ptr() as *const u8;
83    // SAFETY: `ptr` is valid and aligned, as `T` has no padding.
84    // The new slice only spans across `data` and is never mutated,
85    // and its total size is the same as the original `data` so it can't be over `isize::MAX`.
86    let bytes = unsafe { slice::from_raw_parts(ptr, newlen) };
87
88    hash_len_and_bytes(state, slice.len(), bytes)
89}
90
91/// The hash function for an [`ArrayValue`] only hashes its domain types.
92/// This avoids the hashing `Discriminant<ArrayValue>`
93/// which is OK as a table column will only ever have the same type (and so the same discriminant).
94/// The hash function will however length-prefix as the value is of variable length.
95impl Hash for ArrayValue {
96    fn hash<H: Hasher>(&self, state: &mut H) {
97        match self {
98            ArrayValue::Sum(es) => es.hash(state),
99            ArrayValue::Product(es) => es.hash(state),
100            ArrayValue::Bool(es) => es.hash(state),
101            ArrayValue::I8(es) => hash_bytes_of(state, es),
102            ArrayValue::U8(es) => hash_bytes_of(state, es),
103            ArrayValue::I16(es) => hash_bytes_of(state, es),
104            ArrayValue::U16(es) => hash_bytes_of(state, es),
105            ArrayValue::I32(es) => hash_bytes_of(state, es),
106            ArrayValue::U32(es) => hash_bytes_of(state, es),
107            ArrayValue::I64(es) => hash_bytes_of(state, es),
108            ArrayValue::U64(es) => hash_bytes_of(state, es),
109            ArrayValue::I128(es) => hash_bytes_of(state, es),
110            ArrayValue::U128(es) => hash_bytes_of(state, es),
111            // SAFETY: The following two types are `repr(transparent)`
112            // over `[u/i128; 2]` which have no padding.
113            ArrayValue::I256(es) => unsafe { unchecked_hash_bytes_of(state, es) },
114            ArrayValue::U256(es) => unsafe { unchecked_hash_bytes_of(state, es) },
115            ArrayValue::F32(es) => es.hash(state),
116            ArrayValue::F64(es) => es.hash(state),
117            ArrayValue::String(es) => es.hash(state),
118            ArrayValue::Array(es) => es.hash(state),
119        }
120    }
121}
122
123type HR = Result<(), DecodeError>;
124
125fn hash_bsatn<'a>(state: &mut impl Hasher, ty: &AlgebraicType, de: Deserializer<'_, impl BufReader<'a>>) -> HR {
126    match ty {
127        AlgebraicType::Ref(_) => unreachable!("hash_bsatn does not have a typespace"),
128        AlgebraicType::Sum(ty) => hash_bsatn_sum(state, ty, de),
129        AlgebraicType::Product(ty) => hash_bsatn_prod(state, ty, de),
130        AlgebraicType::Array(ty) => hash_bsatn_array(state, &ty.elem_ty, de),
131        AlgebraicType::Bool => hash_bsatn_de::<bool>(state, de),
132        AlgebraicType::I8 => hash_bsatn_de::<i8>(state, de),
133        AlgebraicType::U8 => hash_bsatn_de::<u8>(state, de),
134        AlgebraicType::I16 => hash_bsatn_de::<i16>(state, de),
135        AlgebraicType::U16 => hash_bsatn_de::<u16>(state, de),
136        AlgebraicType::I32 => hash_bsatn_de::<i32>(state, de),
137        AlgebraicType::U32 => hash_bsatn_de::<u32>(state, de),
138        AlgebraicType::I64 => hash_bsatn_de::<i64>(state, de),
139        AlgebraicType::U64 => hash_bsatn_de::<u64>(state, de),
140        AlgebraicType::I128 => hash_bsatn_de::<i128>(state, de),
141        AlgebraicType::U128 => hash_bsatn_de::<u128>(state, de),
142        AlgebraicType::I256 => hash_bsatn_de::<i256>(state, de),
143        AlgebraicType::U256 => hash_bsatn_de::<u256>(state, de),
144        AlgebraicType::F32 => hash_bsatn_de::<F32>(state, de),
145        AlgebraicType::F64 => hash_bsatn_de::<F64>(state, de),
146        AlgebraicType::String => hash_bsatn_de::<&str>(state, de),
147    }
148}
149
150/// Hashes the tag and payload of the BSATN-encoded sum value.
151fn hash_bsatn_sum<'a>(state: &mut impl Hasher, ty: &SumType, mut de: Deserializer<'_, impl BufReader<'a>>) -> HR {
152    // Read + hash the tag.
153    let tag = de.reborrow().deserialize_u8()?;
154    tag.hash(state);
155
156    // Hash the payload.
157    let data_ty = &ty.variants[tag as usize].algebraic_type;
158    hash_bsatn(state, data_ty, de)
159}
160
161/// Hashes every field in the BSATN-encoded product value.
162fn hash_bsatn_prod<'a>(state: &mut impl Hasher, ty: &ProductType, mut de: Deserializer<'_, impl BufReader<'a>>) -> HR {
163    ty.elements
164        .iter()
165        .try_for_each(|f| hash_bsatn(state, &f.algebraic_type, de.reborrow()))
166}
167
168/// Hashes every elem in the BSATN-encoded array value.
169pub fn hash_bsatn_array<'a>(
170    state: &mut impl Hasher,
171    ty: &AlgebraicType,
172    de: Deserializer<'_, impl BufReader<'a>>,
173) -> HR {
174    // The BSATN is length-prefixed.
175    // `Hash for &[T]` also does length-prefixing.
176    match ty {
177        AlgebraicType::Ref(_) => unreachable!("hash_bsatn does not have a typespace"),
178        AlgebraicType::Sum(ty) => hash_bsatn_seq(state, de, |s, d| hash_bsatn_sum(s, ty, d)),
179        AlgebraicType::Product(ty) => hash_bsatn_seq(state, de, |s, d| hash_bsatn_prod(s, ty, d)),
180        AlgebraicType::Array(ty) => hash_bsatn_seq(state, de, |s, d| hash_bsatn_array(s, &ty.elem_ty, d)),
181        AlgebraicType::Bool => hash_bsatn_seq(state, de, hash_bsatn_de::<bool>),
182        AlgebraicType::I8 | AlgebraicType::U8 => hash_bsatn_int_seq(state, de, 1),
183        AlgebraicType::I16 | AlgebraicType::U16 => hash_bsatn_int_seq(state, de, 2),
184        AlgebraicType::I32 | AlgebraicType::U32 => hash_bsatn_int_seq(state, de, 4),
185        AlgebraicType::I64 | AlgebraicType::U64 => hash_bsatn_int_seq(state, de, 8),
186        AlgebraicType::I128 | AlgebraicType::U128 => hash_bsatn_int_seq(state, de, 16),
187        AlgebraicType::I256 | AlgebraicType::U256 => hash_bsatn_int_seq(state, de, 32),
188        AlgebraicType::F32 => hash_bsatn_seq(state, de, hash_bsatn_de::<F32>),
189        AlgebraicType::F64 => hash_bsatn_seq(state, de, hash_bsatn_de::<F64>),
190        AlgebraicType::String => hash_bsatn_seq(state, de, hash_bsatn_de::<&str>),
191    }
192}
193
194/// Hashes elements in the BSATN-encoded element sequence.
195/// The sequence is prefixed with its length and the hash will as well.
196fn hash_bsatn_seq<'a, H: Hasher, R: BufReader<'a>>(
197    state: &mut H,
198    mut de: Deserializer<'_, R>,
199    mut elem_hash: impl FnMut(&mut H, Deserializer<'_, R>) -> Result<(), DecodeError>,
200) -> HR {
201    // The BSATN is length-prefixed.
202    // The Hash also needs to be length-prefixed.
203    let len = de.reborrow().deserialize_len()?;
204    state.write_usize(len);
205
206    // Hash each element.
207    (0..len).try_for_each(|_| elem_hash(state, de.reborrow()))
208}
209
210/// Hashes the BSATN-encoded integer sequence where each integer is `width` bytes wide.
211/// The sequence is prefixed with its length and the hash will as well.
212fn hash_bsatn_int_seq<'a, H: Hasher, R: BufReader<'a>>(state: &mut H, mut de: Deserializer<'_, R>, width: usize) -> HR {
213    // The BSATN is length-prefixed.
214    // The Hash also needs to be length-prefixed.
215    let len = de.reborrow().deserialize_len()?;
216
217    // Extract and hash the bytes.
218    // This is consistent with what `<$int_primitive>::hash_slice` will do
219    // and for `U/I256` we provide special logic in `impl Hash for ArrayValue` above
220    // and handle it the same way for `spacetimedb_table::table::RowRef`s.
221    let bytes = de.get_slice(len * width)?;
222
223    hash_len_and_bytes(state, len, bytes);
224    Ok(())
225}
226
227/// Hashes a `len` prefix as well as `bytes`.
228fn hash_len_and_bytes(state: &mut impl Hasher, len: usize, bytes: &[u8]) {
229    state.write_usize(len);
230    state.write(bytes);
231}
232
233/// Deserializes from `de` an `x: T` and then proceeds to hash `x`.
234fn hash_bsatn_de<'a, T: Hash + Deserialize<'a>>(
235    state: &mut impl Hasher,
236    de: Deserializer<'_, impl BufReader<'a>>,
237) -> HR {
238    T::deserialize(de).map(|x| x.hash(state))
239}
240
241#[cfg(test)]
242mod tests {
243    use super::hash_bsatn;
244    use crate::{
245        bsatn::{to_vec, Deserializer},
246        proptest::generate_typed_value,
247        AlgebraicType, AlgebraicValue,
248    };
249    use proptest::prelude::*;
250    use std::hash::{BuildHasher, Hasher as _};
251
252    fn hash_one_bsatn_av(bh: &impl BuildHasher, ty: &AlgebraicType, val: &AlgebraicValue) -> u64 {
253        let mut bsatn = &*to_vec(&val).unwrap();
254        let de = Deserializer::new(&mut bsatn);
255        let mut hasher = bh.build_hasher();
256        hash_bsatn(&mut hasher, ty, de).unwrap();
257        hasher.finish()
258    }
259
260    proptest! {
261        #![proptest_config(ProptestConfig::with_cases(2048))]
262        #[test]
263        fn av_bsatn_hash_same_std_random_state((ty, val) in generate_typed_value()) {
264            let rs = std::hash::RandomState::new();
265            let hash_av = rs.hash_one(&val);
266            let hash_av_bsatn = hash_one_bsatn_av(&rs, &ty, &val);
267            prop_assert_eq!(hash_av, hash_av_bsatn);
268        }
269
270        #[test]
271        fn av_bsatn_hash_same_ahash((ty, val) in generate_typed_value()) {
272            let rs = ahash::RandomState::new();
273            let hash_av = rs.hash_one(&val);
274            let hash_av_bsatn = hash_one_bsatn_av(&rs, &ty, &val);
275            prop_assert_eq!(hash_av, hash_av_bsatn);
276        }
277    }
278}