Skip to main content

studio_worker/
update.rs

1//! Auto-update: poll a GitHub Releases feed, download cargo-dist's
2//! platform installer when a newer semver is available, and re-exec
3//! ourselves so the new binary takes over.
4//!
5//! The update task in `runtime.rs` only invokes us when the worker is
6//! idle (no job in flight) so generation runs never get killed mid-flow.
7//!
8//! All side-effecting bits (HTTP, filesystem writes, process spawn) flow
9//! through testable helpers; see `apply_with` for the seam.
10use crate::types::GithubRelease;
11use anyhow::{anyhow, bail, Context, Result};
12use semver::Version;
13use std::path::{Path, PathBuf};
14use std::time::{Duration, Instant};
15use tracing::{debug, info, warn};
16
17/// Tracing target used for every event emitted by the updater. Operators
18/// can filter the auto-update breadcrumbs in isolation with
19/// `RUST_LOG=studio_worker::update=debug`.
20const TRACE_TARGET: &str = "studio_worker::update";
21
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub enum CheckOutcome {
24    UpToDate { current: Version },
25    NewerAvailable { current: Version, latest: Version },
26}
27
28/// Resolve the feed URL to a JSON document and parse a release list.
29pub fn fetch_releases(feed_url: &str) -> Result<Vec<GithubRelease>> {
30    let client = reqwest::blocking::Client::builder()
31        .timeout(Duration::from_secs(15))
32        .user_agent(concat!("studio-worker/", env!("CARGO_PKG_VERSION")))
33        .build()
34        .context("building reqwest client")?;
35    let started = Instant::now();
36    let response = client
37        .get(feed_url)
38        .header("accept", "application/vnd.github+json")
39        .send()
40        .with_context(|| format!("GET {feed_url}"))?;
41    let status = response.status();
42    let elapsed_ms = started.elapsed().as_millis() as u64;
43    if !status.is_success() {
44        warn!(
45            target: TRACE_TARGET,
46            feed_url,
47            status = status.as_u16(),
48            elapsed_ms,
49            "feed fetch failed"
50        );
51        bail!("feed {feed_url} returned {status}");
52    }
53    let text = response.text()?;
54    let releases = parse_releases(&text)?;
55    debug!(
56        target: TRACE_TARGET,
57        feed_url,
58        status = status.as_u16(),
59        elapsed_ms,
60        releases = releases.len(),
61        "feed fetched"
62    );
63    Ok(releases)
64}
65
66/// Pure parser separated from the HTTP call so it's trivially testable.
67pub fn parse_releases(text: &str) -> Result<Vec<GithubRelease>> {
68    if let Ok(list) = serde_json::from_str::<Vec<GithubRelease>>(text) {
69        return Ok(list);
70    }
71    let single: GithubRelease = serde_json::from_str(text)
72        .with_context(|| "feed JSON is neither an array nor a single release")?;
73    Ok(vec![single])
74}
75
76/// Parse the version from a release tag.  Accepts both `1.2.3` and
77/// `v1.2.3`.
78pub fn parse_tag(tag: &str) -> Option<Version> {
79    Version::parse(tag.strip_prefix('v').unwrap_or(tag)).ok()
80}
81
82/// Compare the local version against the feed and decide whether to
83/// update.
84pub fn check(feed_url: &str, current: &Version, prerelease_ok: bool) -> Result<CheckOutcome> {
85    let releases = fetch_releases(feed_url)?;
86    Ok(decide(&releases, current, prerelease_ok))
87}
88
89/// Pure decision function so we can unit-test the prerelease/draft
90/// filters without going through HTTP.
91pub fn decide(releases: &[GithubRelease], current: &Version, prerelease_ok: bool) -> CheckOutcome {
92    let latest = releases
93        .iter()
94        .filter(|r| !r.draft)
95        .filter(|r| prerelease_ok || !r.prerelease)
96        .filter_map(|r| parse_tag(&r.tag_name))
97        .max();
98    match latest {
99        Some(v) if v > *current => CheckOutcome::NewerAvailable {
100            current: current.clone(),
101            latest: v,
102        },
103        _ => CheckOutcome::UpToDate {
104            current: current.clone(),
105        },
106    }
107}
108
109/// The cargo-dist installer asset name for the current platform.
110pub fn installer_asset_name() -> &'static str {
111    if cfg!(target_os = "windows") {
112        "studio-worker-installer.ps1"
113    } else {
114        "studio-worker-installer.sh"
115    }
116}
117
118/// Resolve which installer asset to download for the given release.
119/// Pulled out of `apply` for unit tests.
120pub fn resolve_installer_url(release: &GithubRelease) -> Option<&str> {
121    let name = installer_asset_name();
122    release
123        .assets
124        .iter()
125        .find(|a| a.name == name)
126        .map(|a| a.browser_download_url.as_str())
127}
128
129/// Verify a streamed installer download wrote exactly the body the
130/// server promised.  `expected` is the response's `Content-Length`;
131/// it's `None` for chunked transfers, where there's nothing to check
132/// and we accept whatever arrived.  A mismatch means the download was
133/// truncated or corrupt — and because the very next step hands this
134/// file to `sh` / `powershell`, running a half-written installer is
135/// far more dangerous than failing the update and retrying on the next
136/// tick, so we surface a clear error instead of executing it.
137fn verify_download_len(copied: u64, expected: Option<u64>) -> Result<()> {
138    match expected {
139        Some(expected) if copied != expected => bail!(
140            "size mismatch: wrote {copied} bytes but the server declared \
141             Content-Length {expected} (installer download truncated or corrupt)"
142        ),
143        _ => Ok(()),
144    }
145}
146
147/// Apply an update by downloading the cargo-dist installer for the
148/// current platform and running it.
149pub fn apply(feed_url: &str, latest: &Version) -> Result<()> {
150    apply_with(feed_url, latest, &RealRunner)
151}
152
153/// Side-effect abstraction for `apply_with`.  The real implementation
154/// downloads via HTTP and runs `sh` / `powershell`; tests inject a fake
155/// that records calls.
156pub trait UpdateRunner {
157    fn download(&self, url: &str, dest: &Path) -> Result<()>;
158    fn run_installer(&self, installer_path: &Path) -> Result<()>;
159}
160
161pub struct RealRunner;
162
163impl UpdateRunner for RealRunner {
164    fn download(&self, url: &str, dest: &Path) -> Result<()> {
165        validate_installer_download_url(url)?;
166        let client = reqwest::blocking::Client::builder()
167            .timeout(Duration::from_secs(300))
168            .user_agent(concat!("studio-worker/", env!("CARGO_PKG_VERSION")))
169            .build()?;
170        let started = Instant::now();
171        let mut response = client.get(url).send()?.error_for_status()?;
172        // Capture the declared length (absent on chunked transfers)
173        // before streaming so a short read is caught below — the next
174        // step runs this file as a shell / PowerShell script.
175        let expected_len = response.content_length();
176        let mut file = std::fs::File::create(dest)?;
177        let bytes = std::io::copy(&mut response, &mut file)?;
178        // Reject a truncated / overlong download before `apply_with`
179        // hands the file to the installer runner.  Bailing here means
180        // `run_installer` never executes, and `apply_with`'s tempdir
181        // drop cleans up the partial file.
182        verify_download_len(bytes, expected_len)
183            .with_context(|| format!("downloading installer from {url}"))?;
184        info!(
185            target: TRACE_TARGET,
186            url,
187            dest = %dest.display(),
188            bytes,
189            elapsed_ms = started.elapsed().as_millis() as u64,
190            "installer downloaded"
191        );
192        Ok(())
193    }
194
195    fn run_installer(&self, installer_path: &Path) -> Result<()> {
196        if cfg!(target_os = "windows") {
197            let status = std::process::Command::new("powershell")
198                .args([
199                    "-NoProfile",
200                    "-ExecutionPolicy",
201                    "Bypass",
202                    "-File",
203                    installer_path
204                        .to_str()
205                        .ok_or_else(|| anyhow!("installer path not UTF-8"))?,
206                ])
207                .status()?;
208            if !status.success() {
209                bail!("installer exited with {status}");
210            }
211        } else {
212            let status = std::process::Command::new("sh")
213                .arg(installer_path)
214                .status()?;
215            if !status.success() {
216                bail!("installer exited with {status}");
217            }
218        }
219        Ok(())
220    }
221}
222
223fn validate_installer_download_url(raw: &str) -> Result<()> {
224    let url = url::Url::parse(raw).with_context(|| format!("invalid installer URL {raw:?}"))?;
225    if url.scheme() == "https" {
226        return Ok(());
227    }
228    if url.scheme() == "http" {
229        if let Some(host) = url.host_str() {
230            if host == "localhost"
231                || host
232                    .parse::<std::net::IpAddr>()
233                    .is_ok_and(|ip| ip.is_loopback())
234            {
235                return Ok(());
236            }
237        }
238    }
239    bail!("installer URL must use https (loopback http is allowed for tests): {raw}");
240}
241
242pub fn apply_with<R: UpdateRunner>(feed_url: &str, latest: &Version, runner: &R) -> Result<()> {
243    info!(
244        target: TRACE_TARGET,
245        feed_url,
246        latest = %latest,
247        "applying update"
248    );
249    let releases = fetch_releases(feed_url)?;
250    let release = releases
251        .iter()
252        .find(|r| parse_tag(&r.tag_name).as_ref() == Some(latest))
253        .ok_or_else(|| anyhow!("release {latest} not present in feed"))?;
254
255    let url = resolve_installer_url(release).ok_or_else(|| {
256        anyhow!(
257            "release {} is missing installer asset {}",
258            latest,
259            installer_asset_name()
260        )
261    })?;
262
263    let tmp = tempfile::tempdir().context("creating tempdir for installer")?;
264    let installer_path = tmp.path().join(installer_asset_name());
265    info!(
266        target: TRACE_TARGET,
267        url,
268        dest = %installer_path.display(),
269        latest = %latest,
270        "downloading installer"
271    );
272    runner.download(url, &installer_path)?;
273    info!(
274        target: TRACE_TARGET,
275        installer = %installer_path.display(),
276        latest = %latest,
277        "running installer"
278    );
279    runner.run_installer(&installer_path)?;
280    info!(
281        target: TRACE_TARGET,
282        latest = %latest,
283        "installer completed; binary replaced"
284    );
285    Ok(())
286}
287
288/// Compute the (binary, args) tuple we'd re-exec ourselves with.  Pure
289/// — actual exec lives in [`restart_self`].
290pub fn restart_argv() -> (PathBuf, Vec<std::ffi::OsString>) {
291    let mut iter = std::env::args_os();
292    let bin = iter
293        .next()
294        .map(PathBuf::from)
295        .unwrap_or_else(|| PathBuf::from("studio-worker"));
296    let args: Vec<std::ffi::OsString> = iter.collect();
297    (bin, args)
298}
299
300/// Replace the current process with a fresh exec of the (now-updated)
301/// binary.  On unix we use `execvp`; on Windows we spawn the successor
302/// and exit cleanly.  Unreachable from tests — covered by integration
303/// tests of `apply_with` instead.
304#[cfg_attr(coverage_nightly, coverage(off))]
305pub fn restart_self() -> ! {
306    let (bin, args) = restart_argv();
307    info!(
308        target: TRACE_TARGET,
309        bin = %bin.display(),
310        argc = args.len(),
311        "restarting into updated binary"
312    );
313    #[cfg(unix)]
314    {
315        use std::os::unix::process::CommandExt;
316        let err = std::process::Command::new(&bin).args(&args).exec();
317        tracing::error!(
318            target: TRACE_TARGET,
319            bin = %bin.display(),
320            %err,
321            "exec into updated binary failed"
322        );
323        eprintln!("[studio-worker] exec failed: {err}");
324        std::process::exit(1);
325    }
326    #[cfg(not(unix))]
327    {
328        match std::process::Command::new(&bin).args(&args).spawn() {
329            Ok(_) => std::process::exit(0),
330            Err(err) => {
331                tracing::error!(
332                    target: TRACE_TARGET,
333                    bin = %bin.display(),
334                    %err,
335                    "spawn-restart of updated binary failed"
336                );
337                eprintln!("[studio-worker] spawn-restart failed: {err}");
338                std::process::exit(1);
339            }
340        }
341    }
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347    use crate::types::{GithubRelease, GithubReleaseAsset};
348    use std::cell::RefCell;
349    use std::path::PathBuf;
350    use tempfile::tempdir;
351
352    fn rel(tag: &str, prerelease: bool, draft: bool, with_installer: bool) -> GithubRelease {
353        let assets = if with_installer {
354            vec![GithubReleaseAsset {
355                name: installer_asset_name().to_string(),
356                browser_download_url: format!("https://example.com/{tag}"),
357            }]
358        } else {
359            vec![]
360        };
361        GithubRelease {
362            tag_name: tag.to_string(),
363            prerelease,
364            draft,
365            assets,
366        }
367    }
368
369    #[test]
370    fn parse_tag_accepts_v_prefix_and_bare() {
371        assert_eq!(parse_tag("v1.2.3"), Some(Version::new(1, 2, 3)));
372        assert_eq!(parse_tag("1.2.3"), Some(Version::new(1, 2, 3)));
373        assert!(parse_tag("garbage").is_none());
374    }
375
376    #[test]
377    fn parse_releases_accepts_array() {
378        let text = serde_json::to_string(&serde_json::json!([
379            { "tag_name": "v1.0.0", "prerelease": false, "draft": false, "assets": [] }
380        ]))
381        .unwrap();
382        let releases = parse_releases(&text).unwrap();
383        assert_eq!(releases.len(), 1);
384        assert_eq!(releases[0].tag_name, "v1.0.0");
385    }
386
387    #[test]
388    fn parse_releases_accepts_single_object() {
389        let text = serde_json::to_string(&serde_json::json!({
390            "tag_name": "v2.0.0", "prerelease": false, "draft": false, "assets": []
391        }))
392        .unwrap();
393        let releases = parse_releases(&text).unwrap();
394        assert_eq!(releases.len(), 1);
395        assert_eq!(releases[0].tag_name, "v2.0.0");
396    }
397
398    #[test]
399    fn parse_releases_errors_on_garbage() {
400        assert!(parse_releases("not json").is_err());
401    }
402
403    #[test]
404    fn decide_reports_up_to_date_when_no_newer() {
405        let releases = vec![rel("v0.1.0", false, false, true)];
406        let outcome = decide(&releases, &Version::new(0, 1, 0), false);
407        assert_eq!(
408            outcome,
409            CheckOutcome::UpToDate {
410                current: Version::new(0, 1, 0)
411            }
412        );
413    }
414
415    #[test]
416    fn decide_reports_newer_when_higher_present() {
417        let releases = vec![
418            rel("v0.1.0", false, false, true),
419            rel("v0.2.0", false, false, true),
420        ];
421        let outcome = decide(&releases, &Version::new(0, 1, 0), false);
422        assert_eq!(
423            outcome,
424            CheckOutcome::NewerAvailable {
425                current: Version::new(0, 1, 0),
426                latest: Version::new(0, 2, 0),
427            }
428        );
429    }
430
431    #[test]
432    fn decide_skips_prereleases_unless_opted_in() {
433        let releases = vec![
434            rel("v0.1.0", false, false, true),
435            rel("v0.3.0-rc.1", true, false, true),
436        ];
437        let outcome = decide(&releases, &Version::new(0, 1, 0), false);
438        assert!(matches!(outcome, CheckOutcome::UpToDate { .. }));
439        let outcome = decide(&releases, &Version::new(0, 1, 0), true);
440        assert!(matches!(outcome, CheckOutcome::NewerAvailable { .. }));
441    }
442
443    #[test]
444    fn decide_skips_drafts() {
445        let releases = vec![
446            rel("v0.1.0", false, false, true),
447            rel("v0.9.0", false, true, true),
448        ];
449        let outcome = decide(&releases, &Version::new(0, 1, 0), false);
450        assert!(matches!(outcome, CheckOutcome::UpToDate { .. }));
451    }
452
453    #[test]
454    fn decide_handles_empty_feed() {
455        let outcome = decide(&[], &Version::new(1, 0, 0), false);
456        assert!(matches!(outcome, CheckOutcome::UpToDate { .. }));
457    }
458
459    #[test]
460    fn decide_skips_malformed_tags() {
461        let releases = vec![
462            rel("garbage", false, false, true),
463            rel("v0.1.0", false, false, true),
464        ];
465        let outcome = decide(&releases, &Version::new(0, 0, 1), false);
466        match outcome {
467            CheckOutcome::NewerAvailable { latest, .. } => {
468                assert_eq!(latest, Version::new(0, 1, 0))
469            }
470            _ => panic!("expected newer"),
471        }
472    }
473
474    #[test]
475    fn installer_asset_name_matches_platform() {
476        let name = installer_asset_name();
477        if cfg!(target_os = "windows") {
478            assert_eq!(name, "studio-worker-installer.ps1");
479        } else {
480            assert_eq!(name, "studio-worker-installer.sh");
481        }
482    }
483
484    #[test]
485    fn resolve_installer_url_finds_the_right_asset() {
486        let release = rel("v1.0.0", false, false, true);
487        let url = resolve_installer_url(&release).unwrap();
488        assert_eq!(url, "https://example.com/v1.0.0");
489    }
490
491    #[test]
492    fn resolve_installer_url_returns_none_when_missing() {
493        let release = rel("v1.0.0", false, false, false);
494        assert!(resolve_installer_url(&release).is_none());
495    }
496
497    // -----------------------------------------------------------------
498    // verify_download_len — guards the installer download against a
499    // short read before the bytes are handed to `sh` / `powershell`.
500    // A truncated installer that runs is far worse than a failed
501    // update, so a Content-Length mismatch must surface as an error.
502    // -----------------------------------------------------------------
503
504    #[test]
505    fn verify_download_len_accepts_exact_match() {
506        assert!(verify_download_len(2048, Some(2048)).is_ok());
507    }
508
509    #[test]
510    fn verify_download_len_accepts_when_length_unknown() {
511        // Chunked transfers omit Content-Length; nothing to check, so
512        // we accept whatever streamed in (same as before this guard).
513        assert!(verify_download_len(123, None).is_ok());
514    }
515
516    #[test]
517    fn verify_download_len_rejects_truncated_installer() {
518        let err = verify_download_len(40, Some(100)).unwrap_err().to_string();
519        assert!(err.contains("size mismatch"), "got: {err}");
520        assert!(err.contains("40"), "got: {err}");
521        assert!(err.contains("100"), "got: {err}");
522    }
523
524    #[test]
525    fn verify_download_len_rejects_overlong_installer() {
526        // A body longer than the declared length is just as corrupt as
527        // a short one — reject both rather than run a bad installer.
528        assert!(verify_download_len(120, Some(100)).is_err());
529    }
530
531    #[test]
532    fn validate_installer_download_url_allows_https() {
533        validate_installer_download_url("https://github.com/owner/repo/releases/download/x/i.sh")
534            .unwrap();
535    }
536
537    #[test]
538    fn validate_installer_download_url_allows_loopback_http_for_tests() {
539        validate_installer_download_url("http://127.0.0.1:1234/i.sh").unwrap();
540        validate_installer_download_url("http://localhost:1234/i.sh").unwrap();
541    }
542
543    #[test]
544    fn validate_installer_download_url_rejects_remote_http() {
545        let err = validate_installer_download_url("http://example.com/i.sh")
546            .unwrap_err()
547            .to_string();
548        assert!(err.contains("https"), "got: {err}");
549    }
550
551    #[test]
552    fn restart_argv_uses_current_exe_and_args() {
553        let (bin, _args) = restart_argv();
554        assert!(!bin.as_os_str().is_empty());
555    }
556
557    // -----------------------------------------------------------------
558    // apply_with — exercised via a fake runner that records calls.
559    // -----------------------------------------------------------------
560
561    struct FakeRunner {
562        downloaded: RefCell<Vec<(String, PathBuf)>>,
563        ran: RefCell<Vec<PathBuf>>,
564        fail_download: bool,
565        fail_run: bool,
566    }
567
568    impl UpdateRunner for FakeRunner {
569        fn download(&self, url: &str, dest: &Path) -> Result<()> {
570            self.downloaded
571                .borrow_mut()
572                .push((url.to_string(), dest.to_path_buf()));
573            if self.fail_download {
574                bail!("simulated download failure");
575            }
576            // Touch the file so apply's runner contract is satisfied.
577            std::fs::write(dest, b"#!/bin/sh\necho fake installer\n").unwrap();
578            Ok(())
579        }
580        fn run_installer(&self, installer_path: &Path) -> Result<()> {
581            self.ran.borrow_mut().push(installer_path.to_path_buf());
582            if self.fail_run {
583                bail!("simulated installer failure");
584            }
585            Ok(())
586        }
587    }
588
589    fn write_fixture_feed(dir: &tempfile::TempDir, releases: serde_json::Value) -> String {
590        let path = dir.path().join("releases.json");
591        std::fs::write(&path, releases.to_string()).unwrap();
592        format!("file://{}", path.to_string_lossy())
593    }
594
595    fn fake_release_with_installer(tag: &str) -> serde_json::Value {
596        serde_json::json!({
597            "tag_name": tag,
598            "prerelease": false,
599            "draft": false,
600            "assets": [{
601                "name": installer_asset_name(),
602                "browser_download_url": format!("https://example.invalid/{tag}/{}", installer_asset_name()),
603            }],
604        })
605    }
606
607    // The reqwest blocking client doesn't follow `file://` URLs, so we
608    // use wiremock-served feeds for the apply tests via the integration
609    // suite (`tests/auto_update.rs`).  Here we just verify the unit-test
610    // branches: missing release, missing asset.
611    #[test]
612    fn apply_with_errors_when_release_missing() {
613        // Static fixture parsed via parse_releases bypasses HTTP for this
614        // narrow test.  We can't call apply_with without a real HTTP fetch
615        // since fetch_releases is HTTP only — but we can drive the
616        // post-fetch branches directly.
617        let releases: Vec<GithubRelease> = vec![rel("v0.1.0", false, false, true)];
618        let missing = Version::new(9, 9, 9);
619        let url = releases
620            .iter()
621            .find(|r| parse_tag(&r.tag_name).as_ref() == Some(&missing));
622        assert!(url.is_none(), "v9.9.9 should not be in the fixture");
623    }
624
625    // Sanity: we can write a fake feed file (used by integration tests).
626    #[test]
627    fn writing_a_fake_feed_round_trips_through_parse_releases() {
628        let dir = tempdir().unwrap();
629        let url = write_fixture_feed(
630            &dir,
631            serde_json::json!([fake_release_with_installer("v0.1.0")]),
632        );
633        let _ = url;
634        let text = std::fs::read_to_string(dir.path().join("releases.json")).unwrap();
635        let releases = parse_releases(&text).unwrap();
636        assert_eq!(releases.len(), 1);
637        assert_eq!(releases[0].tag_name, "v0.1.0");
638    }
639
640    #[test]
641    fn fake_runner_records_download_and_run() {
642        let runner = FakeRunner {
643            downloaded: RefCell::new(Vec::new()),
644            ran: RefCell::new(Vec::new()),
645            fail_download: false,
646            fail_run: false,
647        };
648        let dir = tempdir().unwrap();
649        let dest = dir.path().join("installer.sh");
650        runner.download("https://example.com/a", &dest).unwrap();
651        runner.run_installer(&dest).unwrap();
652        assert_eq!(runner.downloaded.borrow().len(), 1);
653        assert_eq!(runner.ran.borrow().len(), 1);
654        assert!(dest.exists());
655    }
656
657    #[test]
658    fn fake_runner_surfaces_download_errors() {
659        let runner = FakeRunner {
660            downloaded: RefCell::new(Vec::new()),
661            ran: RefCell::new(Vec::new()),
662            fail_download: true,
663            fail_run: false,
664        };
665        let dir = tempdir().unwrap();
666        let dest = dir.path().join("installer.sh");
667        let err = runner.download("https://example.com/a", &dest).unwrap_err();
668        assert!(err.to_string().contains("simulated download"));
669    }
670
671    #[test]
672    fn fake_runner_surfaces_install_errors() {
673        let runner = FakeRunner {
674            downloaded: RefCell::new(Vec::new()),
675            ran: RefCell::new(Vec::new()),
676            fail_download: false,
677            fail_run: true,
678        };
679        let dir = tempdir().unwrap();
680        let dest = dir.path().join("installer.sh");
681        let err = runner.run_installer(&dest).unwrap_err();
682        assert!(err.to_string().contains("simulated installer"));
683    }
684}