Skip to main content

zeph_scheduler/
update_check.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::future::Future;
5use std::pin::Pin;
6
7use semver::Version;
8use serde::Deserialize;
9use tokio::sync::mpsc;
10
11use crate::error::SchedulerError;
12use crate::task::TaskHandler;
13
14const GITHUB_RELEASES_URL: &str = "https://api.github.com/repos/bug-ops/zeph/releases/latest";
15const MAX_RESPONSE_BYTES: usize = 64 * 1024;
16
17/// [`TaskHandler`] that polls the GitHub releases API for a newer Zeph version.
18///
19/// On each execution, `UpdateCheckHandler` fetches the latest release from
20/// `https://api.github.com/repos/bug-ops/zeph/releases/latest`, compares the
21/// `tag_name` field against `current_version` using semantic versioning, and sends a
22/// human-readable notification message on `notify_tx` when a newer release is found.
23///
24/// Network and parse errors are logged as warnings; `execute` always returns `Ok(())`
25/// so a transient failure does not stop the scheduler.
26///
27/// # Examples
28///
29/// ```rust,no_run
30/// use tokio::sync::mpsc;
31/// use zeph_scheduler::UpdateCheckHandler;
32///
33/// # #[tokio::main]
34/// # async fn main() {
35/// let (tx, mut rx) = mpsc::channel(1);
36/// let handler = UpdateCheckHandler::new(env!("CARGO_PKG_VERSION"), tx);
37///
38/// use zeph_scheduler::TaskHandler;
39/// handler
40///     .execute(&serde_json::Value::Null)
41///     .await
42///     .expect("update check should not fail");
43///
44/// // A notification is sent only when a newer version exists on GitHub.
45/// if let Ok(msg) = rx.try_recv() {
46///     println!("{msg}");
47/// }
48/// # }
49/// ```
50pub struct UpdateCheckHandler {
51    current_version: &'static str,
52    notify_tx: mpsc::Sender<String>,
53    http_client: reqwest::Client,
54    /// Base URL for the GitHub releases API. Configurable for testing.
55    base_url: String,
56}
57
58#[derive(Deserialize)]
59struct ReleaseInfo {
60    tag_name: Option<String>,
61}
62
63impl UpdateCheckHandler {
64    /// Create a new handler.
65    ///
66    /// `current_version` should be `env!("CARGO_PKG_VERSION")`.
67    /// Notifications are sent as formatted strings via `notify_tx`.
68    ///
69    /// # Panics
70    ///
71    /// Panics if the underlying `reqwest` client cannot be constructed (unreachable in practice).
72    #[must_use]
73    pub fn new(current_version: &'static str, notify_tx: mpsc::Sender<String>) -> Self {
74        let http_client = reqwest::Client::builder()
75            .timeout(std::time::Duration::from_secs(10))
76            .user_agent(format!("zeph/{current_version}"))
77            .build()
78            .expect("reqwest client builder should not fail with timeout and user_agent");
79        Self {
80            current_version,
81            notify_tx,
82            http_client,
83            base_url: GITHUB_RELEASES_URL.to_owned(),
84        }
85    }
86
87    /// Override the GitHub releases API URL.
88    ///
89    /// Intended for tests only: point the handler at a local mock server so the
90    /// test does not make real network requests.
91    #[must_use]
92    pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
93        self.base_url = url.into();
94        self
95    }
96
97    /// Extract and compare versions; returns `Some(remote_version_str)` when remote > current.
98    fn newer_version(current: &str, tag_name: &str) -> Option<String> {
99        let remote_str = tag_name.trim_start_matches('v');
100        if remote_str.is_empty() {
101            return None;
102        }
103        let current_v = Version::parse(current).ok()?;
104        let remote_v = Version::parse(remote_str).ok()?;
105        if remote_v > current_v {
106            Some(remote_str.to_owned())
107        } else {
108            None
109        }
110    }
111}
112
113impl TaskHandler for UpdateCheckHandler {
114    fn execute(
115        &self,
116        _config: &serde_json::Value,
117    ) -> Pin<Box<dyn Future<Output = Result<(), SchedulerError>> + Send + '_>> {
118        Box::pin(async move {
119            let resp = self
120                .http_client
121                .get(&self.base_url)
122                .header("Accept", "application/vnd.github+json")
123                .send()
124                .await;
125
126            let resp = match resp {
127                Ok(r) => r,
128                Err(e) => {
129                    tracing::warn!("update check request failed: {e}");
130                    return Ok(());
131                }
132            };
133
134            if !resp.status().is_success() {
135                tracing::warn!("update check: HTTP {}", resp.status());
136                return Ok(());
137            }
138
139            let bytes = match resp.bytes().await {
140                Ok(b) => b,
141                Err(e) => {
142                    tracing::warn!("update check: failed to read response body: {e}");
143                    return Ok(());
144                }
145            };
146            if bytes.len() > MAX_RESPONSE_BYTES {
147                tracing::warn!(
148                    "update check: response body too large ({} bytes), skipping",
149                    bytes.len()
150                );
151                return Ok(());
152            }
153            let info: ReleaseInfo = match serde_json::from_slice(&bytes) {
154                Ok(v) => v,
155                Err(e) => {
156                    tracing::warn!("update check response parse failed: {e}");
157                    return Ok(());
158                }
159            };
160
161            let Some(tag_name) = info.tag_name else {
162                tracing::warn!("update check: missing tag_name in response");
163                return Ok(());
164            };
165
166            match Self::newer_version(self.current_version, &tag_name) {
167                Some(remote) => {
168                    let msg = format!(
169                        "New version available: v{remote} (current: v{}).\nUpdate: https://github.com/bug-ops/zeph/releases/tag/v{remote}",
170                        self.current_version
171                    );
172                    tracing::debug!("update available: {remote}");
173                    let _ = self.notify_tx.send(msg).await;
174                }
175                None => {
176                    tracing::debug!(
177                        current = self.current_version,
178                        remote = tag_name,
179                        "no update available"
180                    );
181                }
182            }
183
184            Ok(())
185        })
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use wiremock::matchers::{method, path};
192    use wiremock::{Mock, MockServer, ResponseTemplate};
193
194    use super::*;
195
196    fn make_handler(
197        current_version: &'static str,
198        tx: mpsc::Sender<String>,
199        server_url: &str,
200    ) -> UpdateCheckHandler {
201        UpdateCheckHandler::new(current_version, tx).with_base_url(server_url)
202    }
203
204    #[test]
205    fn newer_version_detects_upgrade() {
206        assert_eq!(
207            UpdateCheckHandler::newer_version("0.11.0", "v0.12.0"),
208            Some("0.12.0".to_owned())
209        );
210    }
211
212    #[test]
213    fn newer_version_same_version_no_notify() {
214        assert_eq!(UpdateCheckHandler::newer_version("0.11.0", "v0.11.0"), None);
215    }
216
217    #[test]
218    fn newer_version_older_remote_no_notify() {
219        assert_eq!(UpdateCheckHandler::newer_version("0.11.0", "v0.10.0"), None);
220    }
221
222    #[test]
223    fn newer_version_strips_v_prefix() {
224        assert_eq!(
225            UpdateCheckHandler::newer_version("1.0.0", "v2.0.0"),
226            Some("2.0.0".to_owned())
227        );
228        assert_eq!(
229            UpdateCheckHandler::newer_version("1.0.0", "2.0.0"),
230            Some("2.0.0".to_owned())
231        );
232    }
233
234    #[test]
235    fn newer_version_invalid_current_returns_none() {
236        assert_eq!(
237            UpdateCheckHandler::newer_version("not-semver", "v1.0.0"),
238            None
239        );
240    }
241
242    #[test]
243    fn newer_version_invalid_remote_returns_none() {
244        assert_eq!(
245            UpdateCheckHandler::newer_version("1.0.0", "v-garbage"),
246            None
247        );
248    }
249
250    #[test]
251    fn newer_version_empty_tag_returns_none() {
252        assert_eq!(UpdateCheckHandler::newer_version("1.0.0", ""), None);
253    }
254
255    // Prerelease versions (e.g. 0.12.0-rc.1) compare as greater than 0.11.0 per semver spec.
256    // This is intentional: users should be notified of release candidates if they appear
257    // on the GitHub releases/latest endpoint (which typically only returns stable releases).
258    #[test]
259    fn newer_version_prerelease_is_notified() {
260        assert_eq!(
261            UpdateCheckHandler::newer_version("0.11.0", "v0.12.0-rc.1"),
262            Some("0.12.0-rc.1".to_owned())
263        );
264    }
265
266    #[tokio::test]
267    async fn test_execute_newer_version_sends_notification() {
268        let server = MockServer::start().await;
269        Mock::given(method("GET"))
270            .and(path("/"))
271            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
272                "tag_name": "v99.0.0"
273            })))
274            .mount(&server)
275            .await;
276
277        let (tx, mut rx) = mpsc::channel(1);
278        let handler = make_handler("0.11.0", tx, &server.uri());
279
280        handler
281            .execute(&serde_json::Value::Null)
282            .await
283            .expect("handler must not return an error");
284
285        let msg = rx.try_recv().expect("notification must be sent");
286        assert!(
287            msg.contains("99.0.0"),
288            "notification should mention new version"
289        );
290        assert!(
291            msg.contains("0.11.0"),
292            "notification should mention current version"
293        );
294    }
295
296    #[tokio::test]
297    async fn test_execute_same_version_no_notification() {
298        let server = MockServer::start().await;
299        Mock::given(method("GET"))
300            .and(path("/"))
301            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
302                "tag_name": "v0.11.0"
303            })))
304            .mount(&server)
305            .await;
306
307        let (tx, mut rx) = mpsc::channel(1);
308        let handler = make_handler("0.11.0", tx, &server.uri());
309
310        handler
311            .execute(&serde_json::Value::Null)
312            .await
313            .expect("handler must not return an error");
314
315        assert!(
316            rx.try_recv().is_err(),
317            "no notification expected for same version"
318        );
319    }
320
321    #[tokio::test]
322    async fn test_execute_http_404_no_notification_no_panic() {
323        let server = MockServer::start().await;
324        Mock::given(method("GET"))
325            .and(path("/"))
326            .respond_with(ResponseTemplate::new(404))
327            .mount(&server)
328            .await;
329
330        let (tx, mut rx) = mpsc::channel(1);
331        let handler = make_handler("0.11.0", tx, &server.uri());
332
333        let result = handler.execute(&serde_json::Value::Null).await;
334        assert!(result.is_ok(), "handler must return Ok on 404");
335        assert!(rx.try_recv().is_err(), "no notification expected on 404");
336    }
337
338    #[tokio::test]
339    async fn test_execute_http_429_rate_limit_graceful() {
340        let server = MockServer::start().await;
341        Mock::given(method("GET"))
342            .and(path("/"))
343            .respond_with(ResponseTemplate::new(429))
344            .mount(&server)
345            .await;
346
347        let (tx, mut rx) = mpsc::channel(1);
348        let handler = make_handler("0.11.0", tx, &server.uri());
349
350        let result = handler.execute(&serde_json::Value::Null).await;
351        assert!(result.is_ok(), "handler must return Ok on 429");
352        assert!(rx.try_recv().is_err(), "no notification expected on 429");
353    }
354
355    #[tokio::test]
356    async fn test_execute_http_500_server_error_graceful() {
357        let server = MockServer::start().await;
358        Mock::given(method("GET"))
359            .and(path("/"))
360            .respond_with(ResponseTemplate::new(500))
361            .mount(&server)
362            .await;
363
364        let (tx, mut rx) = mpsc::channel(1);
365        let handler = make_handler("0.11.0", tx, &server.uri());
366
367        let result = handler.execute(&serde_json::Value::Null).await;
368        assert!(result.is_ok(), "handler must return Ok on 500");
369        assert!(rx.try_recv().is_err(), "no notification expected on 500");
370    }
371
372    #[tokio::test]
373    async fn test_execute_malformed_json_graceful() {
374        let server = MockServer::start().await;
375        Mock::given(method("GET"))
376            .and(path("/"))
377            .respond_with(ResponseTemplate::new(200).set_body_string("this is not json {{{"))
378            .mount(&server)
379            .await;
380
381        let (tx, mut rx) = mpsc::channel(1);
382        let handler = make_handler("0.11.0", tx, &server.uri());
383
384        let result = handler.execute(&serde_json::Value::Null).await;
385        assert!(result.is_ok(), "handler must return Ok on malformed JSON");
386        assert!(
387            rx.try_recv().is_err(),
388            "no notification expected for malformed JSON"
389        );
390    }
391
392    #[tokio::test]
393    async fn test_execute_missing_tag_name_graceful() {
394        let server = MockServer::start().await;
395        Mock::given(method("GET"))
396            .and(path("/"))
397            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
398                "name": "Latest Release",
399                "published_at": "2024-01-01"
400            })))
401            .mount(&server)
402            .await;
403
404        let (tx, mut rx) = mpsc::channel(1);
405        let handler = make_handler("0.11.0", tx, &server.uri());
406
407        let result = handler.execute(&serde_json::Value::Null).await;
408        assert!(result.is_ok(), "handler must return Ok on missing tag_name");
409        assert!(
410            rx.try_recv().is_err(),
411            "no notification expected for missing tag_name"
412        );
413    }
414
415    #[tokio::test]
416    async fn test_execute_oversized_body_graceful() {
417        let server = MockServer::start().await;
418        // Body larger than MAX_RESPONSE_BYTES (64 KB): 65 537 bytes
419        let large_body = "x".repeat(MAX_RESPONSE_BYTES + 1);
420        Mock::given(method("GET"))
421            .and(path("/"))
422            .respond_with(ResponseTemplate::new(200).set_body_string(large_body))
423            .mount(&server)
424            .await;
425
426        let (tx, mut rx) = mpsc::channel(1);
427        let handler = make_handler("0.11.0", tx, &server.uri());
428
429        let result = handler.execute(&serde_json::Value::Null).await;
430        assert!(result.is_ok(), "handler must return Ok for oversized body");
431        assert!(
432            rx.try_recv().is_err(),
433            "no notification expected for oversized body"
434        );
435    }
436}