1use borsh::BorshSerialize;
2use serde::{Deserializer, Serializer};
3use sha2::Digest;
4use std::fmt;
5use std::hash::{Hash, Hasher};
6use std::io::Write;
7
8#[derive(
9 Copy,
10 Clone,
11 PartialEq,
12 Eq,
13 PartialOrd,
14 Ord,
15 derive_more::AsRef,
16 derive_more::AsMut,
17 arbitrary::Arbitrary,
18 borsh::BorshDeserialize,
19 borsh::BorshSerialize,
20)]
21#[as_ref(forward)]
22#[as_mut(forward)]
23pub struct CryptoHash(pub [u8; 32]);
24
25impl CryptoHash {
26 pub const LENGTH: usize = 32;
27
28 pub const fn new() -> Self {
29 Self([0; Self::LENGTH])
30 }
31
32 pub fn hash_bytes(bytes: &[u8]) -> CryptoHash {
34 CryptoHash(sha2::Sha256::digest(bytes).into())
35 }
36
37 pub fn hash_borsh<T: BorshSerialize>(value: T) -> CryptoHash {
44 let mut hasher = sha2::Sha256::default();
45 value.serialize(&mut hasher).unwrap();
46 CryptoHash(hasher.finalize().into())
47 }
48
49 pub fn hash_borsh_iter<I>(values: I) -> CryptoHash
57 where
58 I: IntoIterator,
59 I::IntoIter: ExactSizeIterator,
60 I::Item: BorshSerialize,
61 {
62 let iter = values.into_iter();
63 let n = u32::try_from(iter.len()).unwrap();
64 let mut hasher = sha2::Sha256::default();
65 hasher.write_all(&n.to_le_bytes()).unwrap();
66 let count =
67 iter.inspect(|value| BorshSerialize::serialize(&value, &mut hasher).unwrap()).count();
68 assert_eq!(n as usize, count);
69 CryptoHash(hasher.finalize().into())
70 }
71
72 pub const fn as_bytes(&self) -> &[u8; Self::LENGTH] {
73 &self.0
74 }
75
76 fn to_base58_impl<Out>(self, visitor: impl FnOnce(&str) -> Out) -> Out {
82 let mut buffer = [0u8; 45];
86 let len = bs58::encode(self).into(&mut buffer[..]).unwrap();
87 let value = std::str::from_utf8(&buffer[..len]).unwrap();
88 visitor(value)
89 }
90
91 fn from_base58_impl(encoded: &str) -> Decode58Result {
97 let mut result = Self::new();
98 match bs58::decode(encoded).into(&mut result.0) {
99 Ok(len) if len == result.0.len() => Decode58Result::Ok(result),
100 Ok(_) | Err(bs58::decode::Error::BufferTooSmall) => Decode58Result::BadLength,
101 Err(err) => Decode58Result::Err(err),
102 }
103 }
104}
105
106enum Decode58Result {
108 Ok(CryptoHash),
110 BadLength,
112 Err(bs58::decode::Error),
115}
116
117impl Default for CryptoHash {
118 fn default() -> Self {
119 Self::new()
120 }
121}
122
123impl serde::Serialize for CryptoHash {
124 fn serialize<S>(&self, serializer: S) -> Result<<S as Serializer>::Ok, <S as Serializer>::Error>
125 where
126 S: Serializer,
127 {
128 self.to_base58_impl(|encoded| serializer.serialize_str(encoded))
129 }
130}
131
132struct Visitor;
137
138impl<'de> serde::de::Visitor<'de> for Visitor {
139 type Value = CryptoHash;
140
141 fn expecting(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142 fmt.write_str("base58-encoded 256-bit hash")
143 }
144
145 fn visit_str<E: serde::de::Error>(self, s: &str) -> Result<Self::Value, E> {
146 match CryptoHash::from_base58_impl(s) {
147 Decode58Result::Ok(result) => Ok(result),
148 Decode58Result::BadLength => Err(E::invalid_length(s.len(), &self)),
149 Decode58Result::Err(err) => Err(E::custom(err)),
150 }
151 }
152}
153
154impl<'de> serde::Deserialize<'de> for CryptoHash {
155 fn deserialize<D>(deserializer: D) -> Result<Self, <D as Deserializer<'de>>::Error>
156 where
157 D: Deserializer<'de>,
158 {
159 deserializer.deserialize_str(Visitor)
160 }
161}
162
163impl std::str::FromStr for CryptoHash {
164 type Err = Box<dyn std::error::Error + Send + Sync>;
165
166 fn from_str(encoded: &str) -> Result<Self, Self::Err> {
168 match Self::from_base58_impl(encoded) {
169 Decode58Result::Ok(result) => Ok(result),
170 Decode58Result::BadLength => Err("incorrect length for hash".into()),
171 Decode58Result::Err(err) => Err(err.into()),
172 }
173 }
174}
175
176impl TryFrom<&[u8]> for CryptoHash {
177 type Error = Box<dyn std::error::Error + Send + Sync>;
178
179 fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
180 Ok(CryptoHash(bytes.try_into()?))
181 }
182}
183
184impl From<CryptoHash> for Vec<u8> {
185 fn from(hash: CryptoHash) -> Vec<u8> {
186 hash.0.to_vec()
187 }
188}
189
190impl From<&CryptoHash> for Vec<u8> {
191 fn from(hash: &CryptoHash) -> Vec<u8> {
192 hash.0.to_vec()
193 }
194}
195
196impl From<CryptoHash> for [u8; CryptoHash::LENGTH] {
197 fn from(hash: CryptoHash) -> [u8; CryptoHash::LENGTH] {
198 hash.0
199 }
200}
201
202impl fmt::Debug for CryptoHash {
203 fn fmt(&self, fmtr: &mut fmt::Formatter<'_>) -> fmt::Result {
204 fmt::Display::fmt(self, fmtr)
205 }
206}
207
208impl fmt::Display for CryptoHash {
209 fn fmt(&self, fmtr: &mut fmt::Formatter<'_>) -> fmt::Result {
210 self.to_base58_impl(|encoded| fmtr.write_str(encoded))
211 }
212}
213
214impl Hash for CryptoHash {
217 fn hash<H: Hasher>(&self, state: &mut H) {
218 state.write(self.as_ref());
219 }
220}
221
222pub fn hash(data: &[u8]) -> CryptoHash {
233 CryptoHash::hash_bytes(data)
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239 use std::str::FromStr;
240
241 #[derive(serde::Deserialize, serde::Serialize)]
242 struct Struct {
243 hash: CryptoHash,
244 }
245
246 #[test]
247 fn test_hash_borsh() {
248 fn value<T: BorshSerialize>(want: &str, value: T) {
249 assert_eq!(want, CryptoHash::hash_borsh(&value).to_string());
250 }
251
252 fn slice<T: BorshSerialize>(want: &str, slice: &[T]) {
253 assert_eq!(want, CryptoHash::hash_borsh(slice).to_string());
254 iter(want, slice.iter());
255 iter(want, slice);
256 }
257
258 fn iter<I>(want: &str, iter: I)
259 where
260 I: IntoIterator,
261 I::IntoIter: ExactSizeIterator,
262 I::Item: BorshSerialize,
263 {
264 assert_eq!(want, CryptoHash::hash_borsh_iter(iter).to_string());
265 }
266
267 value("CuoNgQBWsXnTqup6FY3UXNz6RRufnYyQVxx8HKZLUaRt", "foo");
268 value("CuoNgQBWsXnTqup6FY3UXNz6RRufnYyQVxx8HKZLUaRt", "foo".as_bytes());
269 value("CuoNgQBWsXnTqup6FY3UXNz6RRufnYyQVxx8HKZLUaRt", &b"foo"[..]);
270 value("CuoNgQBWsXnTqup6FY3UXNz6RRufnYyQVxx8HKZLUaRt", [3, 0, 0, 0, b'f', b'o', b'o']);
271 slice("CuoNgQBWsXnTqup6FY3UXNz6RRufnYyQVxx8HKZLUaRt", "foo".as_bytes());
272 iter(
273 "CuoNgQBWsXnTqup6FY3UXNz6RRufnYyQVxx8HKZLUaRt",
274 "FOO".bytes().map(|ch| ch.to_ascii_lowercase()),
275 );
276
277 value("3yMApqCuCjXDWPrbjfR5mjCPTHqFG8Pux1TxQrEM35jj", b"foo");
278 value("3yMApqCuCjXDWPrbjfR5mjCPTHqFG8Pux1TxQrEM35jj", [b'f', b'o', b'o']);
279 value("3yMApqCuCjXDWPrbjfR5mjCPTHqFG8Pux1TxQrEM35jj", [b'f', b'o', b'o']);
280 slice("CuoNgQBWsXnTqup6FY3UXNz6RRufnYyQVxx8HKZLUaRt", &[b'f', b'o', b'o']);
281 }
282
283 #[test]
284 fn test_base58_successes() {
285 for (encoded, hash) in [
286 ("11111111111111111111111111111111", CryptoHash::new()),
287 ("CjNSmWXTWhC3EhRVtqLhRmWMTkRbU96wUACqxMtV1uGf", hash(&[0, 1, 2])),
288 ] {
289 assert_eq!(encoded, hash.to_string());
290 assert_eq!(hash, CryptoHash::from_str(encoded).unwrap());
291
292 let json = format!("\"{}\"", encoded);
293 assert_eq!(json, serde_json::to_string(&hash).unwrap());
294 assert_eq!(hash, serde_json::from_str::<CryptoHash>(&json).unwrap());
295 }
296 }
297
298 #[test]
299 fn test_from_str_failures() {
300 fn test(input: &str, want_err: &str) {
301 match CryptoHash::from_str(input) {
302 Ok(got) => panic!("‘{input}’ should have failed; got ‘{got}’"),
303 Err(err) => {
304 assert!(err.to_string().starts_with(want_err), "input: ‘{input}’; err: {err}")
305 }
306 }
307 }
308
309 test("foo-bar-baz", "provided string contained invalid character '-' at byte 3");
311
312 for encoded in &[
314 "CjNSmWXTWhC3ELhRmWMTkRbU96wUACqxMtV1uGf".to_string(),
315 "".to_string(),
316 "1".repeat(31),
317 "1".repeat(33),
318 "1".repeat(1000),
319 ] {
320 test(encoded, "incorrect length for hash");
321 }
322 }
323
324 #[test]
325 fn test_serde_deserialise_failures() {
326 fn test(input: &str, want_err: &str) {
327 match serde_json::from_str::<CryptoHash>(input) {
328 Ok(got) => panic!("‘{input}’ should have failed; got ‘{got}’"),
329 Err(err) => {
330 assert!(err.to_string().starts_with(want_err), "input: ‘{input}’; err: {err}")
331 }
332 }
333 }
334
335 test("\"foo-bar-baz\"", "provided string contained invalid character");
336 for encoded in &[
338 "\"CjNSmWXTWhC3ELhRmWMTkRbU96wUACqxMtV1uGf\"".to_string(),
339 "\"\"".to_string(),
340 format!("\"{}\"", "1".repeat(31)),
341 format!("\"{}\"", "1".repeat(33)),
342 format!("\"{}\"", "1".repeat(1000)),
343 ] {
344 test(encoded, "invalid length");
345 }
346 }
347}