1use crate::error::{Error, Result};
2use crate::manifest::HotswapManifest;
3use std::collections::HashMap;
4use std::future::Future;
5use std::pin::Pin;
6use url::Url;
7
8#[derive(Debug, Clone)]
11pub struct CheckContext {
12 pub current_sequence: u64,
14 pub binary_version: String,
16 pub platform: &'static str,
18 pub arch: &'static str,
20 pub channel: Option<String>,
22 pub headers: HashMap<String, String>,
26 pub endpoint_override: Option<String>,
29}
30
31pub trait HotswapResolver: Send + Sync + 'static {
37 fn check(
42 &self,
43 ctx: &CheckContext,
44 ) -> Pin<Box<dyn Future<Output = Result<Option<HotswapManifest>>> + Send>>;
45}
46
47pub struct HttpResolver {
59 endpoint: String,
60 client: reqwest::Client,
61 headers: HashMap<String, String>,
62}
63
64impl HttpResolver {
65 pub fn new(endpoint: impl Into<String>) -> Self {
67 Self {
68 endpoint: endpoint.into(),
69 client: reqwest::Client::new(),
70 headers: HashMap::new(),
71 }
72 }
73
74 pub fn with_client(endpoint: impl Into<String>, client: reqwest::Client) -> Self {
76 Self {
77 endpoint: endpoint.into(),
78 client,
79 headers: HashMap::new(),
80 }
81 }
82
83 pub fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
85 self.headers = headers;
86 self
87 }
88
89 pub fn endpoint(&self) -> &str {
91 &self.endpoint
92 }
93}
94
95impl HotswapResolver for HttpResolver {
96 fn check(
97 &self,
98 ctx: &CheckContext,
99 ) -> Pin<Box<dyn Future<Output = Result<Option<HotswapManifest>>> + Send>> {
100 let base = ctx.endpoint_override.as_deref().unwrap_or(&self.endpoint);
101 let raw = base.replace("{{current_sequence}}", &ctx.current_sequence.to_string());
102 let mut parsed =
103 Url::parse(&raw).map_err(|e| Error::Config(format!("invalid endpoint URL: {}", e)));
104 if let Ok(ref mut u) = parsed {
105 u.query_pairs_mut()
106 .append_pair("binary_version", &ctx.binary_version)
107 .append_pair("platform", ctx.platform)
108 .append_pair("arch", ctx.arch);
109 if let Some(ref channel) = ctx.channel {
110 u.query_pairs_mut().append_pair("channel", channel);
111 }
112 }
113 let url = parsed;
114 let client = self.client.clone();
115 let mut headers = self.headers.clone();
117 headers.extend(ctx.headers.clone());
118
119 Box::pin(async move {
120 let url = url?;
121 log::info!("[hotswap] Checking for update at: {}", url);
122
123 let mut req = client
124 .get(url.as_str())
125 .timeout(std::time::Duration::from_secs(15));
126
127 for (key, value) in &headers {
128 req = req.header(key.as_str(), value.as_str());
129 }
130
131 let response = req
132 .send()
133 .await
134 .map_err(|e| Error::Network(e.to_string()))?;
135
136 if response.status().as_u16() == 204 {
137 log::info!("[hotswap] No update available (204)");
138 return Ok(None);
139 }
140
141 if !response.status().is_success() {
142 return Err(Error::Http {
143 status: response.status().as_u16(),
144 message: "update check failed".into(),
145 });
146 }
147
148 let manifest: HotswapManifest = response
149 .json()
150 .await
151 .map_err(|e| Error::InvalidManifest(e.to_string()))?;
152
153 Ok(Some(manifest))
154 })
155 }
156}
157
158pub struct StaticFileResolver {
164 source: String,
165 client: reqwest::Client,
166}
167
168impl StaticFileResolver {
169 pub fn new(source: impl Into<String>) -> Self {
171 Self {
172 source: source.into(),
173 client: reqwest::Client::new(),
174 }
175 }
176
177 pub fn with_client(source: impl Into<String>, client: reqwest::Client) -> Self {
179 Self {
180 source: source.into(),
181 client,
182 }
183 }
184
185 pub fn source(&self) -> &str {
187 &self.source
188 }
189}
190
191impl HotswapResolver for StaticFileResolver {
192 fn check(
193 &self,
194 ctx: &CheckContext,
195 ) -> Pin<Box<dyn Future<Output = Result<Option<HotswapManifest>>> + Send>> {
196 let source = self.source.clone();
197 let client = self.client.clone();
198 let current_sequence = ctx.current_sequence;
199
200 Box::pin(async move {
201 let content = if source.starts_with("http://") || source.starts_with("https://") {
202 client
203 .get(&source)
204 .timeout(std::time::Duration::from_secs(15))
205 .send()
206 .await
207 .map_err(|e| Error::Network(e.to_string()))?
208 .text()
209 .await
210 .map_err(|e| Error::Network(e.to_string()))?
211 } else {
212 tokio::fs::read_to_string(&source)
213 .await
214 .map_err(Error::Io)?
215 };
216
217 let manifest: HotswapManifest = serde_json::from_str(&content)
218 .map_err(|e| Error::InvalidManifest(e.to_string()))?;
219
220 if manifest.sequence <= current_sequence {
221 return Ok(None);
222 }
223
224 Ok(Some(manifest))
225 })
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232 use std::fs;
233 use tempfile::TempDir;
234
235 fn test_ctx(seq: u64) -> CheckContext {
236 CheckContext {
237 current_sequence: seq,
238 binary_version: "1.0.0".into(),
239 platform: "macos",
240 arch: "aarch64",
241 channel: None,
242 headers: HashMap::new(),
243 endpoint_override: None,
244 }
245 }
246
247 fn sample_manifest_json(sequence: u64) -> String {
248 serde_json::json!({
249 "version": format!("1.0.0-ota.{}", sequence),
250 "sequence": sequence,
251 "url": "https://cdn.example.com/ota/bundle.tar.gz",
252 "signature": "untrusted comment: test\nRUTl2E==",
253 "min_binary_version": "1.0.0",
254 "notes": "Test release",
255 "pub_date": "2026-04-05T00:00:00Z"
256 })
257 .to_string()
258 }
259
260 #[tokio::test]
261 async fn test_static_file_resolver_update_available() {
262 let tmp = TempDir::new().unwrap();
263 let manifest_path = tmp.path().join("latest.json");
264 fs::write(&manifest_path, sample_manifest_json(5)).unwrap();
265
266 let resolver = StaticFileResolver::new(manifest_path.to_string_lossy().to_string());
267 let result = resolver.check(&test_ctx(3)).await.unwrap();
268
269 assert!(result.is_some());
270 let manifest = result.unwrap();
271 assert_eq!(manifest.sequence, 5);
272 assert_eq!(manifest.version, "1.0.0-ota.5");
273 }
274
275 #[tokio::test]
276 async fn test_static_file_resolver_no_update_same_sequence() {
277 let tmp = TempDir::new().unwrap();
278 let manifest_path = tmp.path().join("latest.json");
279 fs::write(&manifest_path, sample_manifest_json(5)).unwrap();
280
281 let resolver = StaticFileResolver::new(manifest_path.to_string_lossy().to_string());
282 assert!(resolver.check(&test_ctx(5)).await.unwrap().is_none());
283 }
284
285 #[tokio::test]
286 async fn test_static_file_resolver_no_update_higher_sequence() {
287 let tmp = TempDir::new().unwrap();
288 let manifest_path = tmp.path().join("latest.json");
289 fs::write(&manifest_path, sample_manifest_json(3)).unwrap();
290
291 let resolver = StaticFileResolver::new(manifest_path.to_string_lossy().to_string());
292 assert!(resolver.check(&test_ctx(10)).await.unwrap().is_none());
293 }
294
295 #[tokio::test]
296 async fn test_static_file_resolver_missing_file() {
297 let resolver = StaticFileResolver::new("/nonexistent/path/manifest.json");
298 assert!(resolver.check(&test_ctx(0)).await.is_err());
299 }
300
301 #[tokio::test]
302 async fn test_static_file_resolver_invalid_json() {
303 let tmp = TempDir::new().unwrap();
304 let manifest_path = tmp.path().join("latest.json");
305 fs::write(&manifest_path, "not json at all").unwrap();
306
307 let resolver = StaticFileResolver::new(manifest_path.to_string_lossy().to_string());
308 assert!(resolver.check(&test_ctx(0)).await.is_err());
309 }
310
311 #[tokio::test]
312 async fn test_static_file_resolver_minimal_manifest() {
313 let tmp = TempDir::new().unwrap();
314 let manifest_path = tmp.path().join("latest.json");
315 let json = serde_json::json!({
316 "version": "2.0.0",
317 "sequence": 1,
318 "url": "https://cdn.example.com/bundle.tar.gz",
319 "signature": "sig",
320 "min_binary_version": "1.0.0"
321 })
322 .to_string();
323 fs::write(&manifest_path, json).unwrap();
324
325 let resolver = StaticFileResolver::new(manifest_path.to_string_lossy().to_string());
326 let manifest = resolver.check(&test_ctx(0)).await.unwrap().unwrap();
327 assert_eq!(manifest.version, "2.0.0");
328 assert!(manifest.notes.is_none());
329 assert!(manifest.mandatory.is_none());
330 assert!(manifest.bundle_size.is_none());
331 }
332
333 #[test]
334 fn test_http_resolver_url_substitution() {
335 let resolver = HttpResolver::new("https://example.com/ota/{{current_sequence}}");
336 let url = resolver
337 .endpoint()
338 .replace("{{current_sequence}}", &42.to_string());
339 assert_eq!(url, "https://example.com/ota/42");
340 }
341
342 #[tokio::test]
343 async fn test_manifest_with_mandatory_and_size() {
344 let tmp = TempDir::new().unwrap();
345 let manifest_path = tmp.path().join("latest.json");
346 let json = serde_json::json!({
347 "version": "1.0.1",
348 "sequence": 5,
349 "url": "https://cdn.example.com/bundle.tar.gz",
350 "signature": "sig",
351 "min_binary_version": "1.0.0",
352 "mandatory": true,
353 "bundle_size": 5242880
354 })
355 .to_string();
356 fs::write(&manifest_path, json).unwrap();
357
358 let resolver = StaticFileResolver::new(manifest_path.to_string_lossy().to_string());
359 let manifest = resolver.check(&test_ctx(0)).await.unwrap().unwrap();
360 assert_eq!(manifest.mandatory, Some(true));
361 assert_eq!(manifest.bundle_size, Some(5242880));
362 }
363
364 #[tokio::test]
365 async fn test_manifest_without_optional_fields_defaults_to_none() {
366 let tmp = TempDir::new().unwrap();
367 let manifest_path = tmp.path().join("latest.json");
368 let json = serde_json::json!({
369 "version": "1.0.1",
370 "sequence": 1,
371 "url": "https://cdn.example.com/bundle.tar.gz",
372 "signature": "sig",
373 "min_binary_version": "1.0.0"
374 })
375 .to_string();
376 fs::write(&manifest_path, json).unwrap();
377
378 let resolver = StaticFileResolver::new(manifest_path.to_string_lossy().to_string());
379 let manifest = resolver.check(&test_ctx(0)).await.unwrap().unwrap();
380 assert!(manifest.mandatory.is_none());
381 assert!(manifest.bundle_size.is_none());
382 assert!(manifest.notes.is_none());
383 assert!(manifest.pub_date.is_none());
384 }
385
386 #[test]
387 fn test_check_context_with_channel() {
388 let ctx = CheckContext {
389 current_sequence: 5,
390 binary_version: "1.0.0".into(),
391 platform: "linux",
392 arch: "x86_64",
393 channel: Some("beta".into()),
394 headers: HashMap::new(),
395 endpoint_override: None,
396 };
397 assert_eq!(ctx.channel, Some("beta".to_string()));
398 assert_eq!(ctx.platform, "linux");
399 assert_eq!(ctx.arch, "x86_64");
400 }
401
402 #[test]
403 fn test_http_resolver_with_headers() {
404 let mut headers = HashMap::new();
405 headers.insert("Authorization".into(), "Bearer token123".into());
406 let resolver = HttpResolver::new("https://example.com/ota").with_headers(headers);
407 assert_eq!(resolver.endpoint(), "https://example.com/ota");
408 }
409
410 #[test]
413 fn test_http_resolver_url_sequence_zero() {
414 let endpoint = "https://example.com/ota/{{current_sequence}}/check";
415 let replaced = endpoint.replace("{{current_sequence}}", &0u64.to_string());
416 assert_eq!(replaced, "https://example.com/ota/0/check");
417 }
418
419 #[test]
420 fn test_http_resolver_url_sequence_large() {
421 let endpoint = "https://example.com/ota/{{current_sequence}}";
422 let replaced = endpoint.replace("{{current_sequence}}", &999999u64.to_string());
423 assert_eq!(replaced, "https://example.com/ota/999999");
424 }
425
426 #[test]
427 fn test_http_resolver_endpoint_override_used() {
428 let resolver = HttpResolver::new("https://original.example.com/ota");
429 let ctx = CheckContext {
430 current_sequence: 10,
431 binary_version: "2.0.0".into(),
432 platform: "windows",
433 arch: "x86_64",
434 channel: None,
435 headers: HashMap::new(),
436 endpoint_override: Some("https://override.example.com/v2/{{current_sequence}}".into()),
437 };
438 let base = ctx
439 .endpoint_override
440 .as_deref()
441 .unwrap_or(resolver.endpoint());
442 let raw = base.replace("{{current_sequence}}", &ctx.current_sequence.to_string());
443 assert!(raw.contains("override.example.com"));
444 assert!(raw.contains("10"));
445 assert!(!raw.contains("original.example.com"));
446 }
447
448 #[test]
451 fn test_http_resolver_runtime_headers_override_init() {
452 let mut init_headers = HashMap::new();
453 init_headers.insert("Authorization".into(), "Bearer old".into());
454 init_headers.insert("X-Keep".into(), "kept".into());
455 let resolver = HttpResolver::new("https://example.com/ota").with_headers(init_headers);
456
457 let mut runtime_headers = HashMap::new();
458 runtime_headers.insert("Authorization".into(), "Bearer new".into());
459
460 let mut merged = resolver.headers.clone();
461 merged.extend(runtime_headers);
462
463 assert_eq!(merged.get("Authorization").unwrap(), "Bearer new");
464 assert_eq!(merged.get("X-Keep").unwrap(), "kept");
465 }
466
467 #[tokio::test]
470 async fn test_static_file_resolver_empty_json_object_errors() {
471 let tmp = TempDir::new().unwrap();
472 let manifest_path = tmp.path().join("latest.json");
473 fs::write(&manifest_path, "{}").unwrap();
474
475 let resolver = StaticFileResolver::new(manifest_path.to_string_lossy().to_string());
476 let result = resolver.check(&test_ctx(0)).await;
477 assert!(result.is_err());
478 }
479
480 #[tokio::test]
481 async fn test_static_file_resolver_extra_unknown_fields_ok() {
482 let tmp = TempDir::new().unwrap();
483 let manifest_path = tmp.path().join("latest.json");
484 let json = serde_json::json!({
485 "version": "3.0.0",
486 "sequence": 99,
487 "url": "https://cdn.example.com/bundle.tar.gz",
488 "signature": "sig",
489 "min_binary_version": "1.0.0",
490 "some_future_field": "hello",
491 "another_unknown": 42
492 })
493 .to_string();
494 fs::write(&manifest_path, json).unwrap();
495
496 let resolver = StaticFileResolver::new(manifest_path.to_string_lossy().to_string());
497 let result = resolver.check(&test_ctx(0)).await.unwrap();
498 assert!(result.is_some());
499 assert_eq!(result.unwrap().sequence, 99);
500 }
501}