wasmhub/
loader.rs

1use crate::cache::CacheManager;
2use crate::error::{Error, Result};
3use crate::manifest::{GlobalManifest, RuntimeManifest};
4use crate::runtime::{Language, Runtime};
5use reqwest::Client;
6use std::path::PathBuf;
7
8#[cfg(feature = "progress")]
9use futures_util::StreamExt;
10
11const GITHUB_RELEASES_BASE: &str = "https://github.com/anistark/wasmhub/releases/download";
12const JSDELIVR_BASE: &str = "https://cdn.jsdelivr.net/gh/anistark/wasmhub@latest";
13
14#[derive(Debug, Clone)]
15pub enum CdnSource {
16    GitHubReleases,
17    JsDelivr,
18}
19
20impl CdnSource {
21    fn base_url(&self) -> &'static str {
22        match self {
23            CdnSource::GitHubReleases => GITHUB_RELEASES_BASE,
24            CdnSource::JsDelivr => JSDELIVR_BASE,
25        }
26    }
27}
28
29pub struct RuntimeLoader {
30    cache: CacheManager,
31    client: Client,
32    cdn_sources: Vec<CdnSource>,
33    #[cfg(feature = "progress")]
34    show_progress: bool,
35}
36
37impl RuntimeLoader {
38    pub fn new() -> Result<Self> {
39        Ok(Self {
40            cache: CacheManager::new()?,
41            client: Client::new(),
42            cdn_sources: vec![CdnSource::GitHubReleases, CdnSource::JsDelivr],
43            #[cfg(feature = "progress")]
44            show_progress: false,
45        })
46    }
47
48    pub fn builder() -> RuntimeLoaderBuilder {
49        RuntimeLoaderBuilder::default()
50    }
51
52    pub async fn get_runtime(&self, language: Language, version: &str) -> Result<Runtime> {
53        if let Some(runtime) = self.cache.get(language, version) {
54            return Ok(runtime);
55        }
56
57        self.download_runtime(language, version).await
58    }
59
60    pub async fn download_runtime(&self, language: Language, version: &str) -> Result<Runtime> {
61        let manifest = self.fetch_runtime_manifest(language).await?;
62        let version_info = manifest
63            .get_version(version)
64            .ok_or_else(|| Error::VersionNotFound {
65                language: language.to_string(),
66                version: version.to_string(),
67            })?;
68
69        let mut last_error = None;
70        for source in &self.cdn_sources {
71            let url = self.build_download_url(source, language, version);
72            match self.download_from_url(&url).await {
73                Ok(data) => {
74                    let computed_hash = self.compute_hash(&data);
75                    if computed_hash != version_info.sha256 {
76                        return Err(Error::IntegrityCheckFailed {
77                            expected: version_info.sha256.clone(),
78                            actual: computed_hash,
79                        });
80                    }
81
82                    return self.cache.store(language, version, &data);
83                }
84                Err(e) => {
85                    last_error = Some(e);
86                    continue;
87                }
88            }
89        }
90
91        Err(last_error.unwrap_or_else(|| Error::Other("All CDN sources failed".to_string())))
92    }
93
94    fn build_download_url(&self, source: &CdnSource, language: Language, version: &str) -> String {
95        match source {
96            CdnSource::GitHubReleases => {
97                format!(
98                    "{}/v{}/{}-{}.wasm",
99                    source.base_url(),
100                    version,
101                    language.as_str(),
102                    version
103                )
104            }
105            CdnSource::JsDelivr => {
106                format!(
107                    "{}/runtimes/{}/{}.wasm",
108                    source.base_url(),
109                    language.as_str(),
110                    version
111                )
112            }
113        }
114    }
115
116    async fn download_from_url(&self, url: &str) -> Result<Vec<u8>> {
117        #[cfg(feature = "progress")]
118        if self.show_progress {
119            return self.download_with_progress(url).await;
120        }
121
122        let response = self.client.get(url).send().await?;
123        if !response.status().is_success() {
124            return Err(Error::Network(response.error_for_status().unwrap_err()));
125        }
126
127        let bytes = response.bytes().await?;
128        Ok(bytes.to_vec())
129    }
130
131    #[cfg(feature = "progress")]
132    async fn download_with_progress(&self, url: &str) -> Result<Vec<u8>> {
133        use indicatif::{ProgressBar, ProgressStyle};
134
135        let response = self.client.get(url).send().await?;
136        if !response.status().is_success() {
137            return Err(Error::Network(response.error_for_status().unwrap_err()));
138        }
139
140        let total_size = response.content_length().unwrap_or(0);
141        let pb = ProgressBar::new(total_size);
142        pb.set_style(
143            ProgressStyle::default_bar()
144                .template("{msg}\n{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({eta})")
145                .unwrap()
146                .progress_chars("#>-"),
147        );
148        pb.set_message(format!("Downloading {url}"));
149
150        let mut downloaded: u64 = 0;
151        let mut stream = response.bytes_stream();
152        let mut data = Vec::new();
153
154        while let Some(chunk) = stream.next().await {
155            let chunk = chunk?;
156            data.extend_from_slice(&chunk);
157            downloaded += chunk.len() as u64;
158            pb.set_position(downloaded);
159        }
160
161        pb.finish_with_message("Download complete");
162        Ok(data)
163    }
164
165    fn compute_hash(&self, data: &[u8]) -> String {
166        use sha2::{Digest, Sha256};
167        let mut hasher = Sha256::new();
168        hasher.update(data);
169        format!("{:x}", hasher.finalize())
170    }
171
172    pub async fn list_available(&self) -> Result<GlobalManifest> {
173        self.fetch_global_manifest().await
174    }
175
176    pub async fn get_latest_version(&self, language: Language) -> Result<String> {
177        let manifest = self.fetch_global_manifest().await?;
178        let runtime_info =
179            manifest
180                .get_language(language.as_str())
181                .ok_or_else(|| Error::ManifestNotFound {
182                    language: language.to_string(),
183                })?;
184        Ok(runtime_info.latest.clone())
185    }
186
187    pub fn clear_cache(&self, language: Language, version: &str) -> Result<()> {
188        self.cache.clear(language, version)
189    }
190
191    pub fn clear_all_cache(&self) -> Result<()> {
192        self.cache.clear_all()
193    }
194
195    pub fn list_cached(&self) -> Result<Vec<Runtime>> {
196        self.cache.list()
197    }
198
199    async fn fetch_global_manifest(&self) -> Result<GlobalManifest> {
200        let mut last_error = None;
201        for source in &self.cdn_sources {
202            let url = match source {
203                CdnSource::GitHubReleases => {
204                    format!("{}/latest/manifest.json", source.base_url())
205                }
206                CdnSource::JsDelivr => {
207                    format!("{}/manifest.json", source.base_url())
208                }
209            };
210
211            match self.fetch_json(&url).await {
212                Ok(manifest) => return Ok(manifest),
213                Err(e) => {
214                    last_error = Some(e);
215                    continue;
216                }
217            }
218        }
219
220        Err(last_error.unwrap_or_else(|| Error::Other("Failed to fetch manifest".to_string())))
221    }
222
223    pub async fn fetch_runtime_manifest(&self, language: Language) -> Result<RuntimeManifest> {
224        let mut last_error = None;
225        for source in &self.cdn_sources {
226            let url = match source {
227                CdnSource::GitHubReleases => {
228                    format!(
229                        "{}/latest/runtimes/{}/manifest.json",
230                        source.base_url(),
231                        language.as_str()
232                    )
233                }
234                CdnSource::JsDelivr => {
235                    format!(
236                        "{}/runtimes/{}/manifest.json",
237                        source.base_url(),
238                        language.as_str()
239                    )
240                }
241            };
242
243            match self.fetch_json(&url).await {
244                Ok(manifest) => return Ok(manifest),
245                Err(e) => {
246                    last_error = Some(e);
247                    continue;
248                }
249            }
250        }
251
252        Err(last_error.unwrap_or_else(|| Error::ManifestNotFound {
253            language: language.to_string(),
254        }))
255    }
256
257    async fn fetch_json<T: serde::de::DeserializeOwned>(&self, url: &str) -> Result<T> {
258        let response = self.client.get(url).send().await?;
259        if !response.status().is_success() {
260            return Err(Error::Network(response.error_for_status().unwrap_err()));
261        }
262        let json = response.json().await?;
263        Ok(json)
264    }
265}
266
267impl Default for RuntimeLoader {
268    fn default() -> Self {
269        Self::new().expect("Failed to create RuntimeLoader")
270    }
271}
272
273#[derive(Default)]
274pub struct RuntimeLoaderBuilder {
275    cache_dir: Option<PathBuf>,
276    cdn_sources: Option<Vec<CdnSource>>,
277    #[cfg(feature = "progress")]
278    show_progress: bool,
279}
280
281impl RuntimeLoaderBuilder {
282    pub fn new() -> Self {
283        Self::default()
284    }
285
286    pub fn cache_dir(mut self, path: PathBuf) -> Self {
287        self.cache_dir = Some(path);
288        self
289    }
290
291    pub fn cdn_sources(mut self, sources: Vec<CdnSource>) -> Self {
292        self.cdn_sources = Some(sources);
293        self
294    }
295
296    #[cfg(feature = "progress")]
297    pub fn show_progress(mut self, show: bool) -> Self {
298        self.show_progress = show;
299        self
300    }
301
302    pub fn build(self) -> Result<RuntimeLoader> {
303        let cache = if let Some(cache_dir) = self.cache_dir {
304            CacheManager::with_cache_dir(cache_dir)
305        } else {
306            CacheManager::new()?
307        };
308
309        Ok(RuntimeLoader {
310            cache,
311            client: Client::new(),
312            cdn_sources: self
313                .cdn_sources
314                .unwrap_or_else(|| vec![CdnSource::GitHubReleases, CdnSource::JsDelivr]),
315            #[cfg(feature = "progress")]
316            show_progress: self.show_progress,
317        })
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    #[test]
326    fn test_cdn_source_base_url() {
327        assert_eq!(
328            CdnSource::GitHubReleases.base_url(),
329            "https://github.com/anistark/wasmhub/releases/download"
330        );
331        assert_eq!(
332            CdnSource::JsDelivr.base_url(),
333            "https://cdn.jsdelivr.net/gh/anistark/wasmhub@latest"
334        );
335    }
336
337    #[test]
338    fn test_build_download_url() {
339        let loader = RuntimeLoader::new().unwrap();
340
341        let url = loader.build_download_url(&CdnSource::GitHubReleases, Language::Python, "3.11.7");
342        assert!(url.contains("releases/download"));
343        assert!(url.contains("python-3.11.7.wasm"));
344
345        let url = loader.build_download_url(&CdnSource::JsDelivr, Language::Python, "3.11.7");
346        assert!(url.contains("cdn.jsdelivr.net"));
347        assert!(url.contains("runtimes/python/3.11.7.wasm"));
348    }
349
350    #[test]
351    fn test_compute_hash() {
352        let loader = RuntimeLoader::new().unwrap();
353        let data = b"test data";
354        let hash = loader.compute_hash(data);
355        assert_eq!(hash.len(), 64);
356    }
357
358    #[test]
359    fn test_builder() {
360        let loader = RuntimeLoader::builder()
361            .cdn_sources(vec![CdnSource::GitHubReleases])
362            .build()
363            .unwrap();
364
365        assert_eq!(loader.cdn_sources.len(), 1);
366    }
367
368    #[test]
369    fn test_builder_with_cache_dir() {
370        use tempfile::TempDir;
371        let temp_dir = TempDir::new().unwrap();
372
373        let loader = RuntimeLoader::builder()
374            .cache_dir(temp_dir.path().to_path_buf())
375            .build()
376            .unwrap();
377
378        assert!(loader
379            .cache
380            .get_path(Language::Python, "3.11.7")
381            .starts_with(temp_dir.path()));
382    }
383}