warg_crypto/hash/
static.rs

1use digest::generic_array::GenericArray;
2use serde::de::{Error, Visitor};
3use serde::{Deserialize, Deserializer, Serialize, Serializer};
4use thiserror::Error;
5
6use std::fmt;
7
8use crate::{ByteVisitor, VisitBytes};
9
10use super::{Output, SupportedDigest};
11
12#[derive(Default, PartialOrd, Ord)]
13pub struct Hash<D: SupportedDigest> {
14    pub(crate) digest: Output<D>,
15}
16
17struct HashVisitor<D: SupportedDigest> {
18    digest: D,
19}
20
21impl<D> HashVisitor<D>
22where
23    D: SupportedDigest,
24{
25    fn new() -> Self {
26        HashVisitor { digest: D::new() }
27    }
28
29    fn finalize(self) -> Hash<D> {
30        Hash {
31            digest: self.digest.finalize(),
32        }
33    }
34}
35
36impl<D: SupportedDigest> ByteVisitor for HashVisitor<D> {
37    fn visit_bytes(&mut self, bytes: impl AsRef<[u8]>) {
38        self.digest.update(bytes)
39    }
40}
41
42impl<D: SupportedDigest> Hash<D> {
43    pub fn of(content: impl VisitBytes) -> Self {
44        let mut visitor = HashVisitor::new();
45        content.visit(&mut visitor);
46        visitor.finalize()
47    }
48
49    pub fn bytes(&self) -> &[u8] {
50        self.digest.as_slice()
51    }
52
53    #[allow(clippy::len_without_is_empty)]
54    pub fn len(&self) -> usize {
55        self.bytes().len()
56    }
57
58    pub fn bit_len(&self) -> usize {
59        self.bytes().len() * 8
60    }
61}
62
63impl<D: SupportedDigest> VisitBytes for Hash<D> {
64    fn visit<BV: ?Sized + ByteVisitor>(&self, visitor: &mut BV) {
65        visitor.visit_bytes(self.bytes())
66    }
67}
68
69impl<D: SupportedDigest> std::hash::Hash for Hash<D> {
70    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
71        self.digest.hash(state);
72    }
73}
74
75// Derived clone does not have precise enough bounds and type info.
76impl<D: SupportedDigest> Clone for Hash<D> {
77    fn clone(&self) -> Self {
78        Self {
79            digest: self.digest.clone(),
80        }
81    }
82}
83
84impl<D: SupportedDigest> Eq for Hash<D> {}
85impl<D: SupportedDigest> PartialEq for Hash<D> {
86    fn eq(&self, other: &Self) -> bool {
87        self.digest == other.digest
88    }
89}
90
91impl<D: SupportedDigest> fmt::Display for Hash<D> {
92    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
93        write!(
94            f,
95            "{}:{}",
96            D::ALGORITHM,
97            hex::encode(self.digest.as_slice())
98        )
99    }
100}
101
102impl<D: SupportedDigest> fmt::Debug for Hash<D> {
103    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
104        write!(
105            f,
106            "Hash<{:?}>({})",
107            D::ALGORITHM,
108            hex::encode(self.digest.as_slice())
109        )
110    }
111}
112
113impl<D: SupportedDigest> From<Output<D>> for Hash<D> {
114    fn from(value: Output<D>) -> Self {
115        Hash { digest: value }
116    }
117}
118
119impl<D: SupportedDigest> TryFrom<Vec<u8>> for Hash<D> {
120    type Error = IncorrectLengthError;
121
122    fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
123        let hash = Hash {
124            digest: GenericArray::from_exact_iter(value.into_iter()).ok_or(IncorrectLengthError)?,
125        };
126        Ok(hash)
127    }
128}
129
130#[derive(Error, Debug, Clone, PartialEq, Eq)]
131#[error("the provided vector was not the correct length")]
132pub struct IncorrectLengthError;
133
134impl<D: SupportedDigest> Serialize for Hash<D> {
135    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
136        serializer.serialize_bytes(&self.digest)
137    }
138}
139
140impl<'de, T: SupportedDigest> Deserialize<'de> for Hash<T> {
141    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
142        struct CopyVisitor<T>(T);
143
144        impl<T: AsRef<[u8]> + AsMut<[u8]>> From<T> for CopyVisitor<T> {
145            fn from(buffer: T) -> Self {
146                Self(buffer)
147            }
148        }
149
150        impl<'a, T: AsRef<[u8]> + AsMut<[u8]>> Visitor<'a> for CopyVisitor<T> {
151            type Value = T;
152
153            fn expecting(&self, formatter: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
154                formatter.write_fmt(format_args!("{} bytes", self.0.as_ref().len()))
155            }
156
157            fn visit_byte_buf<E: Error>(self, v: Vec<u8>) -> Result<Self::Value, E> {
158                self.visit_bytes(&v)
159            }
160
161            fn visit_borrowed_bytes<E: Error>(self, v: &'a [u8]) -> Result<Self::Value, E> {
162                self.visit_bytes(v)
163            }
164
165            fn visit_bytes<E: Error>(mut self, v: &[u8]) -> Result<Self::Value, E> {
166                if v.len() != self.0.as_mut().len() {
167                    return Err(E::custom("invalid length"));
168                }
169
170                self.0.as_mut().copy_from_slice(v);
171                Ok(self.0)
172            }
173        }
174
175        let buffer = Output::<T>::default();
176        let visitor = CopyVisitor::from(buffer);
177        Ok(Self {
178            digest: deserializer.deserialize_bytes(visitor)?,
179        })
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use sha2::Sha256;
186
187    use super::*;
188
189    #[test]
190    fn test_hash_empties_have_no_impact() {
191        let empty: &[u8] = &[];
192
193        let h0: Hash<Sha256> = Hash::of((0u8, 1u8));
194        let h1: Hash<Sha256> = Hash::of((0u8, 1u8, empty));
195        let h2: Hash<Sha256> = Hash::of((0u8, empty, 1u8));
196        let h3: Hash<Sha256> = Hash::of((0u8, empty, 1u8, empty));
197        let h4: Hash<Sha256> = Hash::of((empty, 0u8, 1u8));
198        let h5: Hash<Sha256> = Hash::of((empty, 0u8, 1u8, empty));
199        let h6: Hash<Sha256> = Hash::of((empty, 0u8, empty, 1u8));
200        let h7: Hash<Sha256> = Hash::of((empty, 0u8, empty, 1u8, empty));
201
202        assert_eq!(h0, h1);
203        assert_eq!(h0, h2);
204        assert_eq!(h0, h3);
205        assert_eq!(h0, h4);
206        assert_eq!(h0, h5);
207        assert_eq!(h0, h6);
208        assert_eq!(h0, h7);
209    }
210}