wasm_pkg_common/
digest.rs1#[cfg(feature = "tokio")]
2use std::path::Path;
3use std::str::FromStr;
4
5use bytes::Bytes;
6use futures_util::{future::ready, stream::once, Stream, StreamExt, TryStream, TryStreamExt};
7use serde::{Deserialize, Serialize};
8use sha2::{Digest, Sha256};
9
10use crate::Error;
11
12#[derive(Clone, Debug, PartialEq, Eq)]
14pub enum ContentDigest {
15 Sha256 { hex: String },
16}
17
18impl ContentDigest {
19 #[cfg(feature = "tokio")]
20 pub async fn sha256_from_file(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
21 use tokio::io::AsyncReadExt;
22 let mut file = tokio::fs::File::open(path).await?;
23 let mut hasher = Sha256::new();
24 let mut buf = [0; 4096];
25 loop {
26 let n = file.read(&mut buf).await?;
27 if n == 0 {
28 break;
29 }
30 hasher.update(&buf[..n]);
31 }
32 Ok(hasher.into())
33 }
34
35 pub fn validating_stream(
36 &self,
37 stream: impl TryStream<Ok = Bytes, Error = Error>,
38 ) -> impl Stream<Item = Result<Bytes, Error>> {
39 let want = self.clone();
40 stream.map_ok(Some).chain(once(async { Ok(None) })).scan(
41 Sha256::new(),
42 move |hasher, res| {
43 ready(match res {
44 Ok(Some(bytes)) => {
45 hasher.update(&bytes);
46 Some(Ok(bytes))
47 }
48 Ok(None) => {
49 let got: Self = std::mem::take(hasher).into();
50 if got == want {
51 None
52 } else {
53 Some(Err(Error::InvalidContent(format!(
54 "expected digest {want}, got {got}"
55 ))))
56 }
57 }
58 Err(err) => Some(Err(err)),
59 })
60 },
61 )
62 }
63}
64
65impl std::fmt::Display for ContentDigest {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 match self {
68 ContentDigest::Sha256 { hex } => write!(f, "sha256:{hex}"),
69 }
70 }
71}
72
73impl From<Sha256> for ContentDigest {
74 fn from(hasher: Sha256) -> Self {
75 Self::Sha256 {
76 hex: format!("{:x}", hasher.finalize()),
77 }
78 }
79}
80
81impl<'a> TryFrom<&'a str> for ContentDigest {
82 type Error = Error;
83
84 fn try_from(value: &'a str) -> Result<Self, Self::Error> {
85 let Some(hex) = value.strip_prefix("sha256:") else {
86 return Err(Error::InvalidContentDigest(
87 "must start with 'sha256:'".into(),
88 ));
89 };
90 let hex = hex.to_lowercase();
91 if hex.len() != 64 {
92 return Err(Error::InvalidContentDigest(format!(
93 "must be 64 hex digits; got {} chars",
94 hex.len()
95 )));
96 }
97 if let Some(invalid) = hex.chars().find(|c| !c.is_ascii_hexdigit()) {
98 return Err(Error::InvalidContentDigest(format!(
99 "must be hex; got {invalid:?}"
100 )));
101 }
102 Ok(Self::Sha256 { hex })
103 }
104}
105
106impl FromStr for ContentDigest {
107 type Err = Error;
108
109 fn from_str(s: &str) -> Result<Self, Self::Err> {
110 s.try_into()
111 }
112}
113
114impl Serialize for ContentDigest {
115 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
116 serializer.serialize_str(&self.to_string())
117 }
118}
119
120impl<'de> Deserialize<'de> for ContentDigest {
121 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
122 where
123 D: serde::Deserializer<'de>,
124 {
125 Self::from_str(&String::deserialize(deserializer)?).map_err(serde::de::Error::custom)
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use bytes::BytesMut;
132 use futures_util::stream;
133
134 use super::*;
135
136 #[tokio::test]
137 async fn test_validating_stream() {
138 let input = b"input";
139 let digest = ContentDigest::from(Sha256::new_with_prefix(input));
140 let stream = stream::iter(input.chunks(2));
141 let validating = digest.validating_stream(stream.map(|bytes| Ok(bytes.into())));
142 assert_eq!(
143 validating.try_collect::<BytesMut>().await.unwrap(),
144 &input[..]
145 );
146 }
147
148 #[tokio::test]
149 async fn test_invalidating_stream() {
150 let input = b"input";
151 let digest = ContentDigest::Sha256 {
152 hex: "doesn't match anything!".to_string(),
153 };
154 let stream = stream::iter(input.chunks(2));
155 let validating = digest.validating_stream(stream.map(|bytes| Ok(bytes.into())));
156 assert!(matches!(
157 validating.try_collect::<BytesMut>().await,
158 Err(Error::InvalidContent(_)),
159 ));
160 }
161}