1use core::fmt;
2use std::str::FromStr;
3
4use arbitrary::Arbitrary;
5use bfieldcodec_derive::BFieldCodec;
6use get_size2::GetSize;
7use itertools::Itertools;
8use num_bigint::BigUint;
9use num_traits::ConstZero;
10use num_traits::Zero;
11use rand::Rng;
12use rand::distr::Distribution;
13use rand::distr::StandardUniform;
14use serde::Deserialize;
15use serde::Deserializer;
16use serde::Serialize;
17use serde::Serializer;
18
19use crate::error::TryFromDigestError;
20use crate::error::TryFromHexDigestError;
21use crate::math::b_field_element::BFieldElement;
22use crate::prelude::Tip5;
23
24#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, BFieldCodec, Arbitrary)]
28pub struct Digest(pub [BFieldElement; Digest::LEN]);
29
30impl GetSize for Digest {
31 fn get_stack_size() -> usize {
32 std::mem::size_of::<Self>()
33 }
34
35 fn get_heap_size(&self) -> usize {
36 0
37 }
38}
39
40impl PartialOrd for Digest {
41 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
42 Some(self.cmp(other))
43 }
44}
45
46impl Ord for Digest {
47 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
48 let Digest(self_inner) = self;
49 let Digest(other_inner) = other;
50 let self_as_u64s = self_inner.iter().rev().map(|bfe| bfe.value());
51 let other_as_u64s = other_inner.iter().rev().map(|bfe| bfe.value());
52 self_as_u64s.cmp(other_as_u64s)
53 }
54}
55
56impl Digest {
57 pub const LEN: usize = 5;
59
60 pub const BYTES: usize = Self::LEN * BFieldElement::BYTES;
62
63 pub(crate) const ALL_ZERO: Self = Self([BFieldElement::ZERO; Self::LEN]);
65
66 pub const fn values(self) -> [BFieldElement; Self::LEN] {
67 self.0
68 }
69
70 pub const fn new(digest: [BFieldElement; Self::LEN]) -> Self {
71 Self(digest)
72 }
73
74 pub const fn reversed(self) -> Digest {
77 let Digest([d0, d1, d2, d3, d4]) = self;
78 Digest([d4, d3, d2, d1, d0])
79 }
80}
81
82impl Default for Digest {
83 fn default() -> Self {
84 Self::ALL_ZERO
85 }
86}
87
88impl fmt::Display for Digest {
89 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
90 write!(f, "{}", self.0.map(|elem| elem.to_string()).join(","))
91 }
92}
93
94impl fmt::LowerHex for Digest {
95 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
96 let bytes = <[u8; Self::BYTES]>::from(*self);
97 write!(f, "{}", hex::encode(bytes))
98 }
99}
100
101impl fmt::UpperHex for Digest {
102 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
103 let bytes = <[u8; Self::BYTES]>::from(*self);
104 write!(f, "{}", hex::encode_upper(bytes))
105 }
106}
107
108impl Distribution<Digest> for StandardUniform {
109 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Digest {
110 Digest::new(rng.random())
111 }
112}
113
114impl FromStr for Digest {
115 type Err = TryFromDigestError;
116
117 fn from_str(string: &str) -> Result<Self, Self::Err> {
118 let bfes: Vec<_> = string
119 .split(',')
120 .map(str::parse::<BFieldElement>)
121 .try_collect()?;
122 let invalid_len_err = Self::Err::InvalidLength(bfes.len());
123 let digest_innards = bfes.try_into().map_err(|_| invalid_len_err)?;
124
125 Ok(Digest(digest_innards))
126 }
127}
128
129impl TryFrom<&[BFieldElement]> for Digest {
130 type Error = TryFromDigestError;
131
132 fn try_from(value: &[BFieldElement]) -> Result<Self, Self::Error> {
133 let len = value.len();
134 let maybe_digest = value.try_into().map(Digest::new);
135 maybe_digest.map_err(|_| Self::Error::InvalidLength(len))
136 }
137}
138
139impl TryFrom<Vec<BFieldElement>> for Digest {
140 type Error = TryFromDigestError;
141
142 fn try_from(value: Vec<BFieldElement>) -> Result<Self, Self::Error> {
143 Digest::try_from(&value as &[BFieldElement])
144 }
145}
146
147impl From<Digest> for Vec<BFieldElement> {
148 fn from(val: Digest) -> Self {
149 val.0.to_vec()
150 }
151}
152
153impl From<Digest> for [u8; Digest::BYTES] {
154 fn from(Digest(innards): Digest) -> Self {
155 innards
156 .map(<[u8; BFieldElement::BYTES]>::from)
157 .concat()
158 .try_into()
159 .unwrap()
160 }
161}
162
163impl TryFrom<[u8; Digest::BYTES]> for Digest {
164 type Error = TryFromDigestError;
165
166 fn try_from(item: [u8; Self::BYTES]) -> Result<Self, Self::Error> {
167 let digest_innards: Vec<_> = item
168 .chunks_exact(BFieldElement::BYTES)
169 .map(BFieldElement::try_from)
170 .try_collect()?;
171
172 Ok(Self(digest_innards.try_into().unwrap()))
173 }
174}
175
176impl TryFrom<&[u8]> for Digest {
177 type Error = TryFromDigestError;
178
179 fn try_from(slice: &[u8]) -> Result<Self, Self::Error> {
180 let array = <[u8; Self::BYTES]>::try_from(slice)
181 .map_err(|_e| TryFromDigestError::InvalidLength(slice.len()))?;
182 Self::try_from(array)
183 }
184}
185
186impl TryFrom<BigUint> for Digest {
187 type Error = TryFromDigestError;
188
189 fn try_from(value: BigUint) -> Result<Self, Self::Error> {
190 let mut remaining = value;
191 let mut digest_innards = [BFieldElement::ZERO; Self::LEN];
192 let modulus: BigUint = BFieldElement::P.into();
193 for digest_element in digest_innards.iter_mut() {
194 let element = u64::try_from(remaining.clone() % modulus.clone()).unwrap();
195 *digest_element = BFieldElement::new(element);
196 remaining /= modulus.clone();
197 }
198
199 if !remaining.is_zero() {
200 return Err(Self::Error::Overflow);
201 }
202
203 Ok(Digest::new(digest_innards))
204 }
205}
206
207impl From<Digest> for BigUint {
208 fn from(digest: Digest) -> Self {
209 let Digest(digest_innards) = digest;
210 let mut ret = BigUint::zero();
211 let modulus: BigUint = BFieldElement::P.into();
212 for i in (0..Digest::LEN).rev() {
213 ret *= modulus.clone();
214 let digest_element: BigUint = digest_innards[i].value().into();
215 ret += digest_element;
216 }
217
218 ret
219 }
220}
221
222impl Digest {
223 pub fn hash(self) -> Digest {
237 Tip5::hash_pair(self, Self::ALL_ZERO)
238 }
239
240 pub fn to_hex(self) -> String {
248 format!("{self:x}")
249 }
250
251 pub fn try_from_hex(data: impl AsRef<[u8]>) -> Result<Self, TryFromHexDigestError> {
253 let slice = hex::decode(data)?;
254 Ok(Self::try_from(&slice as &[u8])?)
255 }
256}
257
258impl Serialize for Digest {
261 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
262 if serializer.is_human_readable() {
263 self.to_hex().serialize(serializer)
264 } else {
265 self.0.serialize(serializer)
266 }
267 }
268}
269
270impl<'de> Deserialize<'de> for Digest {
273 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
274 where
275 D: Deserializer<'de>,
276 {
277 if deserializer.is_human_readable() {
278 let hex_string = String::deserialize(deserializer)?;
279 Self::try_from_hex(hex_string).map_err(serde::de::Error::custom)
280 } else {
281 Ok(Self::new(<[BFieldElement; Self::LEN]>::deserialize(
282 deserializer,
283 )?))
284 }
285 }
286}
287
288#[cfg(test)]
289pub(crate) mod digest_tests {
290 use num_traits::One;
291 use proptest::collection::vec;
292 use proptest::prelude::Arbitrary as ProptestArbitrary;
293 use proptest::prelude::*;
294 use proptest_arbitrary_interop::arb;
295 use test_strategy::proptest;
296
297 use super::*;
298 use crate::error::ParseBFieldElementError;
299 use crate::prelude::*;
300
301 impl ProptestArbitrary for Digest {
302 type Parameters = ();
303 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
304 arb().prop_map(|d| d).no_shrink().boxed()
305 }
306
307 type Strategy = BoxedStrategy<Self>;
308 }
309
310 #[derive(Debug, Clone, PartialEq, Eq, test_strategy::Arbitrary)]
312 pub(crate) struct DigestCorruptor {
313 #[strategy(vec(0..Digest::LEN, 1..=Digest::LEN))]
314 #[filter(#corrupt_indices.iter().all_unique())]
315 corrupt_indices: Vec<usize>,
316
317 #[strategy(vec(arb(), #corrupt_indices.len()))]
318 corrupt_elements: Vec<BFieldElement>,
319 }
320
321 impl DigestCorruptor {
322 pub fn corrupt_digest(&self, digest: Digest) -> Result<Digest, TestCaseError> {
323 let mut corrupt_digest = digest;
324 for (&i, &element) in self.corrupt_indices.iter().zip(&self.corrupt_elements) {
325 corrupt_digest.0[i] = element;
326 }
327 if corrupt_digest == digest {
328 let reject_reason = "corruption must change digest".into();
329 return Err(TestCaseError::Reject(reject_reason));
330 }
331
332 Ok(corrupt_digest)
333 }
334 }
335
336 #[test]
337 fn digest_corruptor_rejects_uncorrupting_corruption() {
338 let digest = Digest(bfe_array![1, 2, 3, 4, 5]);
339 let corruptor = DigestCorruptor {
340 corrupt_indices: vec![0],
341 corrupt_elements: bfe_vec![1],
342 };
343 let err = corruptor.corrupt_digest(digest).unwrap_err();
344 assert!(matches!(err, TestCaseError::Reject(_)));
345 }
346
347 #[test]
348 fn get_size() {
349 let stack = Digest::get_stack_size();
350
351 let bfes = bfe_array![12, 24, 36, 48, 60];
352 let tip5_digest_type_from_array: Digest = Digest::new(bfes);
353 let heap = tip5_digest_type_from_array.get_heap_size();
354 let total = tip5_digest_type_from_array.get_size();
355 println!("stack: {stack} + heap: {heap} = {total}");
356
357 assert_eq!(stack + heap, total)
358 }
359
360 #[test]
361 fn digest_from_str() {
362 let valid_digest_string = "12063201067205522823,\
363 1529663126377206632,\
364 2090171368883726200,\
365 12975872837767296928,\
366 11492877804687889759";
367 let valid_digest = Digest::from_str(valid_digest_string);
368 assert!(valid_digest.is_ok());
369
370 let invalid_digest_string = "00059361073062755064,05168490802189810700";
371 let invalid_digest = Digest::from_str(invalid_digest_string);
372 assert!(invalid_digest.is_err());
373
374 let second_invalid_digest_string = "this_is_not_a_bfield_element,05168490802189810700";
375 let second_invalid_digest = Digest::from_str(second_invalid_digest_string);
376 assert!(second_invalid_digest.is_err());
377 }
378
379 #[proptest]
380 fn test_reversed_involution(digest: Digest) {
381 prop_assert_eq!(digest, digest.reversed().reversed())
382 }
383
384 #[test]
385 fn digest_biguint_conversion_simple_test() {
386 let fourteen: BigUint = 14u128.into();
387 let fourteen_converted_expected = Digest(bfe_array![14, 0, 0, 0, 0]);
388
389 let bfe_max: BigUint = BFieldElement::MAX.into();
390 let bfe_max_converted_expected = Digest(bfe_array![BFieldElement::MAX, 0, 0, 0, 0]);
391
392 let bfe_max_plus_one: BigUint = BFieldElement::P.into();
393 let bfe_max_plus_one_converted_expected = Digest(bfe_array![0, 1, 0, 0, 0]);
394
395 let two_pow_64: BigUint = (1u128 << 64).into();
396 let two_pow_64_converted_expected = Digest(bfe_array![(1u64 << 32) - 1, 1, 0, 0, 0]);
397
398 let two_pow_123: BigUint = (1u128 << 123).into();
399 let two_pow_123_converted_expected =
400 Digest([18446744069280366593, 576460752437641215, 0, 0, 0].map(BFieldElement::new));
401
402 let two_pow_315: BigUint = BigUint::from(2u128).pow(315);
403
404 let two_pow_315_converted_expected = Digest(bfe_array![
406 18446744069280366593_u64,
407 1729382257312923647_u64,
408 13258597298683772929_u64,
409 3458764513015234559_u64,
410 576460752840294400_u64,
411 ]);
412
413 assert_eq!(
415 fourteen_converted_expected,
416 fourteen.clone().try_into().unwrap()
417 );
418 assert_eq!(
419 bfe_max_converted_expected,
420 bfe_max.clone().try_into().unwrap()
421 );
422 assert_eq!(
423 bfe_max_plus_one_converted_expected,
424 bfe_max_plus_one.clone().try_into().unwrap()
425 );
426 assert_eq!(
427 two_pow_64_converted_expected,
428 two_pow_64.clone().try_into().unwrap()
429 );
430 assert_eq!(
431 two_pow_123_converted_expected,
432 two_pow_123.clone().try_into().unwrap()
433 );
434 assert_eq!(
435 two_pow_315_converted_expected,
436 two_pow_315.clone().try_into().unwrap()
437 );
438
439 assert_eq!(fourteen, fourteen_converted_expected.into());
441 assert_eq!(bfe_max, bfe_max_converted_expected.into());
442 assert_eq!(bfe_max_plus_one, bfe_max_plus_one_converted_expected.into());
443 assert_eq!(two_pow_64, two_pow_64_converted_expected.into());
444 assert_eq!(two_pow_123, two_pow_123_converted_expected.into());
445 assert_eq!(two_pow_315, two_pow_315_converted_expected.into());
446 }
447
448 #[proptest]
449 fn digest_biguint_conversion_pbt(components_0: [u64; 4], component_1: u32) {
450 let big_uint = components_0
451 .into_iter()
452 .fold(BigUint::one(), |acc, x| acc * x);
453 let big_uint = big_uint * component_1;
454
455 let as_digest: Digest = big_uint.clone().try_into().unwrap();
456 let big_uint_again: BigUint = as_digest.into();
457 prop_assert_eq!(big_uint, big_uint_again);
458 }
459
460 #[test]
461 fn digest_ordering() {
462 let val0 = Digest::new(bfe_array![0; Digest::LEN]);
463 let val1 = Digest::new(bfe_array![14, 0, 0, 0, 0]);
464 assert!(val1 > val0);
465
466 let val2 = Digest::new(bfe_array![14; Digest::LEN]);
467 assert!(val2 > val1);
468 assert!(val2 > val0);
469
470 let val3 = Digest::new(bfe_array![15, 14, 14, 14, 14]);
471 assert!(val3 > val2);
472 assert!(val3 > val1);
473 assert!(val3 > val0);
474
475 let val4 = Digest::new(bfe_array![14, 15, 14, 14, 14]);
476 assert!(val4 > val3);
477 assert!(val4 > val2);
478 assert!(val4 > val1);
479 assert!(val4 > val0);
480 }
481
482 #[test]
483 fn digest_biguint_overflow_test() {
484 let mut two_pow_384: BigUint = (1u128 << 96).into();
485 two_pow_384 = two_pow_384.pow(4);
486 let err = Digest::try_from(two_pow_384).unwrap_err();
487
488 assert_eq!(TryFromDigestError::Overflow, err);
489 }
490
491 #[proptest]
492 fn forty_bytes_can_be_converted_to_digest(bytes: [u8; Digest::BYTES]) {
493 let digest = Digest::try_from(bytes).unwrap();
494 let bytes_again: [u8; Digest::BYTES] = digest.into();
495 prop_assert_eq!(bytes, bytes_again);
496 }
497
498 #[test]
500 fn try_from_bytes_not_canonical() -> Result<(), TryFromDigestError> {
501 let bytes: [u8; Digest::BYTES] = [255; Digest::BYTES];
502
503 assert!(Digest::try_from(bytes).is_err_and(|e| matches!(
504 e,
505 TryFromDigestError::InvalidBFieldElement(ParseBFieldElementError::NotCanonical(_))
506 )));
507
508 Ok(())
509 }
510
511 #[test]
513 fn from_str_not_canonical() -> Result<(), TryFromDigestError> {
514 let str = format!("0,0,0,0,{}", u64::MAX);
515
516 assert!(Digest::from_str(&str).is_err_and(|e| matches!(
517 e,
518 TryFromDigestError::InvalidBFieldElement(ParseBFieldElementError::NotCanonical(_))
519 )));
520
521 Ok(())
522 }
523
524 #[test]
525 fn bytes_in_matches_bytes_out() -> Result<(), TryFromDigestError> {
526 let bytes1: [u8; Digest::BYTES] = [254; Digest::BYTES];
527 let d1 = Digest::try_from(bytes1)?;
528
529 let bytes2: [u8; Digest::BYTES] = d1.into();
530 let d2 = Digest::try_from(bytes2)?;
531
532 assert_eq!(d1, d2);
533 assert_eq!(bytes1, bytes2);
534
535 Ok(())
536 }
537
538 mod hex_test {
539 use super::*;
540
541 pub(super) fn hex_examples() -> Vec<(Digest, &'static str)> {
542 vec![
543 (
544 Digest::default(),
545 concat!(
546 "0000000000000000000000000000000000000000",
547 "0000000000000000000000000000000000000000"
548 ),
549 ),
550 (
551 Digest::new(bfe_array![0, 1, 10, 15, 255]),
552 concat!(
553 "000000000000000001000000000000000a000000",
554 "000000000f00000000000000ff00000000000000"
555 ),
556 ),
557 ]
564 }
565
566 #[test]
567 fn digest_to_hex() {
568 for (digest, hex) in hex_examples() {
569 assert_eq!(&digest.to_hex(), hex);
570 }
571 }
572
573 #[proptest]
574 fn to_hex_and_from_hex_are_reciprocal_proptest(bytes: [u8; Digest::BYTES]) {
575 let digest = Digest::try_from(bytes).unwrap();
576 let hex = digest.to_hex();
577 let digest_again = Digest::try_from_hex(&hex).unwrap();
578 let hex_again = digest_again.to_hex();
579 prop_assert_eq!(digest, digest_again);
580 prop_assert_eq!(hex, hex_again);
581
582 let lower_hex = format!("{digest:x}");
583 let digest_from_lower_hex = Digest::try_from_hex(lower_hex).unwrap();
584 prop_assert_eq!(digest, digest_from_lower_hex);
585
586 let upper_hex = format!("{digest:X}");
587 let digest_from_upper_hex = Digest::try_from_hex(upper_hex).unwrap();
588 prop_assert_eq!(digest, digest_from_upper_hex);
589 }
590
591 #[test]
592 fn to_hex_and_from_hex_are_reciprocal() -> Result<(), TryFromHexDigestError> {
593 let hex_vals = vec![
594 "00000000000000000000000000000000000000000000000000000000000000000000000000000000",
595 "10000000000000000000000000000000000000000000000000000000000000000000000000000000",
596 "0000000000000000000000000000000000000000000000000000000000000000000000000000000f",
597 "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
598 ];
600 for hex in hex_vals {
601 let digest = Digest::try_from_hex(hex)?;
602 assert_eq!(hex, &digest.to_hex())
603 }
604 Ok(())
605 }
606
607 #[test]
608 fn digest_from_hex() -> Result<(), TryFromHexDigestError> {
609 for (digest, hex) in hex_examples() {
610 assert_eq!(digest, Digest::try_from_hex(hex)?);
611 }
612
613 Ok(())
614 }
615
616 #[test]
617 fn digest_from_invalid_hex_errors() {
618 use hex::FromHexError;
619
620 assert!(Digest::try_from_hex("taco").is_err_and(|e| matches!(
621 e,
622 TryFromHexDigestError::HexDecode(FromHexError::InvalidHexCharacter { .. })
623 )));
624
625 assert!(Digest::try_from_hex("0").is_err_and(|e| matches!(
626 e,
627 TryFromHexDigestError::HexDecode(FromHexError::OddLength)
628 )));
629
630 assert!(Digest::try_from_hex("00").is_err_and(|e| matches!(
631 e,
632 TryFromHexDigestError::Digest(TryFromDigestError::InvalidLength(_))
633 )));
634
635 assert!(Digest::try_from_hex(
637 "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"
638 )
639 .is_err_and(|e| matches!(
640 e,
641 TryFromHexDigestError::Digest(TryFromDigestError::InvalidBFieldElement(
642 ParseBFieldElementError::NotCanonical(_)
643 ))
644 )));
645 }
646 }
647
648 mod serde_test {
649 use super::hex_test::hex_examples;
650 use super::*;
651
652 mod json_test {
653 use super::*;
654
655 #[test]
656 fn serialize() -> Result<(), serde_json::Error> {
657 for (digest, hex) in hex_examples() {
658 assert_eq!(serde_json::to_string(&digest)?, format!("\"{hex}\""));
659 }
660 Ok(())
661 }
662
663 #[test]
664 fn deserialize() -> Result<(), serde_json::Error> {
665 for (digest, hex) in hex_examples() {
666 let json_hex = format!("\"{hex}\"");
667 let digest_deserialized: Digest = serde_json::from_str::<Digest>(&json_hex)?;
668 assert_eq!(digest_deserialized, digest);
669 }
670 Ok(())
671 }
672 }
673
674 mod bincode_test {
675 use super::*;
676
677 fn bincode_examples() -> Vec<(Digest, [u8; Digest::BYTES])> {
678 vec![
679 (Digest::default(), [0u8; Digest::BYTES]),
680 (
681 Digest::new(bfe_array![0, 1, 10, 15, 255]),
682 [
683 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0,
684 0, 15, 0, 0, 0, 0, 0, 0, 0, 255, 0, 0, 0, 0, 0, 0, 0,
685 ],
686 ),
687 ]
688 }
689
690 #[test]
691 fn serialize() {
692 for (digest, bytes) in bincode_examples() {
693 assert_eq!(bincode::serialize(&digest).unwrap(), bytes);
694 }
695 }
696
697 #[test]
698 fn deserialize() {
699 for (digest, bytes) in bincode_examples() {
700 assert_eq!(bincode::deserialize::<Digest>(&bytes).unwrap(), digest);
701 }
702 }
703 }
704 }
705}