Skip to main content

stmo_cli/
update_checker.rs

1use serde::{Deserialize, Serialize};
2use std::path::PathBuf;
3use std::time::{SystemTime, UNIX_EPOCH};
4
5const CURRENT_VERSION: &str = env!("CARGO_PKG_VERSION");
6const CACHE_MAX_AGE_SECS: u64 = 86400;
7const FETCH_TIMEOUT_SECS: u64 = 2;
8
9#[derive(Debug, Serialize, Deserialize)]
10struct VersionCache {
11    latest_version: String,
12    last_checked: u64,
13}
14
15#[must_use]
16pub fn installed_via_cargo() -> bool {
17    let Ok(exe) = std::env::current_exe() else {
18        return false;
19    };
20    let Ok(exe) = exe.canonicalize() else {
21        return false;
22    };
23    let Some(home) = dirs::home_dir() else {
24        return false;
25    };
26    let cargo_bin = home.join(".cargo").join("bin");
27    exe.starts_with(cargo_bin)
28}
29
30fn cache_path() -> Option<PathBuf> {
31    Some(dirs::cache_dir()?.join("stmo-cli").join("version-check.json"))
32}
33
34fn read_cache() -> Option<VersionCache> {
35    let path = cache_path()?;
36    let data = std::fs::read_to_string(path).ok()?;
37    serde_json::from_str(&data).ok()
38}
39
40fn write_cache(version: &str, timestamp: u64) {
41    let Some(path) = cache_path() else { return };
42    let Some(parent) = path.parent() else { return };
43    let _ = std::fs::create_dir_all(parent);
44    let cache = VersionCache {
45        latest_version: version.to_string(),
46        last_checked: timestamp,
47    };
48    if let Ok(data) = serde_json::to_string(&cache) {
49        let _ = std::fs::write(path, data);
50    }
51}
52
53fn should_check(cache: Option<&VersionCache>) -> bool {
54    let Some(cache) = cache else { return true };
55    let now = SystemTime::now()
56        .duration_since(UNIX_EPOCH)
57        .map(|d| d.as_secs())
58        .unwrap_or(0);
59    now.saturating_sub(cache.last_checked) > CACHE_MAX_AGE_SECS
60}
61
62async fn fetch_latest_version(base_url: &str) -> Option<String> {
63    let url = format!("{base_url}/api/v1/crates/stmo-cli");
64    let user_agent = format!("stmo-cli/{CURRENT_VERSION}");
65    let client = reqwest::Client::builder()
66        .timeout(std::time::Duration::from_secs(FETCH_TIMEOUT_SECS))
67        .build()
68        .ok()?;
69    let resp = client
70        .get(&url)
71        .header("User-Agent", user_agent)
72        .send()
73        .await
74        .ok()?;
75    if !resp.status().is_success() {
76        return None;
77    }
78    let body: serde_json::Value = resp.json().await.ok()?;
79    body["crate"]["max_version"]
80        .as_str()
81        .map(str::to_string)
82}
83
84#[must_use]
85pub fn is_newer(latest: &str, current: &str) -> bool {
86    let Ok(latest) = latest.parse::<semver::Version>() else {
87        return false;
88    };
89    let Ok(current) = current.parse::<semver::Version>() else {
90        return false;
91    };
92    latest > current
93}
94
95#[allow(dead_code)]
96pub async fn check_for_update_from(base_url: &str) -> Option<String> {
97    let latest = fetch_latest_version(base_url).await?;
98    if is_newer(&latest, CURRENT_VERSION) {
99        Some(latest)
100    } else {
101        None
102    }
103}
104
105pub async fn check_and_auto_update() {
106    if !installed_via_cargo() {
107        return;
108    }
109
110    let cache = read_cache();
111    let latest = if should_check(cache.as_ref()) {
112        let now = SystemTime::now()
113            .duration_since(UNIX_EPOCH)
114            .map(|d| d.as_secs())
115            .unwrap_or(0);
116        let Some(version) = fetch_latest_version("https://crates.io").await else {
117            return;
118        };
119        write_cache(&version, now);
120        version
121    } else {
122        let Some(cache) = cache else { return };
123        cache.latest_version
124    };
125
126    if !is_newer(&latest, CURRENT_VERSION) {
127        return;
128    }
129
130    eprintln!("Updating stmo-cli {CURRENT_VERSION} → {latest}...");
131
132    let status = std::process::Command::new("cargo")
133        .args(["install", "stmo-cli"])
134        .status();
135
136    match status {
137        Ok(s) if s.success() => {
138            let args: Vec<String> = std::env::args().collect();
139            #[cfg(unix)]
140            {
141                use std::os::unix::process::CommandExt as _;
142                let mut cmd = std::process::Command::new(&args[0]);
143                cmd.args(&args[1..]);
144                let err = cmd.exec();
145                eprintln!("Failed to re-exec: {err}");
146            }
147            #[cfg(not(unix))]
148            {
149                let _ = std::process::Command::new(&args[0])
150                    .args(&args[1..])
151                    .status();
152                std::process::exit(0);
153            }
154        }
155        _ => {
156            eprintln!("Update failed, continuing with current version.");
157        }
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164    use tempfile::TempDir;
165
166    #[test]
167    fn is_newer_detects_patch_bump() {
168        assert!(is_newer("0.3.1", "0.3.0"));
169    }
170
171    #[test]
172    fn is_newer_detects_minor_bump() {
173        assert!(is_newer("0.4.0", "0.3.0"));
174    }
175
176    #[test]
177    fn is_newer_detects_major_bump() {
178        assert!(is_newer("1.0.0", "0.3.0"));
179    }
180
181    #[test]
182    fn is_newer_same_version_false() {
183        assert!(!is_newer("0.3.0", "0.3.0"));
184    }
185
186    #[test]
187    fn is_newer_older_version_false() {
188        assert!(!is_newer("0.2.0", "0.3.0"));
189    }
190
191    #[test]
192    fn is_newer_invalid_latest_false() {
193        assert!(!is_newer("not-a-version", "0.3.0"));
194    }
195
196    #[test]
197    fn is_newer_invalid_current_false() {
198        assert!(!is_newer("0.4.0", "not-a-version"));
199    }
200
201    #[test]
202    fn should_check_returns_true_when_no_cache() {
203        assert!(should_check(None));
204    }
205
206    #[test]
207    fn should_check_returns_true_for_stale_cache() {
208        let stale = VersionCache {
209            latest_version: "0.3.0".to_string(),
210            last_checked: 0,
211        };
212        assert!(should_check(Some(&stale)));
213    }
214
215    #[test]
216    fn should_check_returns_false_for_fresh_cache() {
217        let now = SystemTime::now()
218            .duration_since(UNIX_EPOCH)
219            .map(|d| d.as_secs())
220            .unwrap_or(0);
221        let fresh = VersionCache {
222            latest_version: "0.3.0".to_string(),
223            last_checked: now,
224        };
225        assert!(!should_check(Some(&fresh)));
226    }
227
228    #[test]
229    fn cache_roundtrip() {
230        let temp = TempDir::new().unwrap();
231        let path = temp.path().join("version-check.json");
232
233        let cache = VersionCache {
234            latest_version: "1.2.3".to_string(),
235            last_checked: 9999,
236        };
237        std::fs::write(&path, serde_json::to_string(&cache).unwrap()).unwrap();
238
239        let loaded: VersionCache =
240            serde_json::from_str(&std::fs::read_to_string(&path).unwrap()).unwrap();
241        assert_eq!(loaded.latest_version, "1.2.3");
242        assert_eq!(loaded.last_checked, 9999);
243    }
244}