Skip to main content

st/
updater.rs

1// -----------------------------------------------------------------------------
2// Self-Update Module for Smart Tree
3// Checks for updates from GitHub releases and installs new versions
4// -----------------------------------------------------------------------------
5
6use anyhow::{bail, Context, Result};
7use serde::{Deserialize, Serialize};
8use std::env;
9use std::fs;
10use std::io::{self, Write};
11use std::path::{Path, PathBuf};
12use std::process::Command;
13use std::time::{Duration, SystemTime, UNIX_EPOCH};
14
15/// GitHub repository for releases
16const GITHUB_REPO: &str = "8b-is/smart-tree";
17
18/// GitHub API endpoint for latest release
19const GITHUB_RELEASES_API: &str = "https://api.github.com/repos/8b-is/smart-tree/releases/latest";
20
21/// Rate limit: check for updates at most once per 24 hours
22const UPDATE_CHECK_INTERVAL_SECS: u64 = 86400;
23
24/// Binaries included in the release tarball
25/// Note: "n8x" replaces "tree" to avoid shadowing the real tree command
26const BINARIES: &[&str] = &["st", "mq", "m8", "n8x"];
27
28/// Current version from Cargo.toml
29const CURRENT_VERSION: &str = env!("CARGO_PKG_VERSION");
30
31/// GitHub release response (partial)
32#[derive(Debug, Deserialize)]
33struct GitHubRelease {
34    tag_name: String,
35    assets: Vec<GitHubAsset>,
36}
37
38#[derive(Debug, Deserialize)]
39struct GitHubAsset {
40    name: String,
41    browser_download_url: String,
42}
43
44/// Update check cache
45#[derive(Debug, Default, Serialize, Deserialize)]
46struct UpdateCache {
47    #[serde(default)]
48    last_check: u64,
49    #[serde(default)]
50    latest_version: Option<String>,
51}
52
53/// Get the cache file path (~/.st/update_check.json)
54fn get_cache_path() -> Result<PathBuf> {
55    let home = dirs::home_dir().context("Could not find home directory")?;
56    let st_dir = home.join(".st");
57    fs::create_dir_all(&st_dir)?;
58    Ok(st_dir.join("update_check.json"))
59}
60
61/// Load the update cache
62fn load_cache() -> UpdateCache {
63    let cache_path = match get_cache_path() {
64        Ok(p) => p,
65        Err(_) => return UpdateCache::default(),
66    };
67
68    match fs::read_to_string(&cache_path) {
69        Ok(contents) => serde_json::from_str(&contents).unwrap_or_default(),
70        Err(_) => UpdateCache::default(),
71    }
72}
73
74/// Save the update cache
75fn save_cache(cache: &UpdateCache) -> Result<()> {
76    let cache_path = get_cache_path()?;
77    let contents = serde_json::to_string_pretty(cache)?;
78    fs::write(&cache_path, contents)?;
79    Ok(())
80}
81
82/// Get current timestamp in seconds
83fn now_secs() -> u64 {
84    SystemTime::now()
85        .duration_since(UNIX_EPOCH)
86        .map(|d| d.as_secs())
87        .unwrap_or(0)
88}
89
90/// Check if we should perform an update check (rate limiting)
91pub fn should_check_update() -> bool {
92    let cache = load_cache();
93    let now = now_secs();
94    now.saturating_sub(cache.last_check) > UPDATE_CHECK_INTERVAL_SECS
95}
96
97/// Compare version strings (semver-like)
98fn is_newer_version(current: &str, latest: &str) -> bool {
99    // Strip 'v' prefix if present
100    let current = current.strip_prefix('v').unwrap_or(current);
101    let latest = latest.strip_prefix('v').unwrap_or(latest);
102
103    let parse_version = |v: &str| -> (u32, u32, u32) {
104        let parts: Vec<u32> = v.split('.').filter_map(|p| p.parse().ok()).collect();
105        (
106            parts.first().copied().unwrap_or(0),
107            parts.get(1).copied().unwrap_or(0),
108            parts.get(2).copied().unwrap_or(0),
109        )
110    };
111
112    let (curr_major, curr_minor, curr_patch) = parse_version(current);
113    let (lat_major, lat_minor, lat_patch) = parse_version(latest);
114
115    (lat_major, lat_minor, lat_patch) > (curr_major, curr_minor, curr_patch)
116}
117
118/// Check for available updates (network call) - async version
119pub async fn check_for_update() -> Result<Option<String>> {
120    let client = reqwest::Client::builder()
121        .user_agent("smart-tree-updater")
122        .timeout(Duration::from_secs(10))
123        .build()?;
124
125    let response: GitHubRelease = client
126        .get(GITHUB_RELEASES_API)
127        .send()
128        .await
129        .context("Failed to connect to GitHub")?
130        .json()
131        .await
132        .context("Failed to parse GitHub response")?;
133
134    // Update cache
135    let mut cache = load_cache();
136    cache.last_check = now_secs();
137    cache.latest_version = Some(response.tag_name.clone());
138    let _ = save_cache(&cache);
139
140    let latest = response.tag_name;
141    if is_newer_version(CURRENT_VERSION, &latest) {
142        Ok(Some(latest))
143    } else {
144        Ok(None)
145    }
146}
147
148/// Check for update using cache if within rate limit - async version
149pub async fn check_for_update_cached() -> Option<String> {
150    let cache = load_cache();
151
152    if should_check_update() {
153        // Perform actual check
154        match check_for_update().await {
155            Ok(Some(version)) => Some(version),
156            Ok(None) => None,
157            Err(_) => None, // Silently fail on network errors
158        }
159    } else {
160        // Use cached result
161        cache
162            .latest_version
163            .filter(|v| is_newer_version(CURRENT_VERSION, v))
164    }
165}
166
167/// Print update available banner
168pub fn print_update_banner(latest_version: &str) {
169    let current = format!("v{}", CURRENT_VERSION);
170    eprintln!();
171    eprintln!("\x1b[36m╭─────────────────────────────────────────────────────╮\x1b[0m");
172    eprintln!(
173        "\x1b[36m│\x1b[0m \x1b[32m🌳 Smart Tree {} is available!\x1b[0m (you have {})",
174        latest_version, current
175    );
176    eprintln!("\x1b[36m│\x1b[0m    Run '\x1b[1mst --update\x1b[0m' to upgrade");
177    eprintln!("\x1b[36m╰─────────────────────────────────────────────────────╯\x1b[0m");
178    eprintln!();
179}
180
181/// Detect the current platform for download
182fn get_platform() -> Result<(&'static str, &'static str)> {
183    let os = if cfg!(target_os = "macos") {
184        "apple-darwin"
185    } else if cfg!(target_os = "linux") {
186        "unknown-linux-gnu"
187    } else if cfg!(target_os = "windows") {
188        "pc-windows-msvc"
189    } else {
190        bail!("Unsupported operating system");
191    };
192
193    let arch = if cfg!(target_arch = "x86_64") {
194        "x86_64"
195    } else if cfg!(target_arch = "aarch64") {
196        "aarch64"
197    } else {
198        bail!("Unsupported architecture");
199    };
200
201    Ok((arch, os))
202}
203
204/// Create a temporary directory for the update
205fn create_temp_dir() -> Result<PathBuf> {
206    let base = env::temp_dir();
207    let unique_name = format!("st-update-{}", now_secs());
208    let temp_dir = base.join(unique_name);
209    fs::create_dir_all(&temp_dir).context("Failed to create temp directory")?;
210    Ok(temp_dir)
211}
212
213/// Clean up a temporary directory
214fn cleanup_temp_dir(path: &Path) {
215    let _ = fs::remove_dir_all(path);
216}
217
218/// Find where the current binary is installed
219fn find_install_dir() -> Result<PathBuf> {
220    // Try to find where 'st' is installed
221    let current_exe = env::current_exe().context("Could not determine current executable path")?;
222    let install_dir = current_exe
223        .parent()
224        .context("Could not determine installation directory")?
225        .to_path_buf();
226
227    Ok(install_dir)
228}
229
230/// Check if we need elevated permissions
231fn needs_sudo(install_dir: &Path) -> bool {
232    #[cfg(unix)]
233    {
234        use std::os::unix::fs::MetadataExt;
235        if let Ok(meta) = install_dir.metadata() {
236            // Check if we're the owner or if we can write
237            let uid = unsafe { libc::getuid() };
238            if meta.uid() != uid {
239                // Not owner, check if writable
240                return fs::metadata(install_dir)
241                    .and_then(|_| {
242                        fs::OpenOptions::new()
243                            .write(true)
244                            .open(install_dir.join(".test_write"))
245                    })
246                    .is_err();
247            }
248        }
249        false
250    }
251    #[cfg(not(unix))]
252    {
253        false
254    }
255}
256
257/// Download and install the update - async version
258pub async fn download_and_install(version: &str, yes: bool) -> Result<()> {
259    let (arch, os) = get_platform()?;
260    let install_dir = find_install_dir()?;
261
262    println!("\x1b[36m🌳 Smart Tree Updater\x1b[0m");
263    println!();
264    println!("Current version: v{}", CURRENT_VERSION);
265    println!("Latest version:  {}", version);
266    println!("Install path:    {}", install_dir.display());
267    println!("Binaries:        {}", BINARIES.join(", "));
268    println!();
269
270    if !yes {
271        print!("Proceed with update? [Y/n] ");
272        io::stdout().flush()?;
273
274        let mut input = String::new();
275        io::stdin().read_line(&mut input)?;
276        let input = input.trim().to_lowercase();
277
278        if !input.is_empty() && input != "y" && input != "yes" {
279            println!("Update cancelled.");
280            return Ok(());
281        }
282    }
283
284    let use_sudo = needs_sudo(&install_dir);
285    if use_sudo {
286        println!("\x1b[33m⚠ Installation directory requires elevated permissions.\x1b[0m");
287        println!("  You may be prompted for your password.\n");
288    }
289
290    // Construct download URL
291    let ext = if cfg!(target_os = "windows") {
292        "zip"
293    } else {
294        "tar.gz"
295    };
296    let archive_name = format!("st-{}-{}-{}.{}", version, arch, os, ext);
297    let download_url = format!(
298        "https://github.com/{}/releases/download/{}/{}",
299        GITHUB_REPO, version, archive_name
300    );
301
302    println!("Downloading {}...", archive_name);
303
304    // Create temp directory
305    let temp_dir = create_temp_dir()?;
306    let archive_path = temp_dir.join(&archive_name);
307
308    // Download
309    let client = reqwest::Client::builder()
310        .user_agent("smart-tree-updater")
311        .timeout(Duration::from_secs(300))
312        .build()?;
313
314    let response = client
315        .get(&download_url)
316        .send()
317        .await
318        .context("Failed to download release")?;
319
320    if !response.status().is_success() {
321        bail!("Download failed: HTTP {}", response.status());
322    }
323
324    let bytes = response.bytes().await?;
325    fs::write(&archive_path, &bytes)?;
326
327    println!("Extracting...");
328
329    // Extract archive
330    #[cfg(unix)]
331    {
332        let output = Command::new("tar")
333            .args(["-xzf", archive_path.to_str().unwrap()])
334            .current_dir(&temp_dir)
335            .output()
336            .context("Failed to extract archive")?;
337
338        if !output.status.success() {
339            bail!(
340                "Failed to extract archive: {}",
341                String::from_utf8_lossy(&output.stderr)
342            );
343        }
344    }
345
346    #[cfg(windows)]
347    {
348        // On Windows, use powershell to extract zip
349        let output = Command::new("powershell")
350            .args([
351                "-Command",
352                &format!(
353                    "Expand-Archive -Path '{}' -DestinationPath '{}' -Force",
354                    archive_path.display(),
355                    &temp_dir.display()
356                ),
357            ])
358            .output()
359            .context("Failed to extract archive")?;
360
361        if !output.status.success() {
362            bail!(
363                "Failed to extract archive: {}",
364                String::from_utf8_lossy(&output.stderr)
365            );
366        }
367    }
368
369    // Install binaries
370    println!("Installing binaries...");
371
372    let mut installed_count = 0;
373    for binary in BINARIES {
374        let binary_name = if cfg!(windows) {
375            format!("{}.exe", binary)
376        } else {
377            binary.to_string()
378        };
379
380        // Find binary in temp dir (might be at root or in subdirectory)
381        let src_path = match find_binary_in_dir(&temp_dir, &binary_name) {
382            Ok(path) => path,
383            Err(_) => {
384                // Binary not in archive - skip it (older releases may not have all binaries)
385                println!("  \x1b[33m⚠\x1b[0m {} (not in archive, skipping)", binary);
386                continue;
387            }
388        };
389        let dest_path = install_dir.join(&binary_name);
390
391        // IMPORTANT: Remove old binary first to avoid macOS zombie process issue
392        #[cfg(unix)]
393        {
394            if use_sudo {
395                let _ = Command::new("sudo")
396                    .args(["rm", "-f", dest_path.to_str().unwrap()])
397                    .status();
398
399                Command::new("sudo")
400                    .args([
401                        "cp",
402                        src_path.to_str().unwrap(),
403                        dest_path.to_str().unwrap(),
404                    ])
405                    .status()
406                    .context(format!("Failed to install {}", binary))?;
407
408                Command::new("sudo")
409                    .args(["chmod", "+x", dest_path.to_str().unwrap()])
410                    .status()?;
411            } else {
412                let _ = fs::remove_file(&dest_path);
413                fs::copy(&src_path, &dest_path).context(format!("Failed to install {}", binary))?;
414
415                // Set executable permission
416                use std::os::unix::fs::PermissionsExt;
417                let mut perms = fs::metadata(&dest_path)?.permissions();
418                perms.set_mode(0o755);
419                fs::set_permissions(&dest_path, perms)?;
420            }
421        }
422
423        #[cfg(windows)]
424        {
425            // On Windows, rename old binary first (can't delete while running)
426            let old_path = install_dir.join(format!("{}.old", binary_name));
427            let _ = fs::remove_file(&old_path);
428            let _ = fs::rename(&dest_path, &old_path);
429
430            fs::copy(&src_path, &dest_path).context(format!("Failed to install {}", binary))?;
431        }
432
433        println!("  \x1b[32m✓\x1b[0m {}", binary);
434        installed_count += 1;
435    }
436
437    // Ensure at least the main binary was installed
438    if installed_count == 0 {
439        bail!("No binaries were installed from the archive");
440    }
441
442    // Update cache
443    let mut cache = load_cache();
444    cache.latest_version = Some(version.to_string());
445    let _ = save_cache(&cache);
446
447    // Clean up temp directory
448    cleanup_temp_dir(&temp_dir);
449
450    println!();
451    println!("\x1b[32m✨ Successfully updated to {}!\x1b[0m", version);
452
453    #[cfg(windows)]
454    {
455        println!();
456        println!(
457            "\x1b[33mNote: Please restart your terminal for the update to take effect.\x1b[0m"
458        );
459    }
460
461    Ok(())
462}
463
464/// Find a binary file within a directory (handles nested extraction)
465fn find_binary_in_dir(dir: &Path, binary_name: &str) -> Result<PathBuf> {
466    // Check root
467    let root_path = dir.join(binary_name);
468    if root_path.exists() {
469        return Ok(root_path);
470    }
471
472    // Search subdirectories
473    for entry in fs::read_dir(dir)? {
474        let entry = entry?;
475        let path = entry.path();
476        if path.is_dir() {
477            let nested = path.join(binary_name);
478            if nested.exists() {
479                return Ok(nested);
480            }
481        }
482    }
483
484    bail!("Could not find {} in downloaded archive", binary_name)
485}
486
487/// Run the update command - async version
488pub async fn run_update(yes: bool) -> Result<()> {
489    println!("Checking for updates...");
490
491    match check_for_update().await? {
492        Some(version) => {
493            download_and_install(&version, yes).await?;
494        }
495        None => {
496            println!(
497                "\x1b[32m✓\x1b[0m Already up to date! (v{})",
498                CURRENT_VERSION
499            );
500        }
501    }
502
503    Ok(())
504}
505
506/// Get current version string
507pub fn current_version() -> &'static str {
508    CURRENT_VERSION
509}
510
511#[cfg(test)]
512mod tests {
513    use super::*;
514
515    #[test]
516    fn test_version_comparison() {
517        assert!(is_newer_version("5.5.0", "5.5.1"));
518        assert!(is_newer_version("5.5.1", "5.6.0"));
519        assert!(is_newer_version("5.5.1", "6.0.0"));
520        assert!(is_newer_version("v5.5.0", "v5.5.1"));
521        assert!(!is_newer_version("5.5.1", "5.5.1"));
522        assert!(!is_newer_version("5.5.1", "5.5.0"));
523        assert!(!is_newer_version("6.0.0", "5.5.1"));
524    }
525
526    #[test]
527    fn test_platform_detection() {
528        let result = get_platform();
529        assert!(result.is_ok());
530    }
531}