wasm_pkg_loader/
release.rs1use std::path::Path;
2
3use bytes::Bytes;
4use futures_util::{future::ready, stream::once, Stream, StreamExt, TryStream, TryStreamExt};
5use sha2::{Digest, Sha256};
6use tokio::io::AsyncReadExt;
7use wasm_pkg_common::package::Version;
8
9use crate::Error;
10
11#[derive(Clone, Debug)]
12pub struct Release {
13 pub version: Version,
14 pub content_digest: ContentDigest,
15}
16
17#[derive(Clone, Debug, PartialEq)]
18pub enum ContentDigest {
19 Sha256 { hex: String },
20}
21
22impl ContentDigest {
23 pub async fn sha256_from_file(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
24 let mut file = tokio::fs::File::open(path).await?;
25 let mut hasher = Sha256::new();
26 let mut buf = [0; 4096];
27 loop {
28 let n = file.read(&mut buf).await?;
29 if n == 0 {
30 break;
31 }
32 hasher.update(&buf[..n]);
33 }
34 Ok(hasher.into())
35 }
36
37 pub fn validating_stream(
38 &self,
39 stream: impl TryStream<Ok = Bytes, Error = Error>,
40 ) -> impl Stream<Item = Result<Bytes, Error>> {
41 let want = self.clone();
42 stream.map_ok(Some).chain(once(async { Ok(None) })).scan(
43 Sha256::new(),
44 move |hasher, res| {
45 ready(match res {
46 Ok(Some(bytes)) => {
47 hasher.update(&bytes);
48 Some(Ok(bytes))
49 }
50 Ok(None) => {
51 let got: Self = std::mem::take(hasher).into();
52 if got == want {
53 None
54 } else {
55 Some(Err(Error::InvalidContent(format!(
56 "expected digest {want}, got {got}"
57 ))))
58 }
59 }
60 Err(err) => Some(Err(err)),
61 })
62 },
63 )
64 }
65}
66
67impl std::fmt::Display for ContentDigest {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 match self {
70 ContentDigest::Sha256 { hex } => write!(f, "sha256:{hex}"),
71 }
72 }
73}
74
75impl From<Sha256> for ContentDigest {
76 fn from(hasher: Sha256) -> Self {
77 Self::Sha256 {
78 hex: format!("{:x}", hasher.finalize()),
79 }
80 }
81}
82
83impl<'a> TryFrom<&'a str> for ContentDigest {
84 type Error = Error;
85
86 fn try_from(value: &'a str) -> Result<Self, Self::Error> {
87 let Some(hex) = value.strip_prefix("sha256:") else {
88 return Err(Error::InvalidContentDigest(
89 "must start with 'sha256:'".into(),
90 ));
91 };
92 let hex = hex.to_lowercase();
93 if hex.len() != 64 {
94 return Err(Error::InvalidContentDigest(format!(
95 "must be 64 hex digits; got {} chars",
96 hex.len()
97 )));
98 }
99 if let Some(invalid) = hex.chars().find(|c| !c.is_ascii_hexdigit()) {
100 return Err(Error::InvalidContentDigest(format!(
101 "must be hex; got {invalid:?}"
102 )));
103 }
104 Ok(Self::Sha256 { hex })
105 }
106}
107
108impl std::str::FromStr for ContentDigest {
109 type Err = Error;
110
111 fn from_str(s: &str) -> Result<Self, Self::Err> {
112 s.try_into()
113 }
114}
115
116#[cfg(test)]
117mod tests {
118 use bytes::BytesMut;
119 use futures_util::stream;
120
121 use super::*;
122
123 #[tokio::test]
124 async fn test_validating_stream() {
125 let input = b"input";
126 let digest = ContentDigest::from(Sha256::new_with_prefix(input));
127 let stream = stream::iter(input.chunks(2));
128 let validating = digest.validating_stream(stream.map(|bytes| Ok(bytes.into())));
129 assert_eq!(
130 validating.try_collect::<BytesMut>().await.unwrap(),
131 &input[..]
132 );
133 }
134
135 #[tokio::test]
136 async fn test_invalidating_stream() {
137 let input = b"input";
138 let digest = ContentDigest::Sha256 {
139 hex: "doesn't match anything!".to_string(),
140 };
141 let stream = stream::iter(input.chunks(2));
142 let validating = digest.validating_stream(stream.map(|bytes| Ok(bytes.into())));
143 assert!(matches!(
144 validating.try_collect::<BytesMut>().await,
145 Err(Error::InvalidContent(_)),
146 ));
147 }
148}