1use crate::{
4 bsatn::Deserializer,
5 buffer::{BufReader, DecodeError},
6 de::{Deserialize, Deserializer as _},
7 i256, u256, AlgebraicType, AlgebraicValue, ArrayValue, ProductType, ProductValue, SumType, F32, F64,
8};
9use bytemuck::{must_cast_slice, NoUninit};
10use core::hash::{Hash, Hasher};
11use core::{mem, slice};
12
13impl Hash for AlgebraicValue {
35 fn hash<H: Hasher>(&self, state: &mut H) {
36 match self {
37 AlgebraicValue::Sum(x) => x.hash(state),
38 AlgebraicValue::Product(x) => x.hash(state),
39 AlgebraicValue::Array(x) => x.hash(state),
40 AlgebraicValue::Bool(x) => x.hash(state),
41 AlgebraicValue::I8(x) => x.hash(state),
42 AlgebraicValue::U8(x) => x.hash(state),
43 AlgebraicValue::I16(x) => x.hash(state),
44 AlgebraicValue::U16(x) => x.hash(state),
45 AlgebraicValue::I32(x) => x.hash(state),
46 AlgebraicValue::U32(x) => x.hash(state),
47 AlgebraicValue::I64(x) => x.hash(state),
48 AlgebraicValue::U64(x) => x.hash(state),
49 AlgebraicValue::I128(x) => x.hash(state),
50 AlgebraicValue::U128(x) => x.hash(state),
51 AlgebraicValue::I256(x) => x.hash(state),
52 AlgebraicValue::U256(x) => x.hash(state),
53 AlgebraicValue::F32(x) => x.hash(state),
54 AlgebraicValue::F64(x) => x.hash(state),
55 AlgebraicValue::String(s) => s.hash(state),
56 AlgebraicValue::Min | AlgebraicValue::Max => panic!("not defined for Min/Max"),
57 }
58 }
59}
60
61impl Hash for ProductValue {
63 fn hash<H: Hasher>(&self, state: &mut H) {
64 for field in self.elements.iter() {
65 field.hash(state);
66 }
67 }
68}
69
70fn hash_bytes_of(state: &mut impl Hasher, slice: &[impl NoUninit]) {
73 hash_len_and_bytes(state, slice.len(), must_cast_slice(slice))
74}
75
76unsafe fn unchecked_hash_bytes_of<T>(state: &mut impl Hasher, slice: &[T]) {
81 let newlen = mem::size_of_val(slice);
82 let ptr = slice.as_ptr() as *const u8;
83 let bytes = unsafe { slice::from_raw_parts(ptr, newlen) };
87
88 hash_len_and_bytes(state, slice.len(), bytes)
89}
90
91impl Hash for ArrayValue {
96 fn hash<H: Hasher>(&self, state: &mut H) {
97 match self {
98 ArrayValue::Sum(es) => es.hash(state),
99 ArrayValue::Product(es) => es.hash(state),
100 ArrayValue::Bool(es) => es.hash(state),
101 ArrayValue::I8(es) => hash_bytes_of(state, es),
102 ArrayValue::U8(es) => hash_bytes_of(state, es),
103 ArrayValue::I16(es) => hash_bytes_of(state, es),
104 ArrayValue::U16(es) => hash_bytes_of(state, es),
105 ArrayValue::I32(es) => hash_bytes_of(state, es),
106 ArrayValue::U32(es) => hash_bytes_of(state, es),
107 ArrayValue::I64(es) => hash_bytes_of(state, es),
108 ArrayValue::U64(es) => hash_bytes_of(state, es),
109 ArrayValue::I128(es) => hash_bytes_of(state, es),
110 ArrayValue::U128(es) => hash_bytes_of(state, es),
111 ArrayValue::I256(es) => unsafe { unchecked_hash_bytes_of(state, es) },
114 ArrayValue::U256(es) => unsafe { unchecked_hash_bytes_of(state, es) },
115 ArrayValue::F32(es) => es.hash(state),
116 ArrayValue::F64(es) => es.hash(state),
117 ArrayValue::String(es) => es.hash(state),
118 ArrayValue::Array(es) => es.hash(state),
119 }
120 }
121}
122
123type HR = Result<(), DecodeError>;
124
125fn hash_bsatn<'a>(state: &mut impl Hasher, ty: &AlgebraicType, de: Deserializer<'_, impl BufReader<'a>>) -> HR {
126 match ty {
127 AlgebraicType::Ref(_) => unreachable!("hash_bsatn does not have a typespace"),
128 AlgebraicType::Sum(ty) => hash_bsatn_sum(state, ty, de),
129 AlgebraicType::Product(ty) => hash_bsatn_prod(state, ty, de),
130 AlgebraicType::Array(ty) => hash_bsatn_array(state, &ty.elem_ty, de),
131 AlgebraicType::Bool => hash_bsatn_de::<bool>(state, de),
132 AlgebraicType::I8 => hash_bsatn_de::<i8>(state, de),
133 AlgebraicType::U8 => hash_bsatn_de::<u8>(state, de),
134 AlgebraicType::I16 => hash_bsatn_de::<i16>(state, de),
135 AlgebraicType::U16 => hash_bsatn_de::<u16>(state, de),
136 AlgebraicType::I32 => hash_bsatn_de::<i32>(state, de),
137 AlgebraicType::U32 => hash_bsatn_de::<u32>(state, de),
138 AlgebraicType::I64 => hash_bsatn_de::<i64>(state, de),
139 AlgebraicType::U64 => hash_bsatn_de::<u64>(state, de),
140 AlgebraicType::I128 => hash_bsatn_de::<i128>(state, de),
141 AlgebraicType::U128 => hash_bsatn_de::<u128>(state, de),
142 AlgebraicType::I256 => hash_bsatn_de::<i256>(state, de),
143 AlgebraicType::U256 => hash_bsatn_de::<u256>(state, de),
144 AlgebraicType::F32 => hash_bsatn_de::<F32>(state, de),
145 AlgebraicType::F64 => hash_bsatn_de::<F64>(state, de),
146 AlgebraicType::String => hash_bsatn_de::<&str>(state, de),
147 }
148}
149
150fn hash_bsatn_sum<'a>(state: &mut impl Hasher, ty: &SumType, mut de: Deserializer<'_, impl BufReader<'a>>) -> HR {
152 let tag = de.reborrow().deserialize_u8()?;
154 tag.hash(state);
155
156 let data_ty = &ty.variants[tag as usize].algebraic_type;
158 hash_bsatn(state, data_ty, de)
159}
160
161fn hash_bsatn_prod<'a>(state: &mut impl Hasher, ty: &ProductType, mut de: Deserializer<'_, impl BufReader<'a>>) -> HR {
163 ty.elements
164 .iter()
165 .try_for_each(|f| hash_bsatn(state, &f.algebraic_type, de.reborrow()))
166}
167
168pub fn hash_bsatn_array<'a>(
170 state: &mut impl Hasher,
171 ty: &AlgebraicType,
172 de: Deserializer<'_, impl BufReader<'a>>,
173) -> HR {
174 match ty {
177 AlgebraicType::Ref(_) => unreachable!("hash_bsatn does not have a typespace"),
178 AlgebraicType::Sum(ty) => hash_bsatn_seq(state, de, |s, d| hash_bsatn_sum(s, ty, d)),
179 AlgebraicType::Product(ty) => hash_bsatn_seq(state, de, |s, d| hash_bsatn_prod(s, ty, d)),
180 AlgebraicType::Array(ty) => hash_bsatn_seq(state, de, |s, d| hash_bsatn_array(s, &ty.elem_ty, d)),
181 AlgebraicType::Bool => hash_bsatn_seq(state, de, hash_bsatn_de::<bool>),
182 AlgebraicType::I8 | AlgebraicType::U8 => hash_bsatn_int_seq(state, de, 1),
183 AlgebraicType::I16 | AlgebraicType::U16 => hash_bsatn_int_seq(state, de, 2),
184 AlgebraicType::I32 | AlgebraicType::U32 => hash_bsatn_int_seq(state, de, 4),
185 AlgebraicType::I64 | AlgebraicType::U64 => hash_bsatn_int_seq(state, de, 8),
186 AlgebraicType::I128 | AlgebraicType::U128 => hash_bsatn_int_seq(state, de, 16),
187 AlgebraicType::I256 | AlgebraicType::U256 => hash_bsatn_int_seq(state, de, 32),
188 AlgebraicType::F32 => hash_bsatn_seq(state, de, hash_bsatn_de::<F32>),
189 AlgebraicType::F64 => hash_bsatn_seq(state, de, hash_bsatn_de::<F64>),
190 AlgebraicType::String => hash_bsatn_seq(state, de, hash_bsatn_de::<&str>),
191 }
192}
193
194fn hash_bsatn_seq<'a, H: Hasher, R: BufReader<'a>>(
197 state: &mut H,
198 mut de: Deserializer<'_, R>,
199 mut elem_hash: impl FnMut(&mut H, Deserializer<'_, R>) -> Result<(), DecodeError>,
200) -> HR {
201 let len = de.reborrow().deserialize_len()?;
204 state.write_usize(len);
205
206 (0..len).try_for_each(|_| elem_hash(state, de.reborrow()))
208}
209
210fn hash_bsatn_int_seq<'a, H: Hasher, R: BufReader<'a>>(state: &mut H, mut de: Deserializer<'_, R>, width: usize) -> HR {
213 let len = de.reborrow().deserialize_len()?;
216
217 let bytes = de.get_slice(len * width)?;
222
223 hash_len_and_bytes(state, len, bytes);
224 Ok(())
225}
226
227fn hash_len_and_bytes(state: &mut impl Hasher, len: usize, bytes: &[u8]) {
229 state.write_usize(len);
230 state.write(bytes);
231}
232
233fn hash_bsatn_de<'a, T: Hash + Deserialize<'a>>(
235 state: &mut impl Hasher,
236 de: Deserializer<'_, impl BufReader<'a>>,
237) -> HR {
238 T::deserialize(de).map(|x| x.hash(state))
239}
240
241#[cfg(test)]
242mod tests {
243 use super::hash_bsatn;
244 use crate::{
245 bsatn::{to_vec, Deserializer},
246 proptest::generate_typed_value,
247 AlgebraicType, AlgebraicValue,
248 };
249 use proptest::prelude::*;
250 use std::hash::{BuildHasher, Hasher as _};
251
252 fn hash_one_bsatn_av(bh: &impl BuildHasher, ty: &AlgebraicType, val: &AlgebraicValue) -> u64 {
253 let mut bsatn = &*to_vec(&val).unwrap();
254 let de = Deserializer::new(&mut bsatn);
255 let mut hasher = bh.build_hasher();
256 hash_bsatn(&mut hasher, ty, de).unwrap();
257 hasher.finish()
258 }
259
260 proptest! {
261 #![proptest_config(ProptestConfig::with_cases(2048))]
262 #[test]
263 fn av_bsatn_hash_same_std_random_state((ty, val) in generate_typed_value()) {
264 let rs = std::hash::RandomState::new();
265 let hash_av = rs.hash_one(&val);
266 let hash_av_bsatn = hash_one_bsatn_av(&rs, &ty, &val);
267 prop_assert_eq!(hash_av, hash_av_bsatn);
268 }
269
270 #[test]
271 fn av_bsatn_hash_same_ahash((ty, val) in generate_typed_value()) {
272 let rs = ahash::RandomState::new();
273 let hash_av = rs.hash_one(&val);
274 let hash_av_bsatn = hash_one_bsatn_av(&rs, &ty, &val);
275 prop_assert_eq!(hash_av, hash_av_bsatn);
276 }
277 }
278}