workhelix_cli_common/
update.rs1use crate::types::RepoInfo;
10use sha2::{Digest, Sha256};
11use std::fs;
12use std::io::{self, Write};
13use std::path::Path;
14
15#[must_use]
29pub fn run_update(
30 repo_info: &RepoInfo,
31 current_version: &str,
32 version: Option<&str>,
33 force: bool,
34 install_dir: Option<&Path>,
35) -> i32 {
36 println!("🔄 Checking for updates...");
37
38 let target_version = if let Some(v) = version {
40 v.to_string()
41 } else {
42 match get_latest_version(repo_info) {
43 Ok(v) => v,
44 Err(e) => {
45 eprintln!("❌ Failed to check for updates: {e}");
46 return 1;
47 }
48 }
49 };
50
51 if target_version == current_version && !force {
53 println!("✅ Already running latest version (v{current_version})");
54 return 2;
55 }
56
57 println!("✨ Update available: v{target_version} (current: v{current_version})");
58
59 let install_path = if let Some(dir) = install_dir {
61 dir.join(repo_info.name)
62 } else {
63 match std::env::current_exe() {
64 Ok(path) => path,
65 Err(e) => {
66 eprintln!("❌ Failed to determine binary location: {e}");
67 return 1;
68 }
69 }
70 };
71
72 println!("📍 Install location: {}", install_path.display());
73 println!();
74
75 if !force {
77 print!("Continue with update? [y/N]: ");
78 io::stdout().flush().unwrap();
79
80 let mut response = String::new();
81 io::stdin().read_line(&mut response).unwrap();
82
83 if !matches!(response.trim().to_lowercase().as_str(), "y" | "yes") {
84 println!("Update cancelled.");
85 return 0;
86 }
87 }
88
89 match perform_update(repo_info, &target_version, &install_path) {
91 Ok(()) => {
92 println!("✅ Successfully updated to v{target_version}");
93 println!();
94 println!("Run '{} --version' to verify the installation.", repo_info.name);
95 0
96 }
97 Err(e) => {
98 eprintln!("❌ Update failed: {e}");
99 1
100 }
101 }
102}
103
104pub fn get_latest_version(repo_info: &RepoInfo) -> Result<String, String> {
110 let client = reqwest::blocking::Client::builder()
111 .user_agent(format!("{}-updater", repo_info.name))
112 .timeout(std::time::Duration::from_secs(10))
113 .build()
114 .map_err(|e| e.to_string())?;
115
116 let response: serde_json::Value = client
117 .get(repo_info.latest_release_url())
118 .send()
119 .map_err(|e| e.to_string())?
120 .json()
121 .map_err(|e| e.to_string())?;
122
123 let tag_name = response["tag_name"]
124 .as_str()
125 .ok_or_else(|| "No tag_name in response".to_string())?;
126
127 let version = tag_name
128 .trim_start_matches(repo_info.tag_prefix)
129 .trim_start_matches('v');
130 Ok(version.to_string())
131}
132
133fn perform_update(repo_info: &RepoInfo, version: &str, install_path: &Path) -> Result<(), String> {
134 let platform = get_platform_string();
136 let archive_ext = if cfg!(target_os = "windows") {
137 "zip"
138 } else {
139 "tar.gz"
140 };
141
142 let download_url = repo_info.download_url(version, &platform, archive_ext);
143
144 println!("📥 Downloading {}-{platform}.{archive_ext}...", repo_info.name);
145
146 let client = reqwest::blocking::Client::builder()
148 .user_agent(format!("{}-updater", repo_info.name))
149 .timeout(std::time::Duration::from_secs(300))
150 .build()
151 .map_err(|e| e.to_string())?;
152
153 let response = client
154 .get(&download_url)
155 .send()
156 .map_err(|e| e.to_string())?;
157
158 if !response.status().is_success() {
159 return Err(format!("Download failed: HTTP {}", response.status()));
160 }
161
162 let bytes = response.bytes().map_err(|e| e.to_string())?;
163
164 let checksum_url = format!("{download_url}.sha256");
166 let checksum_response = client
167 .get(&checksum_url)
168 .send()
169 .map_err(|e| e.to_string())?;
170
171 if checksum_response.status().is_success() {
172 println!("🔐 Verifying checksum...");
173 let expected_checksum = checksum_response.text().map_err(|e| e.to_string())?;
174 let expected_checksum = expected_checksum.split_whitespace().next().unwrap_or(&expected_checksum);
175
176 let mut hasher = Sha256::new();
177 hasher.update(&bytes);
178 let computed_checksum = hex::encode(hasher.finalize());
179
180 if computed_checksum.to_lowercase() != expected_checksum.to_lowercase() {
181 return Err(format!(
182 "Checksum mismatch!\nExpected: {expected_checksum}\nGot: {computed_checksum}"
183 ));
184 }
185 println!("✅ Checksum verified");
186 } else {
187 println!("⚠️ No checksum found, skipping verification");
188 }
189
190 println!("📦 Extracting archive...");
192 let temp_dir = tempfile::tempdir().map_err(|e| e.to_string())?;
193
194 if cfg!(target_os = "windows") {
195 extract_zip(&bytes, temp_dir.path())?;
196 } else {
197 extract_tar_gz(&bytes, temp_dir.path())?;
198 }
199
200 let binary_name = if cfg!(target_os = "windows") {
202 format!("{}.exe", repo_info.name)
203 } else {
204 repo_info.name.to_string()
205 };
206
207 let extracted_binary = temp_dir.path().join(&binary_name);
208 if !extracted_binary.exists() {
209 return Err(format!("Binary {binary_name} not found in archive"));
210 }
211
212 println!("🔧 Installing update...");
214
215 let backup_path = install_path.with_extension("bak");
217 if let Err(e) = fs::copy(install_path, &backup_path) {
218 eprintln!("⚠️ Failed to create backup: {e}");
219 }
220
221 fs::copy(&extracted_binary, install_path)
223 .map_err(|e| format!("Failed to install binary: {e}"))?;
224
225 #[cfg(unix)]
227 {
228 use std::os::unix::fs::PermissionsExt;
229 let mut perms = fs::metadata(install_path)
230 .map_err(|e| format!("Failed to get metadata: {e}"))?
231 .permissions();
232 perms.set_mode(0o755);
233 fs::set_permissions(install_path, perms)
234 .map_err(|e| format!("Failed to set permissions: {e}"))?;
235 }
236
237 if backup_path.exists() {
239 let _ = fs::remove_file(&backup_path);
240 }
241
242 Ok(())
243}
244
245fn get_platform_string() -> String {
246 let os = std::env::consts::OS;
247 let arch = std::env::consts::ARCH;
248
249 match (os, arch) {
250 ("linux", "x86_64") => "x86_64-unknown-linux-gnu",
251 ("linux", "aarch64") => "aarch64-unknown-linux-gnu",
252 ("macos", "x86_64") => "x86_64-apple-darwin",
253 ("macos", "aarch64") => "aarch64-apple-darwin",
254 ("windows", "x86_64") => "x86_64-pc-windows-msvc",
255 _ => panic!("Unsupported platform: {os}/{arch}"),
256 }
257 .to_string()
258}
259
260fn extract_tar_gz(bytes: &[u8], dest: &Path) -> Result<(), String> {
261 use flate2::read::GzDecoder;
262 use tar::Archive;
263
264 let decoder = GzDecoder::new(bytes);
265 let mut archive = Archive::new(decoder);
266 archive
267 .unpack(dest)
268 .map_err(|e| format!("Failed to extract tar.gz: {e}"))
269}
270
271#[cfg(target_os = "windows")]
272fn extract_zip(bytes: &[u8], dest: &Path) -> Result<(), String> {
273 use std::io::Cursor;
274 use zip::ZipArchive;
275
276 let reader = Cursor::new(bytes);
277 let mut archive = ZipArchive::new(reader).map_err(|e| format!("Failed to open zip: {e}"))?;
278
279 for i in 0..archive.len() {
280 let mut file = archive
281 .by_index(i)
282 .map_err(|e| format!("Failed to read zip entry: {e}"))?;
283 let outpath = dest.join(file.name());
284
285 if file.is_dir() {
286 fs::create_dir_all(&outpath)
287 .map_err(|e| format!("Failed to create directory: {e}"))?;
288 } else {
289 if let Some(p) = outpath.parent() {
290 fs::create_dir_all(p)
291 .map_err(|e| format!("Failed to create parent directory: {e}"))?;
292 }
293 let mut outfile = fs::File::create(&outpath)
294 .map_err(|e| format!("Failed to create file: {e}"))?;
295 io::copy(&mut file, &mut outfile)
296 .map_err(|e| format!("Failed to extract file: {e}"))?;
297 }
298 }
299
300 Ok(())
301}
302
303#[cfg(not(target_os = "windows"))]
304fn extract_zip(_bytes: &[u8], _dest: &Path) -> Result<(), String> {
305 Err("ZIP extraction not supported on this platform".to_string())
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311
312 #[test]
313 fn test_get_platform_string() {
314 let platform = get_platform_string();
315 assert!(!platform.is_empty());
317 assert!(platform.contains('-'));
318 }
319
320 #[test]
321 fn test_repo_info_latest_release_url() {
322 let repo = RepoInfo::new("workhelix", "prompter", "prompter-v");
323 let url = repo.latest_release_url();
324 assert_eq!(url, "https://api.github.com/repos/workhelix/prompter/releases/latest");
325 }
326
327 #[test]
328 fn test_get_latest_version_handles_errors() {
329 let repo = RepoInfo::new("nonexistent", "repo", "v");
330 let result = get_latest_version(&repo);
332 assert!(result.is_ok() || result.is_err());
334 }
335}