warg_crypto/hash/
static.rs1use digest::generic_array::GenericArray;
2use serde::de::{Error, Visitor};
3use serde::{Deserialize, Deserializer, Serialize, Serializer};
4use thiserror::Error;
5
6use std::fmt;
7
8use crate::{ByteVisitor, VisitBytes};
9
10use super::{Output, SupportedDigest};
11
12#[derive(Default, PartialOrd, Ord)]
13pub struct Hash<D: SupportedDigest> {
14 pub(crate) digest: Output<D>,
15}
16
17struct HashVisitor<D: SupportedDigest> {
18 digest: D,
19}
20
21impl<D> HashVisitor<D>
22where
23 D: SupportedDigest,
24{
25 fn new() -> Self {
26 HashVisitor { digest: D::new() }
27 }
28
29 fn finalize(self) -> Hash<D> {
30 Hash {
31 digest: self.digest.finalize(),
32 }
33 }
34}
35
36impl<D: SupportedDigest> ByteVisitor for HashVisitor<D> {
37 fn visit_bytes(&mut self, bytes: impl AsRef<[u8]>) {
38 self.digest.update(bytes)
39 }
40}
41
42impl<D: SupportedDigest> Hash<D> {
43 pub fn of(content: impl VisitBytes) -> Self {
44 let mut visitor = HashVisitor::new();
45 content.visit(&mut visitor);
46 visitor.finalize()
47 }
48
49 pub fn bytes(&self) -> &[u8] {
50 self.digest.as_slice()
51 }
52
53 #[allow(clippy::len_without_is_empty)]
54 pub fn len(&self) -> usize {
55 self.bytes().len()
56 }
57
58 pub fn bit_len(&self) -> usize {
59 self.bytes().len() * 8
60 }
61}
62
63impl<D: SupportedDigest> VisitBytes for Hash<D> {
64 fn visit<BV: ?Sized + ByteVisitor>(&self, visitor: &mut BV) {
65 visitor.visit_bytes(self.bytes())
66 }
67}
68
69impl<D: SupportedDigest> std::hash::Hash for Hash<D> {
70 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
71 self.digest.hash(state);
72 }
73}
74
75impl<D: SupportedDigest> Clone for Hash<D> {
77 fn clone(&self) -> Self {
78 Self {
79 digest: self.digest.clone(),
80 }
81 }
82}
83
84impl<D: SupportedDigest> Eq for Hash<D> {}
85impl<D: SupportedDigest> PartialEq for Hash<D> {
86 fn eq(&self, other: &Self) -> bool {
87 self.digest == other.digest
88 }
89}
90
91impl<D: SupportedDigest> fmt::Display for Hash<D> {
92 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
93 write!(
94 f,
95 "{}:{}",
96 D::ALGORITHM,
97 hex::encode(self.digest.as_slice())
98 )
99 }
100}
101
102impl<D: SupportedDigest> fmt::Debug for Hash<D> {
103 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
104 write!(
105 f,
106 "Hash<{:?}>({})",
107 D::ALGORITHM,
108 hex::encode(self.digest.as_slice())
109 )
110 }
111}
112
113impl<D: SupportedDigest> From<Output<D>> for Hash<D> {
114 fn from(value: Output<D>) -> Self {
115 Hash { digest: value }
116 }
117}
118
119impl<D: SupportedDigest> TryFrom<Vec<u8>> for Hash<D> {
120 type Error = IncorrectLengthError;
121
122 fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
123 let hash = Hash {
124 digest: GenericArray::from_exact_iter(value.into_iter()).ok_or(IncorrectLengthError)?,
125 };
126 Ok(hash)
127 }
128}
129
130#[derive(Error, Debug, Clone, PartialEq, Eq)]
131#[error("the provided vector was not the correct length")]
132pub struct IncorrectLengthError;
133
134impl<D: SupportedDigest> Serialize for Hash<D> {
135 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
136 serializer.serialize_bytes(&self.digest)
137 }
138}
139
140impl<'de, T: SupportedDigest> Deserialize<'de> for Hash<T> {
141 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
142 struct CopyVisitor<T>(T);
143
144 impl<T: AsRef<[u8]> + AsMut<[u8]>> From<T> for CopyVisitor<T> {
145 fn from(buffer: T) -> Self {
146 Self(buffer)
147 }
148 }
149
150 impl<'a, T: AsRef<[u8]> + AsMut<[u8]>> Visitor<'a> for CopyVisitor<T> {
151 type Value = T;
152
153 fn expecting(&self, formatter: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
154 formatter.write_fmt(format_args!("{} bytes", self.0.as_ref().len()))
155 }
156
157 fn visit_byte_buf<E: Error>(self, v: Vec<u8>) -> Result<Self::Value, E> {
158 self.visit_bytes(&v)
159 }
160
161 fn visit_borrowed_bytes<E: Error>(self, v: &'a [u8]) -> Result<Self::Value, E> {
162 self.visit_bytes(v)
163 }
164
165 fn visit_bytes<E: Error>(mut self, v: &[u8]) -> Result<Self::Value, E> {
166 if v.len() != self.0.as_mut().len() {
167 return Err(E::custom("invalid length"));
168 }
169
170 self.0.as_mut().copy_from_slice(v);
171 Ok(self.0)
172 }
173 }
174
175 let buffer = Output::<T>::default();
176 let visitor = CopyVisitor::from(buffer);
177 Ok(Self {
178 digest: deserializer.deserialize_bytes(visitor)?,
179 })
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use sha2::Sha256;
186
187 use super::*;
188
189 #[test]
190 fn test_hash_empties_have_no_impact() {
191 let empty: &[u8] = &[];
192
193 let h0: Hash<Sha256> = Hash::of((0u8, 1u8));
194 let h1: Hash<Sha256> = Hash::of((0u8, 1u8, empty));
195 let h2: Hash<Sha256> = Hash::of((0u8, empty, 1u8));
196 let h3: Hash<Sha256> = Hash::of((0u8, empty, 1u8, empty));
197 let h4: Hash<Sha256> = Hash::of((empty, 0u8, 1u8));
198 let h5: Hash<Sha256> = Hash::of((empty, 0u8, 1u8, empty));
199 let h6: Hash<Sha256> = Hash::of((empty, 0u8, empty, 1u8));
200 let h7: Hash<Sha256> = Hash::of((empty, 0u8, empty, 1u8, empty));
201
202 assert_eq!(h0, h1);
203 assert_eq!(h0, h2);
204 assert_eq!(h0, h3);
205 assert_eq!(h0, h4);
206 assert_eq!(h0, h5);
207 assert_eq!(h0, h6);
208 assert_eq!(h0, h7);
209 }
210}