workhelix_cli_common/
update.rs

1//! Self-update module.
2//!
3//! This module provides self-update functionality for CLI tools, including:
4//! - Checking for latest releases on GitHub
5//! - Downloading release binaries
6//! - Verifying checksums
7//! - Replacing the current binary
8
9use crate::types::RepoInfo;
10use sha2::{Digest, Sha256};
11use std::fs;
12use std::io::{self, Write};
13use std::path::Path;
14
15/// Run update command to install latest or specified version.
16///
17/// Returns exit code: 0 if successful, 1 on error, 2 if already up-to-date.
18///
19/// # Arguments
20/// * `repo_info` - Repository information for GitHub integration
21/// * `current_version` - Current version of the tool
22/// * `version` - Optional specific version to install
23/// * `force` - Force installation even if already up-to-date
24/// * `install_dir` - Optional custom installation directory
25///
26/// # Panics
27/// May panic if stdout flush fails or stdin read fails during user confirmation.
28#[must_use]
29pub fn run_update(
30    repo_info: &RepoInfo,
31    current_version: &str,
32    version: Option<&str>,
33    force: bool,
34    install_dir: Option<&Path>,
35) -> i32 {
36    println!("🔄 Checking for updates...");
37
38    // Get target version
39    let target_version = if let Some(v) = version {
40        v.to_string()
41    } else {
42        match get_latest_version(repo_info) {
43            Ok(v) => v,
44            Err(e) => {
45                eprintln!("❌ Failed to check for updates: {e}");
46                return 1;
47            }
48        }
49    };
50
51    // Check if already up-to-date
52    if target_version == current_version && !force {
53        println!("✅ Already running latest version (v{current_version})");
54        return 2;
55    }
56
57    println!("✨ Update available: v{target_version} (current: v{current_version})");
58
59    // Detect current binary location
60    let install_path = if let Some(dir) = install_dir {
61        dir.join(repo_info.name)
62    } else {
63        match std::env::current_exe() {
64            Ok(path) => path,
65            Err(e) => {
66                eprintln!("❌ Failed to determine binary location: {e}");
67                return 1;
68            }
69        }
70    };
71
72    println!("📍 Install location: {}", install_path.display());
73    println!();
74
75    // Confirm unless forced
76    if !force {
77        print!("Continue with update? [y/N]: ");
78        io::stdout().flush().unwrap();
79
80        let mut response = String::new();
81        io::stdin().read_line(&mut response).unwrap();
82
83        if !matches!(response.trim().to_lowercase().as_str(), "y" | "yes") {
84            println!("Update cancelled.");
85            return 0;
86        }
87    }
88
89    // Perform update
90    match perform_update(repo_info, &target_version, &install_path) {
91        Ok(()) => {
92            println!("✅ Successfully updated to v{target_version}");
93            println!();
94            println!("Run '{} --version' to verify the installation.", repo_info.name);
95            0
96        }
97        Err(e) => {
98            eprintln!("❌ Update failed: {e}");
99            1
100        }
101    }
102}
103
104/// Get the latest version from GitHub releases.
105///
106/// # Errors
107/// Returns an error if the HTTP request fails, the response cannot be parsed,
108/// or the `tag_name` field is missing.
109pub fn get_latest_version(repo_info: &RepoInfo) -> Result<String, String> {
110    let client = reqwest::blocking::Client::builder()
111        .user_agent(format!("{}-updater", repo_info.name))
112        .timeout(std::time::Duration::from_secs(10))
113        .build()
114        .map_err(|e| e.to_string())?;
115
116    let response: serde_json::Value = client
117        .get(repo_info.latest_release_url())
118        .send()
119        .map_err(|e| e.to_string())?
120        .json()
121        .map_err(|e| e.to_string())?;
122
123    let tag_name = response["tag_name"]
124        .as_str()
125        .ok_or_else(|| "No tag_name in response".to_string())?;
126
127    let version = tag_name
128        .trim_start_matches(repo_info.tag_prefix)
129        .trim_start_matches('v');
130    Ok(version.to_string())
131}
132
133fn perform_update(repo_info: &RepoInfo, version: &str, install_path: &Path) -> Result<(), String> {
134    // Detect platform
135    let platform = get_platform_string();
136    let archive_ext = if cfg!(target_os = "windows") {
137        "zip"
138    } else {
139        "tar.gz"
140    };
141
142    let download_url = repo_info.download_url(version, &platform, archive_ext);
143
144    println!("📥 Downloading {}-{platform}.{archive_ext}...", repo_info.name);
145
146    // Download file
147    let client = reqwest::blocking::Client::builder()
148        .user_agent(format!("{}-updater", repo_info.name))
149        .timeout(std::time::Duration::from_secs(300))
150        .build()
151        .map_err(|e| e.to_string())?;
152
153    let response = client
154        .get(&download_url)
155        .send()
156        .map_err(|e| e.to_string())?;
157
158    if !response.status().is_success() {
159        return Err(format!("Download failed: HTTP {}", response.status()));
160    }
161
162    let bytes = response.bytes().map_err(|e| e.to_string())?;
163
164    // Download checksum
165    let checksum_url = format!("{download_url}.sha256");
166    let checksum_response = client
167        .get(&checksum_url)
168        .send()
169        .map_err(|e| e.to_string())?;
170
171    if checksum_response.status().is_success() {
172        println!("🔐 Verifying checksum...");
173        let expected_checksum = checksum_response.text().map_err(|e| e.to_string())?;
174        let expected_checksum = expected_checksum.split_whitespace().next().unwrap_or(&expected_checksum);
175
176        let mut hasher = Sha256::new();
177        hasher.update(&bytes);
178        let computed_checksum = hex::encode(hasher.finalize());
179
180        if computed_checksum.to_lowercase() != expected_checksum.to_lowercase() {
181            return Err(format!(
182                "Checksum mismatch!\nExpected: {expected_checksum}\nGot: {computed_checksum}"
183            ));
184        }
185        println!("✅ Checksum verified");
186    } else {
187        println!("⚠️  No checksum found, skipping verification");
188    }
189
190    // Extract archive
191    println!("📦 Extracting archive...");
192    let temp_dir = tempfile::tempdir().map_err(|e| e.to_string())?;
193
194    if cfg!(target_os = "windows") {
195        extract_zip(&bytes, temp_dir.path())?;
196    } else {
197        extract_tar_gz(&bytes, temp_dir.path())?;
198    }
199
200    // Find the binary in the extracted files
201    let binary_name = if cfg!(target_os = "windows") {
202        format!("{}.exe", repo_info.name)
203    } else {
204        repo_info.name.to_string()
205    };
206
207    let extracted_binary = temp_dir.path().join(&binary_name);
208    if !extracted_binary.exists() {
209        return Err(format!("Binary {binary_name} not found in archive"));
210    }
211
212    // Replace the current binary
213    println!("🔧 Installing update...");
214
215    // Backup current binary
216    let backup_path = install_path.with_extension("bak");
217    if let Err(e) = fs::copy(install_path, &backup_path) {
218        eprintln!("⚠️  Failed to create backup: {e}");
219    }
220
221    // Copy new binary
222    fs::copy(&extracted_binary, install_path)
223        .map_err(|e| format!("Failed to install binary: {e}"))?;
224
225    // Set executable permissions on Unix
226    #[cfg(unix)]
227    {
228        use std::os::unix::fs::PermissionsExt;
229        let mut perms = fs::metadata(install_path)
230            .map_err(|e| format!("Failed to get metadata: {e}"))?
231            .permissions();
232        perms.set_mode(0o755);
233        fs::set_permissions(install_path, perms)
234            .map_err(|e| format!("Failed to set permissions: {e}"))?;
235    }
236
237    // Clean up backup
238    if backup_path.exists() {
239        let _ = fs::remove_file(&backup_path);
240    }
241
242    Ok(())
243}
244
245fn get_platform_string() -> String {
246    let os = std::env::consts::OS;
247    let arch = std::env::consts::ARCH;
248
249    match (os, arch) {
250        ("linux", "x86_64") => "x86_64-unknown-linux-gnu",
251        ("linux", "aarch64") => "aarch64-unknown-linux-gnu",
252        ("macos", "x86_64") => "x86_64-apple-darwin",
253        ("macos", "aarch64") => "aarch64-apple-darwin",
254        ("windows", "x86_64") => "x86_64-pc-windows-msvc",
255        _ => panic!("Unsupported platform: {os}/{arch}"),
256    }
257    .to_string()
258}
259
260fn extract_tar_gz(bytes: &[u8], dest: &Path) -> Result<(), String> {
261    use flate2::read::GzDecoder;
262    use tar::Archive;
263
264    let decoder = GzDecoder::new(bytes);
265    let mut archive = Archive::new(decoder);
266    archive
267        .unpack(dest)
268        .map_err(|e| format!("Failed to extract tar.gz: {e}"))
269}
270
271#[cfg(target_os = "windows")]
272fn extract_zip(bytes: &[u8], dest: &Path) -> Result<(), String> {
273    use std::io::Cursor;
274    use zip::ZipArchive;
275
276    let reader = Cursor::new(bytes);
277    let mut archive = ZipArchive::new(reader).map_err(|e| format!("Failed to open zip: {e}"))?;
278
279    for i in 0..archive.len() {
280        let mut file = archive
281            .by_index(i)
282            .map_err(|e| format!("Failed to read zip entry: {e}"))?;
283        let outpath = dest.join(file.name());
284
285        if file.is_dir() {
286            fs::create_dir_all(&outpath)
287                .map_err(|e| format!("Failed to create directory: {e}"))?;
288        } else {
289            if let Some(p) = outpath.parent() {
290                fs::create_dir_all(p)
291                    .map_err(|e| format!("Failed to create parent directory: {e}"))?;
292            }
293            let mut outfile = fs::File::create(&outpath)
294                .map_err(|e| format!("Failed to create file: {e}"))?;
295            io::copy(&mut file, &mut outfile)
296                .map_err(|e| format!("Failed to extract file: {e}"))?;
297        }
298    }
299
300    Ok(())
301}
302
303#[cfg(not(target_os = "windows"))]
304fn extract_zip(_bytes: &[u8], _dest: &Path) -> Result<(), String> {
305    Err("ZIP extraction not supported on this platform".to_string())
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311
312    #[test]
313    fn test_get_platform_string() {
314        let platform = get_platform_string();
315        // Just verify it returns a non-empty string
316        assert!(!platform.is_empty());
317        assert!(platform.contains('-'));
318    }
319
320    #[test]
321    fn test_repo_info_latest_release_url() {
322        let repo = RepoInfo::new("workhelix", "prompter", "prompter-v");
323        let url = repo.latest_release_url();
324        assert_eq!(url, "https://api.github.com/repos/workhelix/prompter/releases/latest");
325    }
326
327    #[test]
328    fn test_get_latest_version_handles_errors() {
329        let repo = RepoInfo::new("nonexistent", "repo", "v");
330        // This should fail since the repo doesn't exist
331        let result = get_latest_version(&repo);
332        // Either succeeds (unlikely) or fails (expected)
333        assert!(result.is_ok() || result.is_err());
334    }
335}