wasm_pkg_loader/
release.rs

1use std::path::Path;
2
3use bytes::Bytes;
4use futures_util::{future::ready, stream::once, Stream, StreamExt, TryStream, TryStreamExt};
5use sha2::{Digest, Sha256};
6use tokio::io::AsyncReadExt;
7use wasm_pkg_common::package::Version;
8
9use crate::Error;
10
11#[derive(Clone, Debug)]
12pub struct Release {
13    pub version: Version,
14    pub content_digest: ContentDigest,
15}
16
17#[derive(Clone, Debug, PartialEq)]
18pub enum ContentDigest {
19    Sha256 { hex: String },
20}
21
22impl ContentDigest {
23    pub async fn sha256_from_file(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
24        let mut file = tokio::fs::File::open(path).await?;
25        let mut hasher = Sha256::new();
26        let mut buf = [0; 4096];
27        loop {
28            let n = file.read(&mut buf).await?;
29            if n == 0 {
30                break;
31            }
32            hasher.update(&buf[..n]);
33        }
34        Ok(hasher.into())
35    }
36
37    pub fn validating_stream(
38        &self,
39        stream: impl TryStream<Ok = Bytes, Error = Error>,
40    ) -> impl Stream<Item = Result<Bytes, Error>> {
41        let want = self.clone();
42        stream.map_ok(Some).chain(once(async { Ok(None) })).scan(
43            Sha256::new(),
44            move |hasher, res| {
45                ready(match res {
46                    Ok(Some(bytes)) => {
47                        hasher.update(&bytes);
48                        Some(Ok(bytes))
49                    }
50                    Ok(None) => {
51                        let got: Self = std::mem::take(hasher).into();
52                        if got == want {
53                            None
54                        } else {
55                            Some(Err(Error::InvalidContent(format!(
56                                "expected digest {want}, got {got}"
57                            ))))
58                        }
59                    }
60                    Err(err) => Some(Err(err)),
61                })
62            },
63        )
64    }
65}
66
67impl std::fmt::Display for ContentDigest {
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        match self {
70            ContentDigest::Sha256 { hex } => write!(f, "sha256:{hex}"),
71        }
72    }
73}
74
75impl From<Sha256> for ContentDigest {
76    fn from(hasher: Sha256) -> Self {
77        Self::Sha256 {
78            hex: format!("{:x}", hasher.finalize()),
79        }
80    }
81}
82
83impl<'a> TryFrom<&'a str> for ContentDigest {
84    type Error = Error;
85
86    fn try_from(value: &'a str) -> Result<Self, Self::Error> {
87        let Some(hex) = value.strip_prefix("sha256:") else {
88            return Err(Error::InvalidContentDigest(
89                "must start with 'sha256:'".into(),
90            ));
91        };
92        let hex = hex.to_lowercase();
93        if hex.len() != 64 {
94            return Err(Error::InvalidContentDigest(format!(
95                "must be 64 hex digits; got {} chars",
96                hex.len()
97            )));
98        }
99        if let Some(invalid) = hex.chars().find(|c| !c.is_ascii_hexdigit()) {
100            return Err(Error::InvalidContentDigest(format!(
101                "must be hex; got {invalid:?}"
102            )));
103        }
104        Ok(Self::Sha256 { hex })
105    }
106}
107
108impl std::str::FromStr for ContentDigest {
109    type Err = Error;
110
111    fn from_str(s: &str) -> Result<Self, Self::Err> {
112        s.try_into()
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    use bytes::BytesMut;
119    use futures_util::stream;
120
121    use super::*;
122
123    #[tokio::test]
124    async fn test_validating_stream() {
125        let input = b"input";
126        let digest = ContentDigest::from(Sha256::new_with_prefix(input));
127        let stream = stream::iter(input.chunks(2));
128        let validating = digest.validating_stream(stream.map(|bytes| Ok(bytes.into())));
129        assert_eq!(
130            validating.try_collect::<BytesMut>().await.unwrap(),
131            &input[..]
132        );
133    }
134
135    #[tokio::test]
136    async fn test_invalidating_stream() {
137        let input = b"input";
138        let digest = ContentDigest::Sha256 {
139            hex: "doesn't match anything!".to_string(),
140        };
141        let stream = stream::iter(input.chunks(2));
142        let validating = digest.validating_stream(stream.map(|bytes| Ok(bytes.into())));
143        assert!(matches!(
144            validating.try_collect::<BytesMut>().await,
145            Err(Error::InvalidContent(_)),
146        ));
147    }
148}