workhelix_cli_common/
update.rs1use crate::types::RepoInfo;
10use sha2::{Digest, Sha256};
11use std::fs;
12use std::io::{self, Write};
13use std::path::Path;
14
15#[must_use]
29pub fn run_update(
30 repo_info: &RepoInfo,
31 current_version: &str,
32 version: Option<&str>,
33 force: bool,
34 install_dir: Option<&Path>,
35) -> i32 {
36 println!("🔄 Checking for updates...");
37
38 let target_version = if let Some(v) = version {
40 v.to_string()
41 } else {
42 match get_latest_version(repo_info) {
43 Ok(v) => v,
44 Err(e) => {
45 eprintln!("❌ Failed to check for updates: {e}");
46 return 1;
47 }
48 }
49 };
50
51 if target_version == current_version && !force {
53 println!("✅ Already running latest version (v{current_version})");
54 return 2;
55 }
56
57 println!("✨ Update available: v{target_version} (current: v{current_version})");
58
59 let install_path = if let Some(dir) = install_dir {
61 dir.join(repo_info.name)
62 } else {
63 match std::env::current_exe() {
64 Ok(path) => path,
65 Err(e) => {
66 eprintln!("❌ Failed to determine binary location: {e}");
67 return 1;
68 }
69 }
70 };
71
72 println!("📍 Install location: {}", install_path.display());
73 println!();
74
75 if !force {
77 print!("Continue with update? [y/N]: ");
78 io::stdout().flush().unwrap();
79
80 let mut response = String::new();
81 io::stdin().read_line(&mut response).unwrap();
82
83 if !matches!(response.trim().to_lowercase().as_str(), "y" | "yes") {
84 println!("Update cancelled.");
85 return 0;
86 }
87 }
88
89 match perform_update(repo_info, &target_version, &install_path) {
91 Ok(()) => {
92 println!("✅ Successfully updated to v{target_version}");
93 println!();
94 println!(
95 "Run '{} --version' to verify the installation.",
96 repo_info.name
97 );
98 0
99 }
100 Err(e) => {
101 eprintln!("❌ Update failed: {e}");
102 1
103 }
104 }
105}
106
107pub fn get_latest_version(repo_info: &RepoInfo) -> Result<String, String> {
113 let client = reqwest::blocking::Client::builder()
114 .user_agent(format!("{}-updater", repo_info.name))
115 .timeout(std::time::Duration::from_secs(10))
116 .build()
117 .map_err(|e| e.to_string())?;
118
119 let response: serde_json::Value = client
120 .get(repo_info.latest_release_url())
121 .send()
122 .map_err(|e| e.to_string())?
123 .json()
124 .map_err(|e| e.to_string())?;
125
126 let tag_name = response["tag_name"]
127 .as_str()
128 .ok_or_else(|| "No tag_name in response".to_string())?;
129
130 let version = tag_name
131 .trim_start_matches(repo_info.tag_prefix)
132 .trim_start_matches('v');
133 Ok(version.to_string())
134}
135
136fn perform_update(repo_info: &RepoInfo, version: &str, install_path: &Path) -> Result<(), String> {
137 let platform = get_platform_string();
139 let archive_ext = if cfg!(target_os = "windows") {
140 "zip"
141 } else {
142 "tar.gz"
143 };
144
145 let download_url = repo_info.download_url(version, &platform, archive_ext);
146
147 println!(
148 "📥 Downloading {}-{platform}.{archive_ext}...",
149 repo_info.name
150 );
151
152 let client = reqwest::blocking::Client::builder()
154 .user_agent(format!("{}-updater", repo_info.name))
155 .timeout(std::time::Duration::from_secs(300))
156 .build()
157 .map_err(|e| e.to_string())?;
158
159 let response = client
160 .get(&download_url)
161 .send()
162 .map_err(|e| e.to_string())?;
163
164 if !response.status().is_success() {
165 return Err(format!("Download failed: HTTP {}", response.status()));
166 }
167
168 let bytes = response.bytes().map_err(|e| e.to_string())?;
169
170 let checksum_url = format!("{download_url}.sha256");
172 let checksum_response = client
173 .get(&checksum_url)
174 .send()
175 .map_err(|e| e.to_string())?;
176
177 if checksum_response.status().is_success() {
178 println!("🔐 Verifying checksum...");
179 let expected_checksum = checksum_response.text().map_err(|e| e.to_string())?;
180 let expected_checksum = expected_checksum
181 .split_whitespace()
182 .next()
183 .unwrap_or(&expected_checksum);
184
185 let mut hasher = Sha256::new();
186 hasher.update(&bytes);
187 let computed_checksum = hex::encode(hasher.finalize());
188
189 if computed_checksum.to_lowercase() != expected_checksum.to_lowercase() {
190 return Err(format!(
191 "Checksum mismatch!\nExpected: {expected_checksum}\nGot: {computed_checksum}"
192 ));
193 }
194 println!("✅ Checksum verified");
195 } else {
196 println!("⚠️ No checksum found, skipping verification");
197 }
198
199 println!("📦 Extracting archive...");
201 let temp_dir = tempfile::tempdir().map_err(|e| e.to_string())?;
202
203 if cfg!(target_os = "windows") {
204 extract_zip(&bytes, temp_dir.path())?;
205 } else {
206 extract_tar_gz(&bytes, temp_dir.path())?;
207 }
208
209 let binary_name = if cfg!(target_os = "windows") {
211 format!("{}.exe", repo_info.name)
212 } else {
213 repo_info.name.to_string()
214 };
215
216 let extracted_binary = temp_dir.path().join(&binary_name);
217 if !extracted_binary.exists() {
218 return Err(format!("Binary {binary_name} not found in archive"));
219 }
220
221 println!("🔧 Installing update...");
223
224 let backup_path = install_path.with_extension("bak");
226 if let Err(e) = fs::copy(install_path, &backup_path) {
227 eprintln!("⚠️ Failed to create backup: {e}");
228 }
229
230 fs::copy(&extracted_binary, install_path)
232 .map_err(|e| format!("Failed to install binary: {e}"))?;
233
234 #[cfg(unix)]
236 {
237 use std::os::unix::fs::PermissionsExt;
238 let mut perms = fs::metadata(install_path)
239 .map_err(|e| format!("Failed to get metadata: {e}"))?
240 .permissions();
241 perms.set_mode(0o755);
242 fs::set_permissions(install_path, perms)
243 .map_err(|e| format!("Failed to set permissions: {e}"))?;
244 }
245
246 if backup_path.exists() {
248 let _ = fs::remove_file(&backup_path);
249 }
250
251 Ok(())
252}
253
254fn get_platform_string() -> String {
255 let os = std::env::consts::OS;
256 let arch = std::env::consts::ARCH;
257
258 match (os, arch) {
259 ("linux", "x86_64") => "x86_64-unknown-linux-gnu",
260 ("linux", "aarch64") => "aarch64-unknown-linux-gnu",
261 ("macos", "x86_64") => "x86_64-apple-darwin",
262 ("macos", "aarch64") => "aarch64-apple-darwin",
263 ("windows", "x86_64") => "x86_64-pc-windows-msvc",
264 _ => panic!("Unsupported platform: {os}/{arch}"),
265 }
266 .to_string()
267}
268
269fn extract_tar_gz(bytes: &[u8], dest: &Path) -> Result<(), String> {
270 use flate2::read::GzDecoder;
271 use tar::Archive;
272
273 let decoder = GzDecoder::new(bytes);
274 let mut archive = Archive::new(decoder);
275 archive
276 .unpack(dest)
277 .map_err(|e| format!("Failed to extract tar.gz: {e}"))
278}
279
280#[cfg(target_os = "windows")]
281fn extract_zip(bytes: &[u8], dest: &Path) -> Result<(), String> {
282 use std::io::Cursor;
283 use zip::ZipArchive;
284
285 let reader = Cursor::new(bytes);
286 let mut archive = ZipArchive::new(reader).map_err(|e| format!("Failed to open zip: {e}"))?;
287
288 for i in 0..archive.len() {
289 let mut file = archive
290 .by_index(i)
291 .map_err(|e| format!("Failed to read zip entry: {e}"))?;
292 let outpath = dest.join(file.name());
293
294 if file.is_dir() {
295 fs::create_dir_all(&outpath).map_err(|e| format!("Failed to create directory: {e}"))?;
296 } else {
297 if let Some(p) = outpath.parent() {
298 fs::create_dir_all(p)
299 .map_err(|e| format!("Failed to create parent directory: {e}"))?;
300 }
301 let mut outfile =
302 fs::File::create(&outpath).map_err(|e| format!("Failed to create file: {e}"))?;
303 io::copy(&mut file, &mut outfile)
304 .map_err(|e| format!("Failed to extract file: {e}"))?;
305 }
306 }
307
308 Ok(())
309}
310
311#[cfg(not(target_os = "windows"))]
312fn extract_zip(_bytes: &[u8], _dest: &Path) -> Result<(), String> {
313 Err("ZIP extraction not supported on this platform".to_string())
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319
320 #[test]
321 fn test_get_platform_string() {
322 let platform = get_platform_string();
323 assert!(!platform.is_empty());
325 assert!(platform.contains('-'));
326 }
327
328 #[test]
329 fn test_repo_info_latest_release_url() {
330 let repo = RepoInfo::new("workhelix", "prompter", "prompter-v");
331 let url = repo.latest_release_url();
332 assert_eq!(
333 url,
334 "https://api.github.com/repos/workhelix/prompter/releases/latest"
335 );
336 }
337
338 #[test]
339 fn test_get_latest_version_handles_errors() {
340 let repo = RepoInfo::new("nonexistent", "repo", "v");
341 let result = get_latest_version(&repo);
343 assert!(result.is_ok() || result.is_err());
345 }
346}