Skip to main content

tauri_plugin_hotswap/
resolver.rs

1use crate::error::{Error, Result};
2use crate::manifest::HotswapManifest;
3use std::collections::HashMap;
4use std::future::Future;
5use std::pin::Pin;
6use url::Url;
7
8/// Context passed to resolvers on each check.
9/// Contains runtime information about the current app state.
10#[derive(Debug, Clone)]
11pub struct CheckContext {
12    /// Current monotonic sequence (0 if no version cached).
13    pub current_sequence: u64,
14    /// Binary version from tauri.conf.json.
15    pub binary_version: String,
16    /// Current platform (e.g. "macos", "windows", "linux", "android").
17    pub platform: &'static str,
18    /// Current architecture (e.g. "x86_64", "aarch64").
19    pub arch: &'static str,
20    /// Active channel, if configured.
21    pub channel: Option<String>,
22    /// Runtime HTTP headers (set via `configure()`).
23    /// Merged with any headers baked into the resolver at init time,
24    /// with runtime headers taking precedence.
25    pub headers: HashMap<String, String>,
26    /// Runtime endpoint override (set via `configure()`).
27    /// When set, replaces the endpoint baked into `HttpResolver` at init time.
28    pub endpoint_override: Option<String>,
29}
30
31/// Trait for resolving update availability.
32///
33/// Implement this to use a custom update source (e.g. a static file,
34/// a custom API, or a database). The plugin ships two built-in
35/// implementations: [`HttpResolver`] and [`StaticFileResolver`].
36pub trait HotswapResolver: Send + Sync + 'static {
37    /// Check whether an update is available.
38    ///
39    /// Return `Ok(Some(manifest))` if an update is available,
40    /// `Ok(None)` if not, or `Err(...)` on failure.
41    fn check(
42        &self,
43        ctx: &CheckContext,
44    ) -> Pin<Box<dyn Future<Output = Result<Option<HotswapManifest>>> + Send>>;
45}
46
47/// HTTP-based resolver that calls an endpoint URL.
48///
49/// The URL may contain `{{current_sequence}}` which is replaced at runtime.
50/// Query params `binary_version`, `platform`, `arch`, and `channel` are
51/// appended automatically.
52///
53/// Custom headers (e.g. `Authorization`) are sent on every request.
54///
55/// Expected responses:
56/// - **204 No Content** → no update available
57/// - **200** with JSON body → [`HotswapManifest`]
58pub struct HttpResolver {
59    endpoint: String,
60    client: reqwest::Client,
61    headers: HashMap<String, String>,
62}
63
64impl HttpResolver {
65    /// Create a new resolver with the given endpoint URL template.
66    pub fn new(endpoint: impl Into<String>) -> Self {
67        Self {
68            endpoint: endpoint.into(),
69            client: reqwest::Client::new(),
70            headers: HashMap::new(),
71        }
72    }
73
74    /// Create with a shared `reqwest::Client` for connection pooling.
75    pub fn with_client(endpoint: impl Into<String>, client: reqwest::Client) -> Self {
76        Self {
77            endpoint: endpoint.into(),
78            client,
79            headers: HashMap::new(),
80        }
81    }
82
83    /// Add custom headers sent on every check request.
84    pub fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
85        self.headers = headers;
86        self
87    }
88
89    /// Returns the configured endpoint URL template.
90    pub fn endpoint(&self) -> &str {
91        &self.endpoint
92    }
93}
94
95impl HotswapResolver for HttpResolver {
96    fn check(
97        &self,
98        ctx: &CheckContext,
99    ) -> Pin<Box<dyn Future<Output = Result<Option<HotswapManifest>>> + Send>> {
100        let base = ctx.endpoint_override.as_deref().unwrap_or(&self.endpoint);
101        let raw = base.replace("{{current_sequence}}", &ctx.current_sequence.to_string());
102        let mut parsed =
103            Url::parse(&raw).map_err(|e| Error::Config(format!("invalid endpoint URL: {}", e)));
104        if let Ok(ref mut u) = parsed {
105            u.query_pairs_mut()
106                .append_pair("binary_version", &ctx.binary_version)
107                .append_pair("platform", ctx.platform)
108                .append_pair("arch", ctx.arch);
109            if let Some(ref channel) = ctx.channel {
110                u.query_pairs_mut().append_pair("channel", channel);
111            }
112        }
113        let url = parsed;
114        let client = self.client.clone();
115        // Merge: init-time headers as base, runtime headers override.
116        let mut headers = self.headers.clone();
117        headers.extend(ctx.headers.clone());
118
119        Box::pin(async move {
120            let url = url?;
121            log::info!("[hotswap] Checking for update at: {}", url);
122
123            let mut req = client
124                .get(url.as_str())
125                .timeout(std::time::Duration::from_secs(15));
126
127            for (key, value) in &headers {
128                req = req.header(key.as_str(), value.as_str());
129            }
130
131            let response = req
132                .send()
133                .await
134                .map_err(|e| Error::Network(e.to_string()))?;
135
136            if response.status().as_u16() == 204 {
137                log::info!("[hotswap] No update available (204)");
138                return Ok(None);
139            }
140
141            if !response.status().is_success() {
142                return Err(Error::Http {
143                    status: response.status().as_u16(),
144                    message: "update check failed".into(),
145                });
146            }
147
148            let manifest: HotswapManifest = response
149                .json()
150                .await
151                .map_err(|e| Error::InvalidManifest(e.to_string()))?;
152
153            Ok(Some(manifest))
154        })
155    }
156}
157
158/// Static file resolver that reads a manifest from a local path or URL.
159///
160/// Useful for simple setups without a dynamic server. The manifest is
161/// fetched from the URL (or read from a local file) and compared against
162/// the current sequence.
163pub struct StaticFileResolver {
164    source: String,
165    client: reqwest::Client,
166}
167
168impl StaticFileResolver {
169    /// Create a new resolver with the given source path or URL.
170    pub fn new(source: impl Into<String>) -> Self {
171        Self {
172            source: source.into(),
173            client: reqwest::Client::new(),
174        }
175    }
176
177    /// Create with a shared `reqwest::Client`.
178    pub fn with_client(source: impl Into<String>, client: reqwest::Client) -> Self {
179        Self {
180            source: source.into(),
181            client,
182        }
183    }
184
185    /// Returns the configured source path/URL.
186    pub fn source(&self) -> &str {
187        &self.source
188    }
189}
190
191impl HotswapResolver for StaticFileResolver {
192    fn check(
193        &self,
194        ctx: &CheckContext,
195    ) -> Pin<Box<dyn Future<Output = Result<Option<HotswapManifest>>> + Send>> {
196        let source = self.source.clone();
197        let client = self.client.clone();
198        let current_sequence = ctx.current_sequence;
199
200        Box::pin(async move {
201            let content = if source.starts_with("http://") || source.starts_with("https://") {
202                client
203                    .get(&source)
204                    .timeout(std::time::Duration::from_secs(15))
205                    .send()
206                    .await
207                    .map_err(|e| Error::Network(e.to_string()))?
208                    .text()
209                    .await
210                    .map_err(|e| Error::Network(e.to_string()))?
211            } else {
212                tokio::fs::read_to_string(&source)
213                    .await
214                    .map_err(Error::Io)?
215            };
216
217            let manifest: HotswapManifest = serde_json::from_str(&content)
218                .map_err(|e| Error::InvalidManifest(e.to_string()))?;
219
220            if manifest.sequence <= current_sequence {
221                return Ok(None);
222            }
223
224            Ok(Some(manifest))
225        })
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232    use std::fs;
233    use tempfile::TempDir;
234
235    fn test_ctx(seq: u64) -> CheckContext {
236        CheckContext {
237            current_sequence: seq,
238            binary_version: "1.0.0".into(),
239            platform: "macos",
240            arch: "aarch64",
241            channel: None,
242            headers: HashMap::new(),
243            endpoint_override: None,
244        }
245    }
246
247    fn sample_manifest_json(sequence: u64) -> String {
248        serde_json::json!({
249            "version": format!("1.0.0-ota.{}", sequence),
250            "sequence": sequence,
251            "url": "https://cdn.example.com/ota/bundle.tar.gz",
252            "signature": "untrusted comment: test\nRUTl2E==",
253            "min_binary_version": "1.0.0",
254            "notes": "Test release",
255            "pub_date": "2026-04-05T00:00:00Z"
256        })
257        .to_string()
258    }
259
260    #[tokio::test]
261    async fn test_static_file_resolver_update_available() {
262        let tmp = TempDir::new().unwrap();
263        let manifest_path = tmp.path().join("latest.json");
264        fs::write(&manifest_path, sample_manifest_json(5)).unwrap();
265
266        let resolver = StaticFileResolver::new(manifest_path.to_string_lossy().to_string());
267        let result = resolver.check(&test_ctx(3)).await.unwrap();
268
269        assert!(result.is_some());
270        let manifest = result.unwrap();
271        assert_eq!(manifest.sequence, 5);
272        assert_eq!(manifest.version, "1.0.0-ota.5");
273    }
274
275    #[tokio::test]
276    async fn test_static_file_resolver_no_update_same_sequence() {
277        let tmp = TempDir::new().unwrap();
278        let manifest_path = tmp.path().join("latest.json");
279        fs::write(&manifest_path, sample_manifest_json(5)).unwrap();
280
281        let resolver = StaticFileResolver::new(manifest_path.to_string_lossy().to_string());
282        assert!(resolver.check(&test_ctx(5)).await.unwrap().is_none());
283    }
284
285    #[tokio::test]
286    async fn test_static_file_resolver_no_update_higher_sequence() {
287        let tmp = TempDir::new().unwrap();
288        let manifest_path = tmp.path().join("latest.json");
289        fs::write(&manifest_path, sample_manifest_json(3)).unwrap();
290
291        let resolver = StaticFileResolver::new(manifest_path.to_string_lossy().to_string());
292        assert!(resolver.check(&test_ctx(10)).await.unwrap().is_none());
293    }
294
295    #[tokio::test]
296    async fn test_static_file_resolver_missing_file() {
297        let resolver = StaticFileResolver::new("/nonexistent/path/manifest.json");
298        assert!(resolver.check(&test_ctx(0)).await.is_err());
299    }
300
301    #[tokio::test]
302    async fn test_static_file_resolver_invalid_json() {
303        let tmp = TempDir::new().unwrap();
304        let manifest_path = tmp.path().join("latest.json");
305        fs::write(&manifest_path, "not json at all").unwrap();
306
307        let resolver = StaticFileResolver::new(manifest_path.to_string_lossy().to_string());
308        assert!(resolver.check(&test_ctx(0)).await.is_err());
309    }
310
311    #[tokio::test]
312    async fn test_static_file_resolver_minimal_manifest() {
313        let tmp = TempDir::new().unwrap();
314        let manifest_path = tmp.path().join("latest.json");
315        let json = serde_json::json!({
316            "version": "2.0.0",
317            "sequence": 1,
318            "url": "https://cdn.example.com/bundle.tar.gz",
319            "signature": "sig",
320            "min_binary_version": "1.0.0"
321        })
322        .to_string();
323        fs::write(&manifest_path, json).unwrap();
324
325        let resolver = StaticFileResolver::new(manifest_path.to_string_lossy().to_string());
326        let manifest = resolver.check(&test_ctx(0)).await.unwrap().unwrap();
327        assert_eq!(manifest.version, "2.0.0");
328        assert!(manifest.notes.is_none());
329        assert!(manifest.mandatory.is_none());
330        assert!(manifest.bundle_size.is_none());
331    }
332
333    #[test]
334    fn test_http_resolver_url_substitution() {
335        let resolver = HttpResolver::new("https://example.com/ota/{{current_sequence}}");
336        let url = resolver
337            .endpoint()
338            .replace("{{current_sequence}}", &42.to_string());
339        assert_eq!(url, "https://example.com/ota/42");
340    }
341
342    #[tokio::test]
343    async fn test_manifest_with_mandatory_and_size() {
344        let tmp = TempDir::new().unwrap();
345        let manifest_path = tmp.path().join("latest.json");
346        let json = serde_json::json!({
347            "version": "1.0.1",
348            "sequence": 5,
349            "url": "https://cdn.example.com/bundle.tar.gz",
350            "signature": "sig",
351            "min_binary_version": "1.0.0",
352            "mandatory": true,
353            "bundle_size": 5242880
354        })
355        .to_string();
356        fs::write(&manifest_path, json).unwrap();
357
358        let resolver = StaticFileResolver::new(manifest_path.to_string_lossy().to_string());
359        let manifest = resolver.check(&test_ctx(0)).await.unwrap().unwrap();
360        assert_eq!(manifest.mandatory, Some(true));
361        assert_eq!(manifest.bundle_size, Some(5242880));
362    }
363
364    #[tokio::test]
365    async fn test_manifest_without_optional_fields_defaults_to_none() {
366        let tmp = TempDir::new().unwrap();
367        let manifest_path = tmp.path().join("latest.json");
368        let json = serde_json::json!({
369            "version": "1.0.1",
370            "sequence": 1,
371            "url": "https://cdn.example.com/bundle.tar.gz",
372            "signature": "sig",
373            "min_binary_version": "1.0.0"
374        })
375        .to_string();
376        fs::write(&manifest_path, json).unwrap();
377
378        let resolver = StaticFileResolver::new(manifest_path.to_string_lossy().to_string());
379        let manifest = resolver.check(&test_ctx(0)).await.unwrap().unwrap();
380        assert!(manifest.mandatory.is_none());
381        assert!(manifest.bundle_size.is_none());
382        assert!(manifest.notes.is_none());
383        assert!(manifest.pub_date.is_none());
384    }
385
386    #[test]
387    fn test_check_context_with_channel() {
388        let ctx = CheckContext {
389            current_sequence: 5,
390            binary_version: "1.0.0".into(),
391            platform: "linux",
392            arch: "x86_64",
393            channel: Some("beta".into()),
394            headers: HashMap::new(),
395            endpoint_override: None,
396        };
397        assert_eq!(ctx.channel, Some("beta".to_string()));
398        assert_eq!(ctx.platform, "linux");
399        assert_eq!(ctx.arch, "x86_64");
400    }
401
402    #[test]
403    fn test_http_resolver_with_headers() {
404        let mut headers = HashMap::new();
405        headers.insert("Authorization".into(), "Bearer token123".into());
406        let resolver = HttpResolver::new("https://example.com/ota").with_headers(headers);
407        assert_eq!(resolver.endpoint(), "https://example.com/ota");
408    }
409
410    // ── HttpResolver URL building ──────────────────────────────────────
411
412    #[test]
413    fn test_http_resolver_url_sequence_zero() {
414        let endpoint = "https://example.com/ota/{{current_sequence}}/check";
415        let replaced = endpoint.replace("{{current_sequence}}", &0u64.to_string());
416        assert_eq!(replaced, "https://example.com/ota/0/check");
417    }
418
419    #[test]
420    fn test_http_resolver_url_sequence_large() {
421        let endpoint = "https://example.com/ota/{{current_sequence}}";
422        let replaced = endpoint.replace("{{current_sequence}}", &999999u64.to_string());
423        assert_eq!(replaced, "https://example.com/ota/999999");
424    }
425
426    #[test]
427    fn test_http_resolver_endpoint_override_used() {
428        let resolver = HttpResolver::new("https://original.example.com/ota");
429        let ctx = CheckContext {
430            current_sequence: 10,
431            binary_version: "2.0.0".into(),
432            platform: "windows",
433            arch: "x86_64",
434            channel: None,
435            headers: HashMap::new(),
436            endpoint_override: Some("https://override.example.com/v2/{{current_sequence}}".into()),
437        };
438        let base = ctx
439            .endpoint_override
440            .as_deref()
441            .unwrap_or(resolver.endpoint());
442        let raw = base.replace("{{current_sequence}}", &ctx.current_sequence.to_string());
443        assert!(raw.contains("override.example.com"));
444        assert!(raw.contains("10"));
445        assert!(!raw.contains("original.example.com"));
446    }
447
448    // ── HttpResolver header merging ────────────────────────────────────
449
450    #[test]
451    fn test_http_resolver_runtime_headers_override_init() {
452        let mut init_headers = HashMap::new();
453        init_headers.insert("Authorization".into(), "Bearer old".into());
454        init_headers.insert("X-Keep".into(), "kept".into());
455        let resolver = HttpResolver::new("https://example.com/ota").with_headers(init_headers);
456
457        let mut runtime_headers = HashMap::new();
458        runtime_headers.insert("Authorization".into(), "Bearer new".into());
459
460        let mut merged = resolver.headers.clone();
461        merged.extend(runtime_headers);
462
463        assert_eq!(merged.get("Authorization").unwrap(), "Bearer new");
464        assert_eq!(merged.get("X-Keep").unwrap(), "kept");
465    }
466
467    // ── StaticFileResolver edge cases ──────────────────────────────────
468
469    #[tokio::test]
470    async fn test_static_file_resolver_empty_json_object_errors() {
471        let tmp = TempDir::new().unwrap();
472        let manifest_path = tmp.path().join("latest.json");
473        fs::write(&manifest_path, "{}").unwrap();
474
475        let resolver = StaticFileResolver::new(manifest_path.to_string_lossy().to_string());
476        let result = resolver.check(&test_ctx(0)).await;
477        assert!(result.is_err());
478    }
479
480    #[tokio::test]
481    async fn test_static_file_resolver_extra_unknown_fields_ok() {
482        let tmp = TempDir::new().unwrap();
483        let manifest_path = tmp.path().join("latest.json");
484        let json = serde_json::json!({
485            "version": "3.0.0",
486            "sequence": 99,
487            "url": "https://cdn.example.com/bundle.tar.gz",
488            "signature": "sig",
489            "min_binary_version": "1.0.0",
490            "some_future_field": "hello",
491            "another_unknown": 42
492        })
493        .to_string();
494        fs::write(&manifest_path, json).unwrap();
495
496        let resolver = StaticFileResolver::new(manifest_path.to_string_lossy().to_string());
497        let result = resolver.check(&test_ctx(0)).await.unwrap();
498        assert!(result.is_some());
499        assert_eq!(result.unwrap().sequence, 99);
500    }
501}