workhelix_cli_common/
update.rs

1//! Self-update module.
2//!
3//! This module provides self-update functionality for CLI tools, including:
4//! - Checking for latest releases on GitHub
5//! - Downloading release binaries
6//! - Verifying checksums
7//! - Replacing the current binary
8
9use crate::types::RepoInfo;
10use sha2::{Digest, Sha256};
11use std::fs;
12use std::io::{self, Write};
13use std::path::Path;
14
15/// Run update command to install latest or specified version.
16///
17/// Returns exit code: 0 if successful, 1 on error, 2 if already up-to-date.
18///
19/// # Arguments
20/// * `repo_info` - Repository information for GitHub integration
21/// * `current_version` - Current version of the tool
22/// * `version` - Optional specific version to install
23/// * `force` - Force installation even if already up-to-date
24/// * `install_dir` - Optional custom installation directory
25///
26/// # Panics
27/// May panic if stdout flush fails or stdin read fails during user confirmation.
28#[must_use]
29pub fn run_update(
30    repo_info: &RepoInfo,
31    current_version: &str,
32    version: Option<&str>,
33    force: bool,
34    install_dir: Option<&Path>,
35) -> i32 {
36    println!("🔄 Checking for updates...");
37
38    // Get target version
39    let target_version = if let Some(v) = version {
40        v.to_string()
41    } else {
42        match get_latest_version(repo_info) {
43            Ok(v) => v,
44            Err(e) => {
45                eprintln!("❌ Failed to check for updates: {e}");
46                return 1;
47            }
48        }
49    };
50
51    // Check if already up-to-date
52    if target_version == current_version && !force {
53        println!("✅ Already running latest version (v{current_version})");
54        return 2;
55    }
56
57    println!("✨ Update available: v{target_version} (current: v{current_version})");
58
59    // Detect current binary location
60    let install_path = if let Some(dir) = install_dir {
61        dir.join(repo_info.name)
62    } else {
63        match std::env::current_exe() {
64            Ok(path) => path,
65            Err(e) => {
66                eprintln!("❌ Failed to determine binary location: {e}");
67                return 1;
68            }
69        }
70    };
71
72    println!("📍 Install location: {}", install_path.display());
73    println!();
74
75    // Confirm unless forced
76    if !force {
77        print!("Continue with update? [y/N]: ");
78        io::stdout().flush().unwrap();
79
80        let mut response = String::new();
81        io::stdin().read_line(&mut response).unwrap();
82
83        if !matches!(response.trim().to_lowercase().as_str(), "y" | "yes") {
84            println!("Update cancelled.");
85            return 0;
86        }
87    }
88
89    // Perform update
90    match perform_update(repo_info, &target_version, &install_path) {
91        Ok(()) => {
92            println!("✅ Successfully updated to v{target_version}");
93            println!();
94            println!(
95                "Run '{} --version' to verify the installation.",
96                repo_info.name
97            );
98            0
99        }
100        Err(e) => {
101            eprintln!("❌ Update failed: {e}");
102            1
103        }
104    }
105}
106
107/// Get the latest version from GitHub releases.
108///
109/// # Errors
110/// Returns an error if the HTTP request fails, the response cannot be parsed,
111/// or the `tag_name` field is missing.
112pub fn get_latest_version(repo_info: &RepoInfo) -> Result<String, String> {
113    let client = reqwest::blocking::Client::builder()
114        .user_agent(format!("{}-updater", repo_info.name))
115        .timeout(std::time::Duration::from_secs(10))
116        .build()
117        .map_err(|e| e.to_string())?;
118
119    let response: serde_json::Value = client
120        .get(repo_info.latest_release_url())
121        .send()
122        .map_err(|e| e.to_string())?
123        .json()
124        .map_err(|e| e.to_string())?;
125
126    let tag_name = response["tag_name"]
127        .as_str()
128        .ok_or_else(|| "No tag_name in response".to_string())?;
129
130    let version = tag_name
131        .trim_start_matches(repo_info.tag_prefix)
132        .trim_start_matches('v');
133    Ok(version.to_string())
134}
135
136fn perform_update(repo_info: &RepoInfo, version: &str, install_path: &Path) -> Result<(), String> {
137    // Detect platform
138    let platform = get_platform_string();
139    let archive_ext = if cfg!(target_os = "windows") {
140        "zip"
141    } else {
142        "tar.gz"
143    };
144
145    let download_url = repo_info.download_url(version, &platform, archive_ext);
146
147    println!(
148        "📥 Downloading {}-{platform}.{archive_ext}...",
149        repo_info.name
150    );
151
152    // Download file
153    let client = reqwest::blocking::Client::builder()
154        .user_agent(format!("{}-updater", repo_info.name))
155        .timeout(std::time::Duration::from_secs(300))
156        .build()
157        .map_err(|e| e.to_string())?;
158
159    let response = client
160        .get(&download_url)
161        .send()
162        .map_err(|e| e.to_string())?;
163
164    if !response.status().is_success() {
165        return Err(format!("Download failed: HTTP {}", response.status()));
166    }
167
168    let bytes = response.bytes().map_err(|e| e.to_string())?;
169
170    // Download checksum
171    let checksum_url = format!("{download_url}.sha256");
172    let checksum_response = client
173        .get(&checksum_url)
174        .send()
175        .map_err(|e| e.to_string())?;
176
177    if checksum_response.status().is_success() {
178        println!("🔐 Verifying checksum...");
179        let expected_checksum = checksum_response.text().map_err(|e| e.to_string())?;
180        let expected_checksum = expected_checksum
181            .split_whitespace()
182            .next()
183            .unwrap_or(&expected_checksum);
184
185        let mut hasher = Sha256::new();
186        hasher.update(&bytes);
187        let computed_checksum = hex::encode(hasher.finalize());
188
189        if computed_checksum.to_lowercase() != expected_checksum.to_lowercase() {
190            return Err(format!(
191                "Checksum mismatch!\nExpected: {expected_checksum}\nGot: {computed_checksum}"
192            ));
193        }
194        println!("✅ Checksum verified");
195    } else {
196        println!("⚠️  No checksum found, skipping verification");
197    }
198
199    // Extract archive
200    println!("📦 Extracting archive...");
201    let temp_dir = tempfile::tempdir().map_err(|e| e.to_string())?;
202
203    if cfg!(target_os = "windows") {
204        extract_zip(&bytes, temp_dir.path())?;
205    } else {
206        extract_tar_gz(&bytes, temp_dir.path())?;
207    }
208
209    // Find the binary in the extracted files
210    let binary_name = if cfg!(target_os = "windows") {
211        format!("{}.exe", repo_info.name)
212    } else {
213        repo_info.name.to_string()
214    };
215
216    let extracted_binary = temp_dir.path().join(&binary_name);
217    if !extracted_binary.exists() {
218        return Err(format!("Binary {binary_name} not found in archive"));
219    }
220
221    // Replace the current binary
222    println!("🔧 Installing update...");
223
224    // Backup current binary
225    let backup_path = install_path.with_extension("bak");
226    if let Err(e) = fs::copy(install_path, &backup_path) {
227        eprintln!("⚠️  Failed to create backup: {e}");
228    }
229
230    // Copy new binary
231    fs::copy(&extracted_binary, install_path)
232        .map_err(|e| format!("Failed to install binary: {e}"))?;
233
234    // Set executable permissions on Unix
235    #[cfg(unix)]
236    {
237        use std::os::unix::fs::PermissionsExt;
238        let mut perms = fs::metadata(install_path)
239            .map_err(|e| format!("Failed to get metadata: {e}"))?
240            .permissions();
241        perms.set_mode(0o755);
242        fs::set_permissions(install_path, perms)
243            .map_err(|e| format!("Failed to set permissions: {e}"))?;
244    }
245
246    // Clean up backup
247    if backup_path.exists() {
248        let _ = fs::remove_file(&backup_path);
249    }
250
251    Ok(())
252}
253
254fn get_platform_string() -> String {
255    let os = std::env::consts::OS;
256    let arch = std::env::consts::ARCH;
257
258    match (os, arch) {
259        ("linux", "x86_64") => "x86_64-unknown-linux-gnu",
260        ("linux", "aarch64") => "aarch64-unknown-linux-gnu",
261        ("macos", "x86_64") => "x86_64-apple-darwin",
262        ("macos", "aarch64") => "aarch64-apple-darwin",
263        ("windows", "x86_64") => "x86_64-pc-windows-msvc",
264        _ => panic!("Unsupported platform: {os}/{arch}"),
265    }
266    .to_string()
267}
268
269fn extract_tar_gz(bytes: &[u8], dest: &Path) -> Result<(), String> {
270    use flate2::read::GzDecoder;
271    use tar::Archive;
272
273    let decoder = GzDecoder::new(bytes);
274    let mut archive = Archive::new(decoder);
275    archive
276        .unpack(dest)
277        .map_err(|e| format!("Failed to extract tar.gz: {e}"))
278}
279
280#[cfg(target_os = "windows")]
281fn extract_zip(bytes: &[u8], dest: &Path) -> Result<(), String> {
282    use std::io::Cursor;
283    use zip::ZipArchive;
284
285    let reader = Cursor::new(bytes);
286    let mut archive = ZipArchive::new(reader).map_err(|e| format!("Failed to open zip: {e}"))?;
287
288    for i in 0..archive.len() {
289        let mut file = archive
290            .by_index(i)
291            .map_err(|e| format!("Failed to read zip entry: {e}"))?;
292        let outpath = dest.join(file.name());
293
294        if file.is_dir() {
295            fs::create_dir_all(&outpath).map_err(|e| format!("Failed to create directory: {e}"))?;
296        } else {
297            if let Some(p) = outpath.parent() {
298                fs::create_dir_all(p)
299                    .map_err(|e| format!("Failed to create parent directory: {e}"))?;
300            }
301            let mut outfile =
302                fs::File::create(&outpath).map_err(|e| format!("Failed to create file: {e}"))?;
303            io::copy(&mut file, &mut outfile)
304                .map_err(|e| format!("Failed to extract file: {e}"))?;
305        }
306    }
307
308    Ok(())
309}
310
311#[cfg(not(target_os = "windows"))]
312fn extract_zip(_bytes: &[u8], _dest: &Path) -> Result<(), String> {
313    Err("ZIP extraction not supported on this platform".to_string())
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319
320    #[test]
321    fn test_get_platform_string() {
322        let platform = get_platform_string();
323        // Just verify it returns a non-empty string
324        assert!(!platform.is_empty());
325        assert!(platform.contains('-'));
326    }
327
328    #[test]
329    fn test_repo_info_latest_release_url() {
330        let repo = RepoInfo::new("workhelix", "prompter", "prompter-v");
331        let url = repo.latest_release_url();
332        assert_eq!(
333            url,
334            "https://api.github.com/repos/workhelix/prompter/releases/latest"
335        );
336    }
337
338    #[test]
339    fn test_get_latest_version_handles_errors() {
340        let repo = RepoInfo::new("nonexistent", "repo", "v");
341        // This should fail since the repo doesn't exist
342        let result = get_latest_version(&repo);
343        // Either succeeds (unlikely) or fails (expected)
344        assert!(result.is_ok() || result.is_err());
345    }
346}