1use crate::cache::CacheManager;
2use crate::error::{Error, Result};
3use crate::manifest::{GlobalManifest, RuntimeManifest};
4use crate::runtime::{Language, Runtime};
5use reqwest::Client;
6use std::path::PathBuf;
7
8#[cfg(feature = "progress")]
9use futures_util::StreamExt;
10
11const GITHUB_RELEASES_BASE: &str = "https://github.com/anistark/wasmhub/releases/download";
12const JSDELIVR_BASE: &str = "https://cdn.jsdelivr.net/gh/anistark/wasmhub@latest";
13
14#[derive(Debug, Clone)]
15pub enum CdnSource {
16 GitHubReleases,
17 JsDelivr,
18}
19
20impl CdnSource {
21 fn base_url(&self) -> &'static str {
22 match self {
23 CdnSource::GitHubReleases => GITHUB_RELEASES_BASE,
24 CdnSource::JsDelivr => JSDELIVR_BASE,
25 }
26 }
27}
28
29pub struct RuntimeLoader {
30 cache: CacheManager,
31 client: Client,
32 cdn_sources: Vec<CdnSource>,
33 #[cfg(feature = "progress")]
34 show_progress: bool,
35}
36
37impl RuntimeLoader {
38 pub fn new() -> Result<Self> {
39 Ok(Self {
40 cache: CacheManager::new()?,
41 client: Client::new(),
42 cdn_sources: vec![CdnSource::GitHubReleases, CdnSource::JsDelivr],
43 #[cfg(feature = "progress")]
44 show_progress: false,
45 })
46 }
47
48 pub fn builder() -> RuntimeLoaderBuilder {
49 RuntimeLoaderBuilder::default()
50 }
51
52 pub async fn get_runtime(&self, language: Language, version: &str) -> Result<Runtime> {
53 if let Some(runtime) = self.cache.get(language, version) {
54 return Ok(runtime);
55 }
56
57 self.download_runtime(language, version).await
58 }
59
60 pub async fn download_runtime(&self, language: Language, version: &str) -> Result<Runtime> {
61 let manifest = self.fetch_runtime_manifest(language).await?;
62 let version_info = manifest
63 .get_version(version)
64 .ok_or_else(|| Error::VersionNotFound {
65 language: language.to_string(),
66 version: version.to_string(),
67 })?;
68
69 let mut last_error = None;
70 for source in &self.cdn_sources {
71 let url = self.build_download_url(source, language, version);
72 match self.download_from_url(&url).await {
73 Ok(data) => {
74 let computed_hash = self.compute_hash(&data);
75 if computed_hash != version_info.sha256 {
76 return Err(Error::IntegrityCheckFailed {
77 expected: version_info.sha256.clone(),
78 actual: computed_hash,
79 });
80 }
81
82 return self.cache.store(language, version, &data);
83 }
84 Err(e) => {
85 last_error = Some(e);
86 continue;
87 }
88 }
89 }
90
91 Err(last_error.unwrap_or_else(|| Error::Other("All CDN sources failed".to_string())))
92 }
93
94 fn build_download_url(&self, source: &CdnSource, language: Language, version: &str) -> String {
95 match source {
96 CdnSource::GitHubReleases => {
97 format!(
98 "{}/v{}/{}-{}.wasm",
99 source.base_url(),
100 version,
101 language.as_str(),
102 version
103 )
104 }
105 CdnSource::JsDelivr => {
106 format!(
107 "{}/runtimes/{}/{}.wasm",
108 source.base_url(),
109 language.as_str(),
110 version
111 )
112 }
113 }
114 }
115
116 async fn download_from_url(&self, url: &str) -> Result<Vec<u8>> {
117 #[cfg(feature = "progress")]
118 if self.show_progress {
119 return self.download_with_progress(url).await;
120 }
121
122 let response = self.client.get(url).send().await?;
123 if !response.status().is_success() {
124 return Err(Error::Network(response.error_for_status().unwrap_err()));
125 }
126
127 let bytes = response.bytes().await?;
128 Ok(bytes.to_vec())
129 }
130
131 #[cfg(feature = "progress")]
132 async fn download_with_progress(&self, url: &str) -> Result<Vec<u8>> {
133 use indicatif::{ProgressBar, ProgressStyle};
134
135 let response = self.client.get(url).send().await?;
136 if !response.status().is_success() {
137 return Err(Error::Network(response.error_for_status().unwrap_err()));
138 }
139
140 let total_size = response.content_length().unwrap_or(0);
141 let pb = ProgressBar::new(total_size);
142 pb.set_style(
143 ProgressStyle::default_bar()
144 .template("{msg}\n{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({eta})")
145 .unwrap()
146 .progress_chars("#>-"),
147 );
148 pb.set_message(format!("Downloading {url}"));
149
150 let mut downloaded: u64 = 0;
151 let mut stream = response.bytes_stream();
152 let mut data = Vec::new();
153
154 while let Some(chunk) = stream.next().await {
155 let chunk = chunk?;
156 data.extend_from_slice(&chunk);
157 downloaded += chunk.len() as u64;
158 pb.set_position(downloaded);
159 }
160
161 pb.finish_with_message("Download complete");
162 Ok(data)
163 }
164
165 fn compute_hash(&self, data: &[u8]) -> String {
166 use sha2::{Digest, Sha256};
167 let mut hasher = Sha256::new();
168 hasher.update(data);
169 format!("{:x}", hasher.finalize())
170 }
171
172 pub async fn list_available(&self) -> Result<GlobalManifest> {
173 self.fetch_global_manifest().await
174 }
175
176 pub async fn get_latest_version(&self, language: Language) -> Result<String> {
177 let manifest = self.fetch_global_manifest().await?;
178 let runtime_info =
179 manifest
180 .get_language(language.as_str())
181 .ok_or_else(|| Error::ManifestNotFound {
182 language: language.to_string(),
183 })?;
184 Ok(runtime_info.latest.clone())
185 }
186
187 pub fn clear_cache(&self, language: Language, version: &str) -> Result<()> {
188 self.cache.clear(language, version)
189 }
190
191 pub fn clear_all_cache(&self) -> Result<()> {
192 self.cache.clear_all()
193 }
194
195 pub fn list_cached(&self) -> Result<Vec<Runtime>> {
196 self.cache.list()
197 }
198
199 async fn fetch_global_manifest(&self) -> Result<GlobalManifest> {
200 let mut last_error = None;
201 for source in &self.cdn_sources {
202 let url = match source {
203 CdnSource::GitHubReleases => {
204 format!("{}/latest/manifest.json", source.base_url())
205 }
206 CdnSource::JsDelivr => {
207 format!("{}/manifest.json", source.base_url())
208 }
209 };
210
211 match self.fetch_json(&url).await {
212 Ok(manifest) => return Ok(manifest),
213 Err(e) => {
214 last_error = Some(e);
215 continue;
216 }
217 }
218 }
219
220 Err(last_error.unwrap_or_else(|| Error::Other("Failed to fetch manifest".to_string())))
221 }
222
223 pub async fn fetch_runtime_manifest(&self, language: Language) -> Result<RuntimeManifest> {
224 let mut last_error = None;
225 for source in &self.cdn_sources {
226 let url = match source {
227 CdnSource::GitHubReleases => {
228 format!(
229 "{}/latest/runtimes/{}/manifest.json",
230 source.base_url(),
231 language.as_str()
232 )
233 }
234 CdnSource::JsDelivr => {
235 format!(
236 "{}/runtimes/{}/manifest.json",
237 source.base_url(),
238 language.as_str()
239 )
240 }
241 };
242
243 match self.fetch_json(&url).await {
244 Ok(manifest) => return Ok(manifest),
245 Err(e) => {
246 last_error = Some(e);
247 continue;
248 }
249 }
250 }
251
252 Err(last_error.unwrap_or_else(|| Error::ManifestNotFound {
253 language: language.to_string(),
254 }))
255 }
256
257 async fn fetch_json<T: serde::de::DeserializeOwned>(&self, url: &str) -> Result<T> {
258 let response = self.client.get(url).send().await?;
259 if !response.status().is_success() {
260 return Err(Error::Network(response.error_for_status().unwrap_err()));
261 }
262 let json = response.json().await?;
263 Ok(json)
264 }
265}
266
267impl Default for RuntimeLoader {
268 fn default() -> Self {
269 Self::new().expect("Failed to create RuntimeLoader")
270 }
271}
272
273#[derive(Default)]
274pub struct RuntimeLoaderBuilder {
275 cache_dir: Option<PathBuf>,
276 cdn_sources: Option<Vec<CdnSource>>,
277 #[cfg(feature = "progress")]
278 show_progress: bool,
279}
280
281impl RuntimeLoaderBuilder {
282 pub fn new() -> Self {
283 Self::default()
284 }
285
286 pub fn cache_dir(mut self, path: PathBuf) -> Self {
287 self.cache_dir = Some(path);
288 self
289 }
290
291 pub fn cdn_sources(mut self, sources: Vec<CdnSource>) -> Self {
292 self.cdn_sources = Some(sources);
293 self
294 }
295
296 #[cfg(feature = "progress")]
297 pub fn show_progress(mut self, show: bool) -> Self {
298 self.show_progress = show;
299 self
300 }
301
302 pub fn build(self) -> Result<RuntimeLoader> {
303 let cache = if let Some(cache_dir) = self.cache_dir {
304 CacheManager::with_cache_dir(cache_dir)
305 } else {
306 CacheManager::new()?
307 };
308
309 Ok(RuntimeLoader {
310 cache,
311 client: Client::new(),
312 cdn_sources: self
313 .cdn_sources
314 .unwrap_or_else(|| vec![CdnSource::GitHubReleases, CdnSource::JsDelivr]),
315 #[cfg(feature = "progress")]
316 show_progress: self.show_progress,
317 })
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324
325 #[test]
326 fn test_cdn_source_base_url() {
327 assert_eq!(
328 CdnSource::GitHubReleases.base_url(),
329 "https://github.com/anistark/wasmhub/releases/download"
330 );
331 assert_eq!(
332 CdnSource::JsDelivr.base_url(),
333 "https://cdn.jsdelivr.net/gh/anistark/wasmhub@latest"
334 );
335 }
336
337 #[test]
338 fn test_build_download_url() {
339 let loader = RuntimeLoader::new().unwrap();
340
341 let url = loader.build_download_url(&CdnSource::GitHubReleases, Language::Python, "3.11.7");
342 assert!(url.contains("releases/download"));
343 assert!(url.contains("python-3.11.7.wasm"));
344
345 let url = loader.build_download_url(&CdnSource::JsDelivr, Language::Python, "3.11.7");
346 assert!(url.contains("cdn.jsdelivr.net"));
347 assert!(url.contains("runtimes/python/3.11.7.wasm"));
348 }
349
350 #[test]
351 fn test_compute_hash() {
352 let loader = RuntimeLoader::new().unwrap();
353 let data = b"test data";
354 let hash = loader.compute_hash(data);
355 assert_eq!(hash.len(), 64);
356 }
357
358 #[test]
359 fn test_builder() {
360 let loader = RuntimeLoader::builder()
361 .cdn_sources(vec![CdnSource::GitHubReleases])
362 .build()
363 .unwrap();
364
365 assert_eq!(loader.cdn_sources.len(), 1);
366 }
367
368 #[test]
369 fn test_builder_with_cache_dir() {
370 use tempfile::TempDir;
371 let temp_dir = TempDir::new().unwrap();
372
373 let loader = RuntimeLoader::builder()
374 .cache_dir(temp_dir.path().to_path_buf())
375 .build()
376 .unwrap();
377
378 assert!(loader
379 .cache
380 .get_path(Language::Python, "3.11.7")
381 .starts_with(temp_dir.path()));
382 }
383}