Skip to main content

rusty_commit/
update.rs

1use anyhow::{bail, Context, Result};
2use colored::*;
3use semver::Version;
4use std::env;
5use std::fs;
6use std::io;
7use std::path::{Path, PathBuf};
8use std::process::Command;
9use std::time::Duration;
10
11const GITHUB_REPO: &str = "hongkongkiwi/rusty-commit";
12const MAX_DOWNLOAD_SIZE: u64 = 100 * 1024 * 1024; // 100MB max
13const DOWNLOAD_TIMEOUT: Duration = Duration::from_secs(300); // 5 minutes
14const CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
15
16#[derive(Debug, Clone, PartialEq)]
17pub enum InstallMethod {
18    Homebrew,
19    Cargo,
20    Deb,
21    Rpm,
22    Binary,
23    Snap,
24    Unknown,
25}
26
27#[derive(Debug)]
28pub struct UpdateInfo {
29    pub current_version: String,
30    pub latest_version: String,
31    pub install_method: InstallMethod,
32    pub executable_path: PathBuf,
33    pub needs_update: bool,
34}
35
36/// Validate version string format
37fn validate_version(version: &str) -> Result<()> {
38    let clean_version = version.trim_start_matches('v');
39    Version::parse(clean_version).context("Invalid version format")?;
40
41    // Additional validation: no path traversal
42    if version.contains("..") || version.contains('/') || version.contains('\\') {
43        bail!("Invalid characters in version string");
44    }
45
46    Ok(())
47}
48
49/// Validate and sanitize file paths
50fn sanitize_path(path: &Path) -> Result<PathBuf> {
51    // Try to resolve to absolute path
52    let canonical = match path.canonicalize() {
53        Ok(p) => p,
54        Err(e) => {
55            // Fallback: use absolute path from current directory
56            tracing::warn!(
57                "Failed to canonicalize path {}: {}. Using absolute path fallback.",
58                path.display(),
59                e
60            );
61            if path.is_absolute() {
62                path.to_path_buf()
63            } else {
64                std::env::current_dir()
65                    .context("Failed to get current directory")?
66                    .join(path)
67            }
68        }
69    };
70
71    // Ensure path doesn't escape expected directories
72    let path_str = canonical.to_string_lossy();
73    if path_str.contains("..") {
74        bail!("Path traversal detected");
75    }
76
77    Ok(canonical)
78}
79
80/// Create a secure HTTP client with proper timeouts and limits
81fn create_http_client() -> Result<reqwest::Client> {
82    reqwest::Client::builder()
83        .user_agent(format!("rusty-commit/{}", env!("CARGO_PKG_VERSION")))
84        .timeout(DOWNLOAD_TIMEOUT)
85        .connect_timeout(CONNECT_TIMEOUT)
86        .https_only(true)
87        .build()
88        .context("Failed to create HTTP client")
89}
90
91/// Detect how rusty-commit was installed (more secure version)
92pub fn detect_install_method() -> Result<InstallMethod> {
93    let exe_path = env::current_exe().context("Failed to get current executable path")?;
94
95    // Sanitize the path
96    let exe_path = sanitize_path(&exe_path)?;
97    let exe_str = exe_path.to_string_lossy();
98
99    // Check for Homebrew installation
100    if exe_str.contains("/Cellar/") || exe_str.contains("homebrew") {
101        return Ok(InstallMethod::Homebrew);
102    }
103
104    // Check for Cargo installation
105    if exe_str.contains(".cargo/bin") {
106        return Ok(InstallMethod::Cargo);
107    }
108
109    // Check for Snap installation
110    if exe_str.contains("/snap/") {
111        return Ok(InstallMethod::Snap);
112    }
113
114    // Check for system package manager installations
115    if exe_str.starts_with("/usr/bin/") || exe_str.starts_with("/usr/local/bin/") {
116        // Try to detect package manager using safer methods
117        if Path::new("/etc/debian_version").exists() {
118            // Use dpkg-query which is safer than dpkg -S
119            if let Ok(output) = Command::new("dpkg-query")
120                .args(["-S", &exe_path.to_string_lossy()])
121                .output()
122            {
123                if output.status.success() {
124                    return Ok(InstallMethod::Deb);
125                }
126            }
127        }
128
129        if Path::new("/etc/redhat-release").exists() || Path::new("/etc/fedora-release").exists() {
130            // Check if installed via rpm (safer query)
131            if let Ok(output) = Command::new("rpm")
132                .args(["-qf", &exe_path.to_string_lossy()])
133                .output()
134            {
135                if output.status.success() {
136                    return Ok(InstallMethod::Rpm);
137                }
138            }
139        }
140
141        // Likely a binary installation
142        return Ok(InstallMethod::Binary);
143    }
144
145    // Check if it's in a typical binary install location
146    if exe_str.contains("/usr/local/bin/") || exe_str.contains("/opt/") {
147        return Ok(InstallMethod::Binary);
148    }
149
150    Ok(InstallMethod::Unknown)
151}
152
153/// Get the latest version from GitHub releases (with validation)
154pub async fn get_latest_version() -> Result<String> {
155    let client = create_http_client()?;
156
157    let url = format!(
158        "https://api.github.com/repos/{}/releases/latest",
159        GITHUB_REPO
160    );
161    let response = client
162        .get(&url)
163        .send()
164        .await
165        .context("Failed to fetch latest release")?;
166
167    if !response.status().is_success() {
168        bail!("GitHub API returned status: {}", response.status());
169    }
170
171    let release: serde_json::Value = response
172        .json()
173        .await
174        .context("Failed to parse release JSON")?;
175
176    let tag_name = release["tag_name"]
177        .as_str()
178        .context("Failed to get tag_name from release")?;
179
180    // Validate version format
181    validate_version(tag_name)?;
182
183    // Remove 'v' prefix if present
184    Ok(tag_name.trim_start_matches('v').to_string())
185}
186
187/// Check if an update is available
188pub async fn check_for_update() -> Result<UpdateInfo> {
189    let current_version = env!("CARGO_PKG_VERSION").to_string();
190    let latest_version = get_latest_version().await?;
191    let install_method = detect_install_method()?;
192    let executable_path = env::current_exe()?;
193
194    let current = Version::parse(&current_version)?;
195    let latest = Version::parse(&latest_version)?;
196
197    let needs_update = latest > current;
198
199    Ok(UpdateInfo {
200        current_version,
201        latest_version,
202        install_method,
203        executable_path,
204        needs_update,
205    })
206}
207
208/// Update using Homebrew (safer version)
209async fn update_homebrew() -> Result<()> {
210    println!("{}", "Updating via Homebrew...".blue());
211
212    // Check if brew exists first
213    which::which("brew").context("Homebrew not found in PATH")?;
214
215    // Update Homebrew
216    let output = Command::new("brew")
217        .args(["update"])
218        .output()
219        .context("Failed to run brew update")?;
220
221    if !output.status.success() {
222        bail!(
223            "brew update failed: {}",
224            String::from_utf8_lossy(&output.stderr)
225        );
226    }
227
228    // Upgrade rusty-commit
229    let output = Command::new("brew")
230        .args(["upgrade", "rusty-commit"])
231        .output()
232        .context("Failed to run brew upgrade")?;
233
234    if !output.status.success() {
235        let stderr = String::from_utf8_lossy(&output.stderr);
236        if stderr.contains("already installed") {
237            println!("{}", "Already up to date!".green());
238            return Ok(());
239        }
240        bail!("brew upgrade failed: {}", stderr);
241    }
242
243    println!("{}", "Successfully updated via Homebrew!".green());
244    Ok(())
245}
246
247/// Update using Cargo (safer version)
248async fn update_cargo() -> Result<()> {
249    println!("{}", "Updating via Cargo...".blue());
250
251    // Check if cargo exists
252    which::which("cargo").context("Cargo not found in PATH")?;
253
254    let output = Command::new("cargo")
255        .args([
256            "install",
257            "rusty-commit",
258            "--force",
259            "--features",
260            "secure-storage",
261        ])
262        .output()
263        .context("Failed to run cargo install")?;
264
265    if !output.status.success() {
266        bail!(
267            "cargo install failed: {}",
268            String::from_utf8_lossy(&output.stderr)
269        );
270    }
271
272    println!("{}", "Successfully updated via Cargo!".green());
273    Ok(())
274}
275
276/// Download with size limit and checksum verification
277async fn download_with_verification(
278    url: &str,
279    expected_checksum: Option<&str>,
280    max_size: u64,
281) -> Result<Vec<u8>> {
282    println!("{}", format!("Downloading from: {}", url).blue());
283
284    let client = create_http_client()?;
285    let response = client
286        .get(url)
287        .send()
288        .await
289        .context("Failed to start download")?;
290
291    if !response.status().is_success() {
292        bail!("Download failed with status: {}", response.status());
293    }
294
295    // Check content length if available
296    if let Some(content_length) = response.content_length() {
297        if content_length > max_size {
298            bail!(
299                "File too large: {} bytes (max: {} bytes)",
300                content_length,
301                max_size
302            );
303        }
304    }
305
306    // Download with size limit
307    let mut bytes = Vec::new();
308    let mut stream = response.bytes_stream();
309    use futures::StreamExt;
310
311    while let Some(chunk) = stream.next().await {
312        let chunk = chunk.context("Failed to read chunk")?;
313        bytes.extend_from_slice(&chunk);
314
315        if bytes.len() as u64 > max_size {
316            bail!("Download exceeded maximum size of {} bytes", max_size);
317        }
318    }
319
320    // Verify checksum if provided
321    if let Some(expected) = expected_checksum {
322        use sha2::{Digest, Sha256};
323        let mut hasher = Sha256::new();
324        hasher.update(&bytes);
325        let actual = format!("{:x}", hasher.finalize());
326
327        if actual != expected {
328            bail!("Checksum verification failed");
329        }
330
331        println!("{}", "Checksum verified".green());
332    }
333
334    // TODO: Add Cosign signature verification in future version
335    // This would require integrating with cosign binary or sigstore-rs crate
336    // For now, checksums provide integrity verification
337
338    Ok(bytes)
339}
340
341/// Get SHA256 checksum for a release file
342async fn get_release_checksum(version: &str, filename: &str) -> Result<Option<String>> {
343    let client = create_http_client()?;
344    let url = format!(
345        "https://github.com/{}/releases/download/v{}/SHA256SUMS.txt",
346        GITHUB_REPO, version
347    );
348
349    let response = client.get(&url).send().await;
350
351    match response {
352        Ok(resp) if resp.status().is_success() => {
353            let text = resp.text().await?;
354            for line in text.lines() {
355                if line.contains(filename) {
356                    if let Some(checksum) = line.split_whitespace().next() {
357                        return Ok(Some(checksum.to_string()));
358                    }
359                }
360            }
361            Ok(None)
362        }
363        _ => Ok(None),
364    }
365}
366
367/// Atomic file replacement with proper error handling
368async fn atomic_replace_file(source: &Path, target: &Path) -> Result<()> {
369    use std::fs::OpenOptions;
370    use std::io::copy;
371
372    // Create a unique temporary file in the same directory as target
373    let temp_path = target.with_extension(format!(".tmp.{}", std::process::id()));
374
375    // Copy source to temp location
376    {
377        let mut source_file = fs::File::open(source).context("Failed to open source file")?;
378        let mut temp_file = OpenOptions::new()
379            .write(true)
380            .create(true)
381            .truncate(true)
382            .open(&temp_path)
383            .context("Failed to create temp file")?;
384
385        copy(&mut source_file, &mut temp_file).context("Failed to copy to temp file")?;
386    }
387
388    // Set executable permissions on Unix
389    #[cfg(unix)]
390    {
391        use std::os::unix::fs::PermissionsExt;
392        let mut perms = fs::metadata(&temp_path)?.permissions();
393        perms.set_mode(0o755);
394        fs::set_permissions(&temp_path, perms)?;
395    }
396
397    // Atomic rename
398    fs::rename(&temp_path, target).context("Failed to perform atomic rename")?;
399
400    Ok(())
401}
402
403/// Update Debian package (secure version)
404async fn update_deb(version: &str) -> Result<()> {
405    println!("{}", "Updating via apt/dpkg...".blue());
406
407    // Validate version
408    validate_version(version)?;
409
410    let arch = get_system_arch()?;
411    let deb_arch = match arch.as_str() {
412        "x86_64" => "amd64",
413        "aarch64" => "arm64",
414        "armv7" => "armhf",
415        _ => bail!("Unsupported architecture for .deb: {}", arch),
416    };
417
418    let filename = format!("rusty-commit_{}_{}.deb", version, deb_arch);
419    let url = format!(
420        "https://github.com/{}/releases/download/v{}/{}",
421        GITHUB_REPO, version, filename
422    );
423
424    // Get checksum
425    let checksum = get_release_checksum(version, &filename).await?;
426
427    // Download with verification
428    let package_data =
429        download_with_verification(&url, checksum.as_deref(), MAX_DOWNLOAD_SIZE).await?;
430
431    // Save to secure temp directory
432    let temp_dir = tempfile::TempDir::new()?;
433    let temp_path = temp_dir.path().join(&filename);
434    fs::write(&temp_path, package_data)?;
435
436    // Install with dpkg or apt
437    let result = if which::which("apt-get").is_ok() {
438        Command::new("sudo")
439            .args(["apt-get", "install", "-y"])
440            .arg(&temp_path)
441            .output()
442    } else if which::which("dpkg").is_ok() {
443        Command::new("sudo")
444            .args(["dpkg", "-i"])
445            .arg(&temp_path)
446            .output()
447    } else {
448        bail!("Neither apt-get nor dpkg found");
449    };
450
451    match result {
452        Ok(output) if output.status.success() => {
453            println!("{}", "Successfully updated via package manager!".green());
454            Ok(())
455        }
456        Ok(output) => bail!(
457            "Package installation failed: {}",
458            String::from_utf8_lossy(&output.stderr)
459        ),
460        Err(e) => Err(e.into()),
461    }
462}
463
464/// Update RPM package (secure version)
465async fn update_rpm(version: &str) -> Result<()> {
466    println!("{}", "Updating via rpm/dnf/yum...".blue());
467
468    // Validate version
469    validate_version(version)?;
470
471    let arch = get_system_arch()?;
472    let rpm_arch = match arch.as_str() {
473        "x86_64" => "x86_64",
474        "aarch64" => "aarch64",
475        _ => bail!("Unsupported architecture for .rpm: {}", arch),
476    };
477
478    let filename = format!("rusty-commit-{}-1.{}.rpm", version, rpm_arch);
479    let url = format!(
480        "https://github.com/{}/releases/download/v{}/{}",
481        GITHUB_REPO, version, filename
482    );
483
484    // Get checksum
485    let checksum = get_release_checksum(version, &filename).await?;
486
487    // Download with verification
488    let package_data =
489        download_with_verification(&url, checksum.as_deref(), MAX_DOWNLOAD_SIZE).await?;
490
491    // Save to secure temp directory
492    let temp_dir = tempfile::TempDir::new()?;
493    let temp_path = temp_dir.path().join(&filename);
494    fs::write(&temp_path, package_data)?;
495
496    // Install with package manager
497    let result = if which::which("dnf").is_ok() {
498        Command::new("sudo")
499            .args(["dnf", "install", "-y"])
500            .arg(&temp_path)
501            .output()
502    } else if which::which("yum").is_ok() {
503        Command::new("sudo")
504            .args(["yum", "install", "-y"])
505            .arg(&temp_path)
506            .output()
507    } else if which::which("rpm").is_ok() {
508        Command::new("sudo")
509            .args(["rpm", "-Uvh"])
510            .arg(&temp_path)
511            .output()
512    } else {
513        bail!("No suitable package manager found");
514    };
515
516    match result {
517        Ok(output) if output.status.success() => {
518            println!("{}", "Successfully updated via package manager!".green());
519            Ok(())
520        }
521        Ok(output) => bail!(
522            "Package installation failed: {}",
523            String::from_utf8_lossy(&output.stderr)
524        ),
525        Err(e) => Err(e.into()),
526    }
527}
528
529/// Update binary installation (secure version)
530async fn update_binary(version: &str, exe_path: &Path) -> Result<()> {
531    println!("{}", "Updating binary installation...".blue());
532
533    // Validate inputs
534    validate_version(version)?;
535    let exe_path = sanitize_path(exe_path)?;
536
537    let os = get_system_os()?;
538    let arch = get_system_arch()?;
539
540    // Prefer musl tarballs when running on Alpine/musl
541    let is_musl = if os == "linux" {
542        // Best-effort detection: check /etc/alpine-release or ldd output
543        if Path::new("/etc/alpine-release").exists() {
544            true
545        } else {
546            let output = Command::new("sh")
547                .arg("-lc")
548                .arg("ldd --version 2>&1 || true")
549                .output();
550            if let Ok(out) = output {
551                String::from_utf8_lossy(&out.stdout)
552                    .to_lowercase()
553                    .contains("musl")
554                    || String::from_utf8_lossy(&out.stderr)
555                        .to_lowercase()
556                        .contains("musl")
557            } else {
558                false
559            }
560        }
561    } else {
562        false
563    };
564
565    let archive_name = match (os.as_str(), arch.as_str(), is_musl) {
566        ("linux", "x86_64", true) => "rustycommit-linux-musl-x86_64.tar.gz",
567        ("linux", "aarch64", true) => "rustycommit-linux-musl-aarch64.tar.gz",
568        ("linux", "riscv64", true) => "rustycommit-linux-musl-riscv64.tar.gz",
569        ("linux", "x86_64", false) => "rustycommit-linux-x86_64.tar.gz",
570        ("linux", "aarch64", false) => "rustycommit-linux-aarch64.tar.gz",
571        ("linux", "armv7", false) => "rustycommit-linux-armv7.tar.gz",
572        ("linux", "riscv64", false) => "rustycommit-linux-riscv64.tar.gz",
573        ("macos", "x86_64", _) => "rustycommit-macos-x86_64.tar.gz",
574        ("macos", "aarch64", _) => "rustycommit-macos-aarch64.tar.gz",
575        ("windows", "x86_64", _) => "rustycommit-windows-x86_64.zip",
576        ("windows", "i686", _) => "rustycommit-windows-i686.zip",
577        _ => bail!(
578            "Unsupported OS/architecture: {}-{} (musl={})",
579            os,
580            arch,
581            is_musl
582        ),
583    };
584
585    let url = format!(
586        "https://github.com/{}/releases/download/v{}/{}",
587        GITHUB_REPO, version, archive_name
588    );
589
590    // Get checksum
591    let checksum = get_release_checksum(version, archive_name).await?;
592
593    // Download with verification
594    let archive_data =
595        download_with_verification(&url, checksum.as_deref(), MAX_DOWNLOAD_SIZE).await?;
596
597    // Extract to secure temp directory
598    let temp_dir = tempfile::TempDir::new()?;
599    let archive_path = temp_dir.path().join(archive_name);
600    fs::write(&archive_path, archive_data)?;
601
602    // Extract archive using built-in libraries when possible
603    let binary_name = if cfg!(windows) { "rco.exe" } else { "rco" };
604    let extracted_binary = temp_dir.path().join(binary_name);
605
606    if archive_name.ends_with(".tar.gz") {
607        // Use tar crate for extraction (safer than shell command)
608        use flate2::read::GzDecoder;
609        use tar::Archive;
610
611        let tar_gz = fs::File::open(&archive_path)?;
612        let tar = GzDecoder::new(tar_gz);
613        let mut archive = Archive::new(tar);
614        archive.unpack(temp_dir.path())?;
615    } else if archive_name.ends_with(".zip") {
616        // Use zip crate for extraction
617        use zip::ZipArchive;
618
619        let file = fs::File::open(&archive_path)?;
620        let mut archive = ZipArchive::new(file)?;
621
622        for i in 0..archive.len() {
623            let mut file = archive.by_index(i)?;
624            if file.name() == binary_name {
625                let mut outfile = fs::File::create(&extracted_binary)?;
626                io::copy(&mut file, &mut outfile)?;
627                break;
628            }
629        }
630    }
631
632    if !extracted_binary.exists() {
633        bail!("Binary not found in archive");
634    }
635
636    // Create backup of current binary
637    let backup_path = exe_path.with_extension(format!("bak.{}", std::process::id()));
638    fs::copy(&exe_path, &backup_path).context("Failed to create backup")?;
639
640    // Try to perform atomic replacement
641    let replace_result = atomic_replace_file(&extracted_binary, &exe_path).await;
642
643    match replace_result {
644        Ok(_) => {
645            // Success - remove backup
646            let _ = fs::remove_file(&backup_path);
647            println!("{}", "Successfully updated binary!".green());
648            Ok(())
649        }
650        Err(e) => {
651            // Try to restore backup
652            if let Err(restore_err) = fs::rename(&backup_path, &exe_path) {
653                eprintln!(
654                    "{}",
655                    format!("Critical: Failed to restore backup: {}", restore_err).red()
656                );
657            }
658            Err(e)
659        }
660    }
661}
662
663/// Update Snap package
664async fn update_snap() -> Result<()> {
665    println!("{}", "Updating via Snap...".blue());
666
667    which::which("snap").context("Snap not found in PATH")?;
668
669    let output = Command::new("sudo")
670        .args(["snap", "refresh", "rusty-commit"])
671        .output()
672        .context("Failed to run snap refresh")?;
673
674    if !output.status.success() {
675        let stderr = String::from_utf8_lossy(&output.stderr);
676        if stderr.contains("has no updates available") {
677            println!("{}", "Already up to date!".green());
678            return Ok(());
679        }
680        bail!("snap refresh failed: {}", stderr);
681    }
682
683    println!("{}", "Successfully updated via Snap!".green());
684    Ok(())
685}
686
687/// Perform the update based on installation method
688pub async fn perform_update(info: &UpdateInfo) -> Result<()> {
689    if !info.needs_update {
690        println!("{}", "Already running the latest version!".green());
691        return Ok(());
692    }
693
694    println!(
695        "{}",
696        format!(
697            "Updating from v{} to v{}...",
698            info.current_version, info.latest_version
699        )
700        .blue()
701    );
702
703    match info.install_method {
704        InstallMethod::Homebrew => update_homebrew().await,
705        InstallMethod::Cargo => update_cargo().await,
706        InstallMethod::Deb => update_deb(&info.latest_version).await,
707        InstallMethod::Rpm => update_rpm(&info.latest_version).await,
708        InstallMethod::Binary => update_binary(&info.latest_version, &info.executable_path).await,
709        InstallMethod::Snap => update_snap().await,
710        InstallMethod::Unknown => {
711            bail!(
712                "Could not detect installation method. Please update manually or use the install script:\n\
713                curl -fsSL https://raw.githubusercontent.com/{}/main/install.sh | bash",
714                GITHUB_REPO
715            )
716        }
717    }
718}
719
720/// Get system OS
721fn get_system_os() -> Result<String> {
722    if cfg!(target_os = "linux") {
723        Ok("linux".to_string())
724    } else if cfg!(target_os = "macos") {
725        Ok("macos".to_string())
726    } else if cfg!(target_os = "windows") {
727        Ok("windows".to_string())
728    } else {
729        Ok("unknown".to_string())
730    }
731}
732
733/// Get system architecture
734fn get_system_arch() -> Result<String> {
735    if cfg!(target_arch = "x86_64") {
736        Ok("x86_64".to_string())
737    } else if cfg!(target_arch = "aarch64") {
738        Ok("aarch64".to_string())
739    } else if cfg!(target_arch = "arm") {
740        Ok("armv7".to_string())
741    } else if cfg!(target_arch = "x86") {
742        Ok("i686".to_string())
743    } else if cfg!(target_arch = "riscv64") {
744        Ok("riscv64".to_string())
745    } else {
746        Ok("unknown".to_string())
747    }
748}
749
750#[cfg(test)]
751mod tests {
752    use super::*;
753
754    #[test]
755    fn test_version_validation() {
756        assert!(validate_version("1.0.0").is_ok());
757        assert!(validate_version("v1.0.0").is_ok());
758        assert!(validate_version("1.0.0-beta.1").is_ok());
759
760        assert!(validate_version("../etc/passwd").is_err());
761        assert!(validate_version("1.0.0/../../etc").is_err());
762        assert!(validate_version("invalid").is_err());
763    }
764
765    #[test]
766    fn test_version_comparison() {
767        let v1 = Version::parse("1.0.0").unwrap();
768        let v2 = Version::parse("1.0.1").unwrap();
769        assert!(v2 > v1);
770    }
771}