1use crate::types::GithubRelease;
11use anyhow::{anyhow, bail, Context, Result};
12use semver::Version;
13use std::path::{Path, PathBuf};
14use std::time::{Duration, Instant};
15use tracing::{debug, info, warn};
16
17const TRACE_TARGET: &str = "studio_worker::update";
21
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub enum CheckOutcome {
24 UpToDate { current: Version },
25 NewerAvailable { current: Version, latest: Version },
26}
27
28pub fn fetch_releases(feed_url: &str) -> Result<Vec<GithubRelease>> {
30 let client = reqwest::blocking::Client::builder()
31 .timeout(Duration::from_secs(15))
32 .user_agent(concat!("studio-worker/", env!("CARGO_PKG_VERSION")))
33 .build()
34 .context("building reqwest client")?;
35 let started = Instant::now();
36 let response = client
37 .get(feed_url)
38 .header("accept", "application/vnd.github+json")
39 .send()
40 .with_context(|| format!("GET {feed_url}"))?;
41 let status = response.status();
42 let elapsed_ms = started.elapsed().as_millis() as u64;
43 if !status.is_success() {
44 warn!(
45 target: TRACE_TARGET,
46 feed_url,
47 status = status.as_u16(),
48 elapsed_ms,
49 "feed fetch failed"
50 );
51 bail!("feed {feed_url} returned {status}");
52 }
53 let text = response.text()?;
54 let releases = parse_releases(&text)?;
55 debug!(
56 target: TRACE_TARGET,
57 feed_url,
58 status = status.as_u16(),
59 elapsed_ms,
60 releases = releases.len(),
61 "feed fetched"
62 );
63 Ok(releases)
64}
65
66pub fn parse_releases(text: &str) -> Result<Vec<GithubRelease>> {
68 if let Ok(list) = serde_json::from_str::<Vec<GithubRelease>>(text) {
69 return Ok(list);
70 }
71 let single: GithubRelease = serde_json::from_str(text)
72 .with_context(|| "feed JSON is neither an array nor a single release")?;
73 Ok(vec![single])
74}
75
76pub fn parse_tag(tag: &str) -> Option<Version> {
83 let candidates = [
84 tag,
85 tag.strip_prefix('v').unwrap_or(tag),
86 tag.rsplit_once("-v").map(|(_, v)| v).unwrap_or(tag),
87 ];
88 candidates.iter().find_map(|c| Version::parse(c).ok())
89}
90
91pub fn check(feed_url: &str, current: &Version, prerelease_ok: bool) -> Result<CheckOutcome> {
94 let releases = fetch_releases(feed_url)?;
95 Ok(decide(&releases, current, prerelease_ok))
96}
97
98pub fn decide(releases: &[GithubRelease], current: &Version, prerelease_ok: bool) -> CheckOutcome {
101 let latest = releases
102 .iter()
103 .filter(|r| !r.draft)
104 .filter(|r| prerelease_ok || !r.prerelease)
105 .filter_map(|r| parse_tag(&r.tag_name))
106 .max();
107 match latest {
108 Some(v) if v > *current => CheckOutcome::NewerAvailable {
109 current: current.clone(),
110 latest: v,
111 },
112 _ => CheckOutcome::UpToDate {
113 current: current.clone(),
114 },
115 }
116}
117
118pub fn installer_asset_name() -> &'static str {
120 if cfg!(target_os = "windows") {
121 "studio-worker-installer.ps1"
122 } else {
123 "studio-worker-installer.sh"
124 }
125}
126
127pub fn resolve_installer_url(release: &GithubRelease) -> Option<&str> {
130 let name = installer_asset_name();
131 release
132 .assets
133 .iter()
134 .find(|a| a.name == name)
135 .map(|a| a.browser_download_url.as_str())
136}
137
138fn verify_download_len(copied: u64, expected: Option<u64>) -> Result<()> {
147 match expected {
148 Some(expected) if copied != expected => bail!(
149 "size mismatch: wrote {copied} bytes but the server declared \
150 Content-Length {expected} (installer download truncated or corrupt)"
151 ),
152 _ => Ok(()),
153 }
154}
155
156pub fn apply(feed_url: &str, latest: &Version) -> Result<()> {
159 apply_with(feed_url, latest, &RealRunner)
160}
161
162pub trait UpdateRunner {
166 fn download(&self, url: &str, dest: &Path) -> Result<()>;
167 fn run_installer(&self, installer_path: &Path) -> Result<()>;
168}
169
170pub struct RealRunner;
171
172impl UpdateRunner for RealRunner {
173 fn download(&self, url: &str, dest: &Path) -> Result<()> {
174 validate_installer_download_url(url)?;
175 let client = reqwest::blocking::Client::builder()
176 .timeout(Duration::from_secs(300))
177 .user_agent(concat!("studio-worker/", env!("CARGO_PKG_VERSION")))
178 .build()?;
179 let started = Instant::now();
180 let mut response = client.get(url).send()?.error_for_status()?;
181 let expected_len = response.content_length();
185 let mut file = std::fs::File::create(dest)?;
186 let bytes = std::io::copy(&mut response, &mut file)?;
187 verify_download_len(bytes, expected_len)
192 .with_context(|| format!("downloading installer from {url}"))?;
193 info!(
194 target: TRACE_TARGET,
195 url,
196 dest = %dest.display(),
197 bytes,
198 elapsed_ms = started.elapsed().as_millis() as u64,
199 "installer downloaded"
200 );
201 Ok(())
202 }
203
204 fn run_installer(&self, installer_path: &Path) -> Result<()> {
205 if cfg!(target_os = "windows") {
206 let status = std::process::Command::new("powershell")
207 .args([
208 "-NoProfile",
209 "-ExecutionPolicy",
210 "Bypass",
211 "-File",
212 installer_path
213 .to_str()
214 .ok_or_else(|| anyhow!("installer path not UTF-8"))?,
215 ])
216 .status()?;
217 if !status.success() {
218 bail!("installer exited with {status}");
219 }
220 } else {
221 let status = std::process::Command::new("sh")
222 .arg(installer_path)
223 .status()?;
224 if !status.success() {
225 bail!("installer exited with {status}");
226 }
227 }
228 Ok(())
229 }
230}
231
232fn validate_installer_download_url(raw: &str) -> Result<()> {
233 let url = url::Url::parse(raw).with_context(|| format!("invalid installer URL {raw:?}"))?;
234 if url.scheme() == "https" {
235 return Ok(());
236 }
237 if url.scheme() == "http" {
238 if let Some(host) = url.host_str() {
239 if host == "localhost"
240 || host
241 .parse::<std::net::IpAddr>()
242 .is_ok_and(|ip| ip.is_loopback())
243 {
244 return Ok(());
245 }
246 }
247 }
248 bail!("installer URL must use https (loopback http is allowed for tests): {raw}");
249}
250
251pub fn apply_with<R: UpdateRunner>(feed_url: &str, latest: &Version, runner: &R) -> Result<()> {
252 info!(
253 target: TRACE_TARGET,
254 feed_url,
255 latest = %latest,
256 "applying update"
257 );
258 let releases = fetch_releases(feed_url)?;
259 let release = releases
260 .iter()
261 .find(|r| parse_tag(&r.tag_name).as_ref() == Some(latest))
262 .ok_or_else(|| anyhow!("release {latest} not present in feed"))?;
263
264 let url = resolve_installer_url(release).ok_or_else(|| {
265 anyhow!(
266 "release {} is missing installer asset {}",
267 latest,
268 installer_asset_name()
269 )
270 })?;
271
272 let tmp = tempfile::tempdir().context("creating tempdir for installer")?;
273 let installer_path = tmp.path().join(installer_asset_name());
274 info!(
275 target: TRACE_TARGET,
276 url,
277 dest = %installer_path.display(),
278 latest = %latest,
279 "downloading installer"
280 );
281 runner.download(url, &installer_path)?;
282 info!(
283 target: TRACE_TARGET,
284 installer = %installer_path.display(),
285 latest = %latest,
286 "running installer"
287 );
288 runner.run_installer(&installer_path)?;
289 info!(
290 target: TRACE_TARGET,
291 latest = %latest,
292 "installer completed; binary replaced"
293 );
294 Ok(())
295}
296
297pub fn restart_argv() -> (PathBuf, Vec<std::ffi::OsString>) {
300 let mut iter = std::env::args_os();
301 let bin = iter
302 .next()
303 .map(PathBuf::from)
304 .unwrap_or_else(|| PathBuf::from("studio-worker"));
305 let args: Vec<std::ffi::OsString> = iter.collect();
306 (bin, args)
307}
308
309#[cfg_attr(coverage_nightly, coverage(off))]
314pub fn restart_self() -> ! {
315 let (bin, args) = restart_argv();
316 info!(
317 target: TRACE_TARGET,
318 bin = %bin.display(),
319 argc = args.len(),
320 "restarting into updated binary"
321 );
322 #[cfg(unix)]
323 {
324 use std::os::unix::process::CommandExt;
325 let err = std::process::Command::new(&bin).args(&args).exec();
326 tracing::error!(
327 target: TRACE_TARGET,
328 bin = %bin.display(),
329 %err,
330 "exec into updated binary failed"
331 );
332 eprintln!("[studio-worker] exec failed: {err}");
333 std::process::exit(1);
334 }
335 #[cfg(not(unix))]
336 {
337 match std::process::Command::new(&bin).args(&args).spawn() {
338 Ok(_) => std::process::exit(0),
339 Err(err) => {
340 tracing::error!(
341 target: TRACE_TARGET,
342 bin = %bin.display(),
343 %err,
344 "spawn-restart of updated binary failed"
345 );
346 eprintln!("[studio-worker] spawn-restart failed: {err}");
347 std::process::exit(1);
348 }
349 }
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use super::*;
356 use crate::types::{GithubRelease, GithubReleaseAsset};
357 use std::cell::RefCell;
358 use std::path::PathBuf;
359 use tempfile::tempdir;
360
361 fn rel(tag: &str, prerelease: bool, draft: bool, with_installer: bool) -> GithubRelease {
362 let assets = if with_installer {
363 vec![GithubReleaseAsset {
364 name: installer_asset_name().to_string(),
365 browser_download_url: format!("https://example.com/{tag}"),
366 }]
367 } else {
368 vec![]
369 };
370 GithubRelease {
371 tag_name: tag.to_string(),
372 prerelease,
373 draft,
374 assets,
375 }
376 }
377
378 #[test]
379 fn parse_tag_accepts_v_prefix_and_bare() {
380 assert_eq!(parse_tag("v1.2.3"), Some(Version::new(1, 2, 3)));
381 assert_eq!(parse_tag("1.2.3"), Some(Version::new(1, 2, 3)));
382 assert!(parse_tag("garbage").is_none());
383 }
384
385 #[test]
386 fn parse_tag_accepts_component_prefixed_release_tags() {
387 assert_eq!(
392 parse_tag("studio-worker-v0.4.2"),
393 Some(Version::new(0, 4, 2))
394 );
395 assert_eq!(
396 parse_tag("studio-worker-v1.10.0"),
397 Some(Version::new(1, 10, 0))
398 );
399 assert_eq!(
402 parse_tag("studio-worker-v0.5.0-rc.1"),
403 Version::parse("0.5.0-rc.1").ok()
404 );
405 }
406
407 #[test]
408 fn decide_detects_newer_with_component_prefixed_tags() {
409 let releases = vec![
411 rel("studio-worker-v0.4.1", false, false, true),
412 rel("studio-worker-v0.4.2", false, false, true),
413 ];
414 let outcome = decide(&releases, &Version::new(0, 4, 1), false);
415 assert_eq!(
416 outcome,
417 CheckOutcome::NewerAvailable {
418 current: Version::new(0, 4, 1),
419 latest: Version::new(0, 4, 2),
420 }
421 );
422 }
423
424 #[test]
425 fn parse_releases_accepts_array() {
426 let text = serde_json::to_string(&serde_json::json!([
427 { "tag_name": "v1.0.0", "prerelease": false, "draft": false, "assets": [] }
428 ]))
429 .unwrap();
430 let releases = parse_releases(&text).unwrap();
431 assert_eq!(releases.len(), 1);
432 assert_eq!(releases[0].tag_name, "v1.0.0");
433 }
434
435 #[test]
436 fn parse_releases_accepts_single_object() {
437 let text = serde_json::to_string(&serde_json::json!({
438 "tag_name": "v2.0.0", "prerelease": false, "draft": false, "assets": []
439 }))
440 .unwrap();
441 let releases = parse_releases(&text).unwrap();
442 assert_eq!(releases.len(), 1);
443 assert_eq!(releases[0].tag_name, "v2.0.0");
444 }
445
446 #[test]
447 fn parse_releases_errors_on_garbage() {
448 assert!(parse_releases("not json").is_err());
449 }
450
451 #[test]
452 fn decide_reports_up_to_date_when_no_newer() {
453 let releases = vec![rel("v0.1.0", false, false, true)];
454 let outcome = decide(&releases, &Version::new(0, 1, 0), false);
455 assert_eq!(
456 outcome,
457 CheckOutcome::UpToDate {
458 current: Version::new(0, 1, 0)
459 }
460 );
461 }
462
463 #[test]
464 fn decide_reports_newer_when_higher_present() {
465 let releases = vec![
466 rel("v0.1.0", false, false, true),
467 rel("v0.2.0", false, false, true),
468 ];
469 let outcome = decide(&releases, &Version::new(0, 1, 0), false);
470 assert_eq!(
471 outcome,
472 CheckOutcome::NewerAvailable {
473 current: Version::new(0, 1, 0),
474 latest: Version::new(0, 2, 0),
475 }
476 );
477 }
478
479 #[test]
480 fn decide_skips_prereleases_unless_opted_in() {
481 let releases = vec![
482 rel("v0.1.0", false, false, true),
483 rel("v0.3.0-rc.1", true, false, true),
484 ];
485 let outcome = decide(&releases, &Version::new(0, 1, 0), false);
486 assert!(matches!(outcome, CheckOutcome::UpToDate { .. }));
487 let outcome = decide(&releases, &Version::new(0, 1, 0), true);
488 assert!(matches!(outcome, CheckOutcome::NewerAvailable { .. }));
489 }
490
491 #[test]
492 fn decide_skips_drafts() {
493 let releases = vec![
494 rel("v0.1.0", false, false, true),
495 rel("v0.9.0", false, true, true),
496 ];
497 let outcome = decide(&releases, &Version::new(0, 1, 0), false);
498 assert!(matches!(outcome, CheckOutcome::UpToDate { .. }));
499 }
500
501 #[test]
502 fn decide_handles_empty_feed() {
503 let outcome = decide(&[], &Version::new(1, 0, 0), false);
504 assert!(matches!(outcome, CheckOutcome::UpToDate { .. }));
505 }
506
507 #[test]
508 fn decide_skips_malformed_tags() {
509 let releases = vec![
510 rel("garbage", false, false, true),
511 rel("v0.1.0", false, false, true),
512 ];
513 let outcome = decide(&releases, &Version::new(0, 0, 1), false);
514 match outcome {
515 CheckOutcome::NewerAvailable { latest, .. } => {
516 assert_eq!(latest, Version::new(0, 1, 0))
517 }
518 _ => panic!("expected newer"),
519 }
520 }
521
522 #[test]
523 fn installer_asset_name_matches_platform() {
524 let name = installer_asset_name();
525 if cfg!(target_os = "windows") {
526 assert_eq!(name, "studio-worker-installer.ps1");
527 } else {
528 assert_eq!(name, "studio-worker-installer.sh");
529 }
530 }
531
532 #[test]
533 fn resolve_installer_url_finds_the_right_asset() {
534 let release = rel("v1.0.0", false, false, true);
535 let url = resolve_installer_url(&release).unwrap();
536 assert_eq!(url, "https://example.com/v1.0.0");
537 }
538
539 #[test]
540 fn resolve_installer_url_returns_none_when_missing() {
541 let release = rel("v1.0.0", false, false, false);
542 assert!(resolve_installer_url(&release).is_none());
543 }
544
545 #[test]
553 fn verify_download_len_accepts_exact_match() {
554 assert!(verify_download_len(2048, Some(2048)).is_ok());
555 }
556
557 #[test]
558 fn verify_download_len_accepts_when_length_unknown() {
559 assert!(verify_download_len(123, None).is_ok());
562 }
563
564 #[test]
565 fn verify_download_len_rejects_truncated_installer() {
566 let err = verify_download_len(40, Some(100)).unwrap_err().to_string();
567 assert!(err.contains("size mismatch"), "got: {err}");
568 assert!(err.contains("40"), "got: {err}");
569 assert!(err.contains("100"), "got: {err}");
570 }
571
572 #[test]
573 fn verify_download_len_rejects_overlong_installer() {
574 assert!(verify_download_len(120, Some(100)).is_err());
577 }
578
579 #[test]
580 fn validate_installer_download_url_allows_https() {
581 validate_installer_download_url("https://github.com/owner/repo/releases/download/x/i.sh")
582 .unwrap();
583 }
584
585 #[test]
586 fn validate_installer_download_url_allows_loopback_http_for_tests() {
587 validate_installer_download_url("http://127.0.0.1:1234/i.sh").unwrap();
588 validate_installer_download_url("http://localhost:1234/i.sh").unwrap();
589 }
590
591 #[test]
592 fn validate_installer_download_url_rejects_remote_http() {
593 let err = validate_installer_download_url("http://example.com/i.sh")
594 .unwrap_err()
595 .to_string();
596 assert!(err.contains("https"), "got: {err}");
597 }
598
599 #[test]
600 fn restart_argv_uses_current_exe_and_args() {
601 let (bin, _args) = restart_argv();
602 assert!(!bin.as_os_str().is_empty());
603 }
604
605 struct FakeRunner {
610 downloaded: RefCell<Vec<(String, PathBuf)>>,
611 ran: RefCell<Vec<PathBuf>>,
612 fail_download: bool,
613 fail_run: bool,
614 }
615
616 impl UpdateRunner for FakeRunner {
617 fn download(&self, url: &str, dest: &Path) -> Result<()> {
618 self.downloaded
619 .borrow_mut()
620 .push((url.to_string(), dest.to_path_buf()));
621 if self.fail_download {
622 bail!("simulated download failure");
623 }
624 std::fs::write(dest, b"#!/bin/sh\necho fake installer\n").unwrap();
626 Ok(())
627 }
628 fn run_installer(&self, installer_path: &Path) -> Result<()> {
629 self.ran.borrow_mut().push(installer_path.to_path_buf());
630 if self.fail_run {
631 bail!("simulated installer failure");
632 }
633 Ok(())
634 }
635 }
636
637 fn write_fixture_feed(dir: &tempfile::TempDir, releases: serde_json::Value) -> String {
638 let path = dir.path().join("releases.json");
639 std::fs::write(&path, releases.to_string()).unwrap();
640 format!("file://{}", path.to_string_lossy())
641 }
642
643 fn fake_release_with_installer(tag: &str) -> serde_json::Value {
644 serde_json::json!({
645 "tag_name": tag,
646 "prerelease": false,
647 "draft": false,
648 "assets": [{
649 "name": installer_asset_name(),
650 "browser_download_url": format!("https://example.invalid/{tag}/{}", installer_asset_name()),
651 }],
652 })
653 }
654
655 #[test]
660 fn apply_with_errors_when_release_missing() {
661 let releases: Vec<GithubRelease> = vec![rel("v0.1.0", false, false, true)];
666 let missing = Version::new(9, 9, 9);
667 let url = releases
668 .iter()
669 .find(|r| parse_tag(&r.tag_name).as_ref() == Some(&missing));
670 assert!(url.is_none(), "v9.9.9 should not be in the fixture");
671 }
672
673 #[test]
675 fn writing_a_fake_feed_round_trips_through_parse_releases() {
676 let dir = tempdir().unwrap();
677 let url = write_fixture_feed(
678 &dir,
679 serde_json::json!([fake_release_with_installer("v0.1.0")]),
680 );
681 let _ = url;
682 let text = std::fs::read_to_string(dir.path().join("releases.json")).unwrap();
683 let releases = parse_releases(&text).unwrap();
684 assert_eq!(releases.len(), 1);
685 assert_eq!(releases[0].tag_name, "v0.1.0");
686 }
687
688 #[test]
689 fn fake_runner_records_download_and_run() {
690 let runner = FakeRunner {
691 downloaded: RefCell::new(Vec::new()),
692 ran: RefCell::new(Vec::new()),
693 fail_download: false,
694 fail_run: false,
695 };
696 let dir = tempdir().unwrap();
697 let dest = dir.path().join("installer.sh");
698 runner.download("https://example.com/a", &dest).unwrap();
699 runner.run_installer(&dest).unwrap();
700 assert_eq!(runner.downloaded.borrow().len(), 1);
701 assert_eq!(runner.ran.borrow().len(), 1);
702 assert!(dest.exists());
703 }
704
705 #[test]
706 fn fake_runner_surfaces_download_errors() {
707 let runner = FakeRunner {
708 downloaded: RefCell::new(Vec::new()),
709 ran: RefCell::new(Vec::new()),
710 fail_download: true,
711 fail_run: false,
712 };
713 let dir = tempdir().unwrap();
714 let dest = dir.path().join("installer.sh");
715 let err = runner.download("https://example.com/a", &dest).unwrap_err();
716 assert!(err.to_string().contains("simulated download"));
717 }
718
719 #[test]
720 fn fake_runner_surfaces_install_errors() {
721 let runner = FakeRunner {
722 downloaded: RefCell::new(Vec::new()),
723 ran: RefCell::new(Vec::new()),
724 fail_download: false,
725 fail_run: true,
726 };
727 let dir = tempdir().unwrap();
728 let dest = dir.path().join("installer.sh");
729 let err = runner.run_installer(&dest).unwrap_err();
730 assert!(err.to_string().contains("simulated installer"));
731 }
732}