1use crate::types::GithubRelease;
11use anyhow::{anyhow, bail, Context, Result};
12use semver::Version;
13use std::path::{Path, PathBuf};
14use std::time::{Duration, Instant};
15use tracing::{debug, info, warn};
16
17const TRACE_TARGET: &str = "studio_worker::update";
21
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub enum CheckOutcome {
24 UpToDate { current: Version },
25 NewerAvailable { current: Version, latest: Version },
26}
27
28pub fn fetch_releases(feed_url: &str) -> Result<Vec<GithubRelease>> {
30 let client = reqwest::blocking::Client::builder()
31 .timeout(Duration::from_secs(15))
32 .user_agent(concat!("studio-worker/", env!("CARGO_PKG_VERSION")))
33 .build()
34 .context("building reqwest client")?;
35 let started = Instant::now();
36 let response = client
37 .get(feed_url)
38 .header("accept", "application/vnd.github+json")
39 .send()
40 .with_context(|| format!("GET {feed_url}"))?;
41 let status = response.status();
42 let elapsed_ms = started.elapsed().as_millis() as u64;
43 if !status.is_success() {
44 warn!(
45 target: TRACE_TARGET,
46 feed_url,
47 status = status.as_u16(),
48 elapsed_ms,
49 "feed fetch failed"
50 );
51 bail!("feed {feed_url} returned {status}");
52 }
53 let text = response.text()?;
54 let releases = parse_releases(&text)?;
55 debug!(
56 target: TRACE_TARGET,
57 feed_url,
58 status = status.as_u16(),
59 elapsed_ms,
60 releases = releases.len(),
61 "feed fetched"
62 );
63 Ok(releases)
64}
65
66pub fn parse_releases(text: &str) -> Result<Vec<GithubRelease>> {
68 if let Ok(list) = serde_json::from_str::<Vec<GithubRelease>>(text) {
69 return Ok(list);
70 }
71 let single: GithubRelease = serde_json::from_str(text)
72 .with_context(|| "feed JSON is neither an array nor a single release")?;
73 Ok(vec![single])
74}
75
76pub fn parse_tag(tag: &str) -> Option<Version> {
79 Version::parse(tag.strip_prefix('v').unwrap_or(tag)).ok()
80}
81
82pub fn check(feed_url: &str, current: &Version, prerelease_ok: bool) -> Result<CheckOutcome> {
85 let releases = fetch_releases(feed_url)?;
86 Ok(decide(&releases, current, prerelease_ok))
87}
88
89pub fn decide(releases: &[GithubRelease], current: &Version, prerelease_ok: bool) -> CheckOutcome {
92 let latest = releases
93 .iter()
94 .filter(|r| !r.draft)
95 .filter(|r| prerelease_ok || !r.prerelease)
96 .filter_map(|r| parse_tag(&r.tag_name))
97 .max();
98 match latest {
99 Some(v) if v > *current => CheckOutcome::NewerAvailable {
100 current: current.clone(),
101 latest: v,
102 },
103 _ => CheckOutcome::UpToDate {
104 current: current.clone(),
105 },
106 }
107}
108
109pub fn installer_asset_name() -> &'static str {
111 if cfg!(target_os = "windows") {
112 "studio-worker-installer.ps1"
113 } else {
114 "studio-worker-installer.sh"
115 }
116}
117
118pub fn resolve_installer_url(release: &GithubRelease) -> Option<&str> {
121 let name = installer_asset_name();
122 release
123 .assets
124 .iter()
125 .find(|a| a.name == name)
126 .map(|a| a.browser_download_url.as_str())
127}
128
129fn verify_download_len(copied: u64, expected: Option<u64>) -> Result<()> {
138 match expected {
139 Some(expected) if copied != expected => bail!(
140 "size mismatch: wrote {copied} bytes but the server declared \
141 Content-Length {expected} (installer download truncated or corrupt)"
142 ),
143 _ => Ok(()),
144 }
145}
146
147pub fn apply(feed_url: &str, latest: &Version) -> Result<()> {
150 apply_with(feed_url, latest, &RealRunner)
151}
152
153pub trait UpdateRunner {
157 fn download(&self, url: &str, dest: &Path) -> Result<()>;
158 fn run_installer(&self, installer_path: &Path) -> Result<()>;
159}
160
161pub struct RealRunner;
162
163impl UpdateRunner for RealRunner {
164 fn download(&self, url: &str, dest: &Path) -> Result<()> {
165 validate_installer_download_url(url)?;
166 let client = reqwest::blocking::Client::builder()
167 .timeout(Duration::from_secs(300))
168 .user_agent(concat!("studio-worker/", env!("CARGO_PKG_VERSION")))
169 .build()?;
170 let started = Instant::now();
171 let mut response = client.get(url).send()?.error_for_status()?;
172 let expected_len = response.content_length();
176 let mut file = std::fs::File::create(dest)?;
177 let bytes = std::io::copy(&mut response, &mut file)?;
178 verify_download_len(bytes, expected_len)
183 .with_context(|| format!("downloading installer from {url}"))?;
184 info!(
185 target: TRACE_TARGET,
186 url,
187 dest = %dest.display(),
188 bytes,
189 elapsed_ms = started.elapsed().as_millis() as u64,
190 "installer downloaded"
191 );
192 Ok(())
193 }
194
195 fn run_installer(&self, installer_path: &Path) -> Result<()> {
196 if cfg!(target_os = "windows") {
197 let status = std::process::Command::new("powershell")
198 .args([
199 "-NoProfile",
200 "-ExecutionPolicy",
201 "Bypass",
202 "-File",
203 installer_path
204 .to_str()
205 .ok_or_else(|| anyhow!("installer path not UTF-8"))?,
206 ])
207 .status()?;
208 if !status.success() {
209 bail!("installer exited with {status}");
210 }
211 } else {
212 let status = std::process::Command::new("sh")
213 .arg(installer_path)
214 .status()?;
215 if !status.success() {
216 bail!("installer exited with {status}");
217 }
218 }
219 Ok(())
220 }
221}
222
223fn validate_installer_download_url(raw: &str) -> Result<()> {
224 let url = url::Url::parse(raw).with_context(|| format!("invalid installer URL {raw:?}"))?;
225 if url.scheme() == "https" {
226 return Ok(());
227 }
228 if url.scheme() == "http" {
229 if let Some(host) = url.host_str() {
230 if host == "localhost"
231 || host
232 .parse::<std::net::IpAddr>()
233 .is_ok_and(|ip| ip.is_loopback())
234 {
235 return Ok(());
236 }
237 }
238 }
239 bail!("installer URL must use https (loopback http is allowed for tests): {raw}");
240}
241
242pub fn apply_with<R: UpdateRunner>(feed_url: &str, latest: &Version, runner: &R) -> Result<()> {
243 info!(
244 target: TRACE_TARGET,
245 feed_url,
246 latest = %latest,
247 "applying update"
248 );
249 let releases = fetch_releases(feed_url)?;
250 let release = releases
251 .iter()
252 .find(|r| parse_tag(&r.tag_name).as_ref() == Some(latest))
253 .ok_or_else(|| anyhow!("release {latest} not present in feed"))?;
254
255 let url = resolve_installer_url(release).ok_or_else(|| {
256 anyhow!(
257 "release {} is missing installer asset {}",
258 latest,
259 installer_asset_name()
260 )
261 })?;
262
263 let tmp = tempfile::tempdir().context("creating tempdir for installer")?;
264 let installer_path = tmp.path().join(installer_asset_name());
265 info!(
266 target: TRACE_TARGET,
267 url,
268 dest = %installer_path.display(),
269 latest = %latest,
270 "downloading installer"
271 );
272 runner.download(url, &installer_path)?;
273 info!(
274 target: TRACE_TARGET,
275 installer = %installer_path.display(),
276 latest = %latest,
277 "running installer"
278 );
279 runner.run_installer(&installer_path)?;
280 info!(
281 target: TRACE_TARGET,
282 latest = %latest,
283 "installer completed; binary replaced"
284 );
285 Ok(())
286}
287
288pub fn restart_argv() -> (PathBuf, Vec<std::ffi::OsString>) {
291 let mut iter = std::env::args_os();
292 let bin = iter
293 .next()
294 .map(PathBuf::from)
295 .unwrap_or_else(|| PathBuf::from("studio-worker"));
296 let args: Vec<std::ffi::OsString> = iter.collect();
297 (bin, args)
298}
299
300#[cfg_attr(coverage_nightly, coverage(off))]
305pub fn restart_self() -> ! {
306 let (bin, args) = restart_argv();
307 info!(
308 target: TRACE_TARGET,
309 bin = %bin.display(),
310 argc = args.len(),
311 "restarting into updated binary"
312 );
313 #[cfg(unix)]
314 {
315 use std::os::unix::process::CommandExt;
316 let err = std::process::Command::new(&bin).args(&args).exec();
317 tracing::error!(
318 target: TRACE_TARGET,
319 bin = %bin.display(),
320 %err,
321 "exec into updated binary failed"
322 );
323 eprintln!("[studio-worker] exec failed: {err}");
324 std::process::exit(1);
325 }
326 #[cfg(not(unix))]
327 {
328 match std::process::Command::new(&bin).args(&args).spawn() {
329 Ok(_) => std::process::exit(0),
330 Err(err) => {
331 tracing::error!(
332 target: TRACE_TARGET,
333 bin = %bin.display(),
334 %err,
335 "spawn-restart of updated binary failed"
336 );
337 eprintln!("[studio-worker] spawn-restart failed: {err}");
338 std::process::exit(1);
339 }
340 }
341 }
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347 use crate::types::{GithubRelease, GithubReleaseAsset};
348 use std::cell::RefCell;
349 use std::path::PathBuf;
350 use tempfile::tempdir;
351
352 fn rel(tag: &str, prerelease: bool, draft: bool, with_installer: bool) -> GithubRelease {
353 let assets = if with_installer {
354 vec![GithubReleaseAsset {
355 name: installer_asset_name().to_string(),
356 browser_download_url: format!("https://example.com/{tag}"),
357 }]
358 } else {
359 vec![]
360 };
361 GithubRelease {
362 tag_name: tag.to_string(),
363 prerelease,
364 draft,
365 assets,
366 }
367 }
368
369 #[test]
370 fn parse_tag_accepts_v_prefix_and_bare() {
371 assert_eq!(parse_tag("v1.2.3"), Some(Version::new(1, 2, 3)));
372 assert_eq!(parse_tag("1.2.3"), Some(Version::new(1, 2, 3)));
373 assert!(parse_tag("garbage").is_none());
374 }
375
376 #[test]
377 fn parse_releases_accepts_array() {
378 let text = serde_json::to_string(&serde_json::json!([
379 { "tag_name": "v1.0.0", "prerelease": false, "draft": false, "assets": [] }
380 ]))
381 .unwrap();
382 let releases = parse_releases(&text).unwrap();
383 assert_eq!(releases.len(), 1);
384 assert_eq!(releases[0].tag_name, "v1.0.0");
385 }
386
387 #[test]
388 fn parse_releases_accepts_single_object() {
389 let text = serde_json::to_string(&serde_json::json!({
390 "tag_name": "v2.0.0", "prerelease": false, "draft": false, "assets": []
391 }))
392 .unwrap();
393 let releases = parse_releases(&text).unwrap();
394 assert_eq!(releases.len(), 1);
395 assert_eq!(releases[0].tag_name, "v2.0.0");
396 }
397
398 #[test]
399 fn parse_releases_errors_on_garbage() {
400 assert!(parse_releases("not json").is_err());
401 }
402
403 #[test]
404 fn decide_reports_up_to_date_when_no_newer() {
405 let releases = vec![rel("v0.1.0", false, false, true)];
406 let outcome = decide(&releases, &Version::new(0, 1, 0), false);
407 assert_eq!(
408 outcome,
409 CheckOutcome::UpToDate {
410 current: Version::new(0, 1, 0)
411 }
412 );
413 }
414
415 #[test]
416 fn decide_reports_newer_when_higher_present() {
417 let releases = vec![
418 rel("v0.1.0", false, false, true),
419 rel("v0.2.0", false, false, true),
420 ];
421 let outcome = decide(&releases, &Version::new(0, 1, 0), false);
422 assert_eq!(
423 outcome,
424 CheckOutcome::NewerAvailable {
425 current: Version::new(0, 1, 0),
426 latest: Version::new(0, 2, 0),
427 }
428 );
429 }
430
431 #[test]
432 fn decide_skips_prereleases_unless_opted_in() {
433 let releases = vec![
434 rel("v0.1.0", false, false, true),
435 rel("v0.3.0-rc.1", true, false, true),
436 ];
437 let outcome = decide(&releases, &Version::new(0, 1, 0), false);
438 assert!(matches!(outcome, CheckOutcome::UpToDate { .. }));
439 let outcome = decide(&releases, &Version::new(0, 1, 0), true);
440 assert!(matches!(outcome, CheckOutcome::NewerAvailable { .. }));
441 }
442
443 #[test]
444 fn decide_skips_drafts() {
445 let releases = vec![
446 rel("v0.1.0", false, false, true),
447 rel("v0.9.0", false, true, true),
448 ];
449 let outcome = decide(&releases, &Version::new(0, 1, 0), false);
450 assert!(matches!(outcome, CheckOutcome::UpToDate { .. }));
451 }
452
453 #[test]
454 fn decide_handles_empty_feed() {
455 let outcome = decide(&[], &Version::new(1, 0, 0), false);
456 assert!(matches!(outcome, CheckOutcome::UpToDate { .. }));
457 }
458
459 #[test]
460 fn decide_skips_malformed_tags() {
461 let releases = vec![
462 rel("garbage", false, false, true),
463 rel("v0.1.0", false, false, true),
464 ];
465 let outcome = decide(&releases, &Version::new(0, 0, 1), false);
466 match outcome {
467 CheckOutcome::NewerAvailable { latest, .. } => {
468 assert_eq!(latest, Version::new(0, 1, 0))
469 }
470 _ => panic!("expected newer"),
471 }
472 }
473
474 #[test]
475 fn installer_asset_name_matches_platform() {
476 let name = installer_asset_name();
477 if cfg!(target_os = "windows") {
478 assert_eq!(name, "studio-worker-installer.ps1");
479 } else {
480 assert_eq!(name, "studio-worker-installer.sh");
481 }
482 }
483
484 #[test]
485 fn resolve_installer_url_finds_the_right_asset() {
486 let release = rel("v1.0.0", false, false, true);
487 let url = resolve_installer_url(&release).unwrap();
488 assert_eq!(url, "https://example.com/v1.0.0");
489 }
490
491 #[test]
492 fn resolve_installer_url_returns_none_when_missing() {
493 let release = rel("v1.0.0", false, false, false);
494 assert!(resolve_installer_url(&release).is_none());
495 }
496
497 #[test]
505 fn verify_download_len_accepts_exact_match() {
506 assert!(verify_download_len(2048, Some(2048)).is_ok());
507 }
508
509 #[test]
510 fn verify_download_len_accepts_when_length_unknown() {
511 assert!(verify_download_len(123, None).is_ok());
514 }
515
516 #[test]
517 fn verify_download_len_rejects_truncated_installer() {
518 let err = verify_download_len(40, Some(100)).unwrap_err().to_string();
519 assert!(err.contains("size mismatch"), "got: {err}");
520 assert!(err.contains("40"), "got: {err}");
521 assert!(err.contains("100"), "got: {err}");
522 }
523
524 #[test]
525 fn verify_download_len_rejects_overlong_installer() {
526 assert!(verify_download_len(120, Some(100)).is_err());
529 }
530
531 #[test]
532 fn validate_installer_download_url_allows_https() {
533 validate_installer_download_url("https://github.com/owner/repo/releases/download/x/i.sh")
534 .unwrap();
535 }
536
537 #[test]
538 fn validate_installer_download_url_allows_loopback_http_for_tests() {
539 validate_installer_download_url("http://127.0.0.1:1234/i.sh").unwrap();
540 validate_installer_download_url("http://localhost:1234/i.sh").unwrap();
541 }
542
543 #[test]
544 fn validate_installer_download_url_rejects_remote_http() {
545 let err = validate_installer_download_url("http://example.com/i.sh")
546 .unwrap_err()
547 .to_string();
548 assert!(err.contains("https"), "got: {err}");
549 }
550
551 #[test]
552 fn restart_argv_uses_current_exe_and_args() {
553 let (bin, _args) = restart_argv();
554 assert!(!bin.as_os_str().is_empty());
555 }
556
557 struct FakeRunner {
562 downloaded: RefCell<Vec<(String, PathBuf)>>,
563 ran: RefCell<Vec<PathBuf>>,
564 fail_download: bool,
565 fail_run: bool,
566 }
567
568 impl UpdateRunner for FakeRunner {
569 fn download(&self, url: &str, dest: &Path) -> Result<()> {
570 self.downloaded
571 .borrow_mut()
572 .push((url.to_string(), dest.to_path_buf()));
573 if self.fail_download {
574 bail!("simulated download failure");
575 }
576 std::fs::write(dest, b"#!/bin/sh\necho fake installer\n").unwrap();
578 Ok(())
579 }
580 fn run_installer(&self, installer_path: &Path) -> Result<()> {
581 self.ran.borrow_mut().push(installer_path.to_path_buf());
582 if self.fail_run {
583 bail!("simulated installer failure");
584 }
585 Ok(())
586 }
587 }
588
589 fn write_fixture_feed(dir: &tempfile::TempDir, releases: serde_json::Value) -> String {
590 let path = dir.path().join("releases.json");
591 std::fs::write(&path, releases.to_string()).unwrap();
592 format!("file://{}", path.to_string_lossy())
593 }
594
595 fn fake_release_with_installer(tag: &str) -> serde_json::Value {
596 serde_json::json!({
597 "tag_name": tag,
598 "prerelease": false,
599 "draft": false,
600 "assets": [{
601 "name": installer_asset_name(),
602 "browser_download_url": format!("https://example.invalid/{tag}/{}", installer_asset_name()),
603 }],
604 })
605 }
606
607 #[test]
612 fn apply_with_errors_when_release_missing() {
613 let releases: Vec<GithubRelease> = vec![rel("v0.1.0", false, false, true)];
618 let missing = Version::new(9, 9, 9);
619 let url = releases
620 .iter()
621 .find(|r| parse_tag(&r.tag_name).as_ref() == Some(&missing));
622 assert!(url.is_none(), "v9.9.9 should not be in the fixture");
623 }
624
625 #[test]
627 fn writing_a_fake_feed_round_trips_through_parse_releases() {
628 let dir = tempdir().unwrap();
629 let url = write_fixture_feed(
630 &dir,
631 serde_json::json!([fake_release_with_installer("v0.1.0")]),
632 );
633 let _ = url;
634 let text = std::fs::read_to_string(dir.path().join("releases.json")).unwrap();
635 let releases = parse_releases(&text).unwrap();
636 assert_eq!(releases.len(), 1);
637 assert_eq!(releases[0].tag_name, "v0.1.0");
638 }
639
640 #[test]
641 fn fake_runner_records_download_and_run() {
642 let runner = FakeRunner {
643 downloaded: RefCell::new(Vec::new()),
644 ran: RefCell::new(Vec::new()),
645 fail_download: false,
646 fail_run: false,
647 };
648 let dir = tempdir().unwrap();
649 let dest = dir.path().join("installer.sh");
650 runner.download("https://example.com/a", &dest).unwrap();
651 runner.run_installer(&dest).unwrap();
652 assert_eq!(runner.downloaded.borrow().len(), 1);
653 assert_eq!(runner.ran.borrow().len(), 1);
654 assert!(dest.exists());
655 }
656
657 #[test]
658 fn fake_runner_surfaces_download_errors() {
659 let runner = FakeRunner {
660 downloaded: RefCell::new(Vec::new()),
661 ran: RefCell::new(Vec::new()),
662 fail_download: true,
663 fail_run: false,
664 };
665 let dir = tempdir().unwrap();
666 let dest = dir.path().join("installer.sh");
667 let err = runner.download("https://example.com/a", &dest).unwrap_err();
668 assert!(err.to_string().contains("simulated download"));
669 }
670
671 #[test]
672 fn fake_runner_surfaces_install_errors() {
673 let runner = FakeRunner {
674 downloaded: RefCell::new(Vec::new()),
675 ran: RefCell::new(Vec::new()),
676 fail_download: false,
677 fail_run: true,
678 };
679 let dir = tempdir().unwrap();
680 let dest = dir.path().join("installer.sh");
681 let err = runner.run_installer(&dest).unwrap_err();
682 assert!(err.to_string().contains("simulated installer"));
683 }
684}