vx_cli/commands/
self_update.rs

1//! Self-update command implementation
2
3use crate::ui::UI;
4use anyhow::{anyhow, Result};
5use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT};
6use serde::Deserialize;
7use std::env;
8use std::fs;
9use std::path::PathBuf;
10
11#[derive(Debug, Deserialize)]
12struct GitHubRelease {
13    tag_name: String,
14    #[allow(dead_code)]
15    name: String,
16    body: String,
17    assets: Vec<GitHubAsset>,
18    #[allow(dead_code)]
19    prerelease: bool,
20}
21
22#[derive(Debug, Deserialize)]
23struct GitHubAsset {
24    name: String,
25    browser_download_url: String,
26    size: u64,
27}
28
29pub async fn handle(
30    token: Option<&str>,
31    prerelease: bool,
32    force: bool,
33    check_only: bool,
34) -> Result<()> {
35    UI::info("🔍 Checking for vx updates...");
36
37    let current_version = env!("CARGO_PKG_VERSION");
38    UI::detail(&format!("Current version: {}", current_version));
39
40    // Create HTTP client with optional authentication
41    let client = create_authenticated_client(token)?;
42
43    // Get latest release information with smart channel selection
44    let release = get_latest_release(&client, prerelease, token.is_some()).await?;
45
46    let latest_version = release.tag_name.trim_start_matches('v');
47    UI::detail(&format!("Latest version: {}", latest_version));
48
49    // Check if update is needed
50    if !force && current_version == latest_version {
51        UI::success("✅ vx is already up to date!");
52        return Ok(());
53    }
54
55    if current_version != latest_version {
56        UI::info(&format!(
57            "📦 New version available: {} -> {}",
58            current_version, latest_version
59        ));
60
61        if !release.body.is_empty() {
62            UI::info("📝 Release notes:");
63            println!("{}", release.body);
64        }
65    }
66
67    if check_only {
68        if current_version != latest_version {
69            UI::info("💡 Run 'vx self-update' to update to the latest version");
70        }
71        return Ok(());
72    }
73
74    // Find appropriate asset for current platform
75    let asset = find_platform_asset(&release.assets)?;
76    UI::info(&format!(
77        "📥 Downloading {} ({} bytes)...",
78        asset.name, asset.size
79    ));
80
81    // Download and install update
82    download_and_install(&client, asset, force).await?;
83
84    UI::success(&format!(
85        "🎉 Successfully updated vx to version {}!",
86        latest_version
87    ));
88    UI::hint("Restart your terminal or run 'vx --version' to verify the update");
89
90    Ok(())
91}
92
93fn create_authenticated_client(token: Option<&str>) -> Result<reqwest::Client> {
94    let mut headers = HeaderMap::new();
95
96    // Always set User-Agent
97    headers.insert(
98        USER_AGENT,
99        HeaderValue::from_static("vx-cli/0.3.0 (https://github.com/loonghao/vx)"),
100    );
101
102    // Add authentication if token is provided
103    if let Some(token) = token {
104        let auth_value = format!("Bearer {}", token);
105        headers.insert(
106            AUTHORIZATION,
107            HeaderValue::from_str(&auth_value)
108                .map_err(|e| anyhow!("Invalid token format: {}", e))?,
109        );
110        UI::detail("🔐 Using authenticated requests to GitHub API");
111    } else {
112        UI::detail("🌐 No GitHub token provided, will prefer CDN for downloads");
113        UI::hint("💡 Use --token <TOKEN> to use GitHub API directly and avoid CDN delays");
114    }
115
116    let client = reqwest::Client::builder()
117        .default_headers(headers)
118        .timeout(std::time::Duration::from_secs(30))
119        .build()?;
120
121    Ok(client)
122}
123
124async fn get_latest_release(
125    client: &reqwest::Client,
126    prerelease: bool,
127    has_token: bool,
128) -> Result<GitHubRelease> {
129    // If no token is provided, prefer CDN to avoid rate limits
130    if !has_token {
131        UI::info("🌐 No GitHub token provided, using CDN for version check...");
132
133        // Try jsDelivr API first when no token
134        match try_jsdelivr_api(client, prerelease).await {
135            Ok(release) => {
136                UI::info("✅ Got version info from jsDelivr CDN");
137                return Ok(release);
138            }
139            Err(e) => {
140                UI::warn(&format!("⚠️ CDN fallback failed: {}", e));
141                UI::info("🔄 Falling back to GitHub API...");
142            }
143        }
144    }
145
146    // Try GitHub API (either as primary with token, or as fallback without token)
147    match try_github_api(client, prerelease).await {
148        Ok(release) => Ok(release),
149        Err(e) => {
150            // Check if it's a rate limit error
151            if e.to_string().contains("rate limit") {
152                if has_token {
153                    // If we have a token but still hit rate limit, something's wrong
154                    return Err(anyhow!(
155                        "GitHub API rate limit exceeded even with authentication. \
156                        Check your token permissions or try again later."
157                    ));
158                } else {
159                    // If no token and we already tried CDN, we're out of options
160                    return Err(anyhow!(
161                        "GitHub API rate limit exceeded and CDN fallback also failed. \
162                        Use --token <TOKEN> to authenticate and increase rate limits. \
163                        See: https://docs.github.com/en/rest/using-the-rest-api/rate-limits-for-the-rest-api"
164                    ));
165                }
166            }
167
168            // For other errors, try CDN as last resort if we haven't already
169            if has_token {
170                UI::warn(&format!("⚠️ GitHub API failed: {}", e));
171                UI::info("🔄 Trying CDN fallback...");
172
173                if let Ok(release) = try_jsdelivr_api(client, prerelease).await {
174                    UI::info("✅ Got version info from jsDelivr CDN");
175                    return Ok(release);
176                }
177            }
178
179            // Return the original error if all else fails
180            Err(e)
181        }
182    }
183}
184
185async fn try_github_api(client: &reqwest::Client, prerelease: bool) -> Result<GitHubRelease> {
186    let url = if prerelease {
187        "https://api.github.com/repos/loonghao/vx/releases"
188    } else {
189        "https://api.github.com/repos/loonghao/vx/releases/latest"
190    };
191
192    let response = client.get(url).send().await?;
193
194    // Check for rate limiting
195    if response.status() == 403 {
196        let remaining = response
197            .headers()
198            .get("x-ratelimit-remaining")
199            .and_then(|v| v.to_str().ok())
200            .unwrap_or("unknown");
201
202        return Err(anyhow!(
203            "GitHub API rate limit exceeded (remaining: {})",
204            remaining
205        ));
206    }
207
208    if !response.status().is_success() {
209        return Err(anyhow!(
210            "Failed to fetch release information: HTTP {}",
211            response.status()
212        ));
213    }
214
215    if prerelease {
216        let releases: Vec<GitHubRelease> = response.json().await?;
217        releases
218            .into_iter()
219            .next()
220            .ok_or_else(|| anyhow!("No releases found"))
221    } else {
222        Ok(response.json().await?)
223    }
224}
225
226fn find_platform_asset(assets: &[GitHubAsset]) -> Result<&GitHubAsset> {
227    let target_os = env::consts::OS;
228    let target_arch = env::consts::ARCH;
229
230    // Define platform-specific patterns
231    let patterns = match (target_os, target_arch) {
232        ("windows", "x86_64") => vec!["windows", "win64", "x86_64-pc-windows"],
233        ("windows", "x86") => vec!["windows", "win32", "i686-pc-windows"],
234        ("macos", "x86_64") => vec!["macos", "darwin", "x86_64-apple-darwin"],
235        ("macos", "aarch64") => vec!["macos", "darwin", "aarch64-apple-darwin"],
236        ("linux", "x86_64") => vec!["linux", "x86_64-unknown-linux"],
237        ("linux", "aarch64") => vec!["linux", "aarch64-unknown-linux"],
238        _ => {
239            return Err(anyhow!(
240                "Unsupported platform: {}-{}",
241                target_os,
242                target_arch
243            ))
244        }
245    };
246
247    // Find matching asset
248    for asset in assets {
249        let name_lower = asset.name.to_lowercase();
250        if patterns.iter().any(|pattern| name_lower.contains(pattern)) {
251            return Ok(asset);
252        }
253    }
254
255    Err(anyhow!(
256        "No compatible binary found for {}-{}. Available assets: {}",
257        target_os,
258        target_arch,
259        assets
260            .iter()
261            .map(|a| a.name.as_str())
262            .collect::<Vec<_>>()
263            .join(", ")
264    ))
265}
266
267async fn download_and_install(
268    client: &reqwest::Client,
269    asset: &GitHubAsset,
270    force: bool,
271) -> Result<()> {
272    // Get current executable path
273    let current_exe = env::current_exe()?;
274    let backup_path = current_exe.with_extension("bak");
275
276    // Try downloading with multi-channel fallback
277    let content = download_with_fallback(client, asset).await?;
278
279    // Create temporary file for the new binary
280    let temp_path = current_exe.with_extension("tmp");
281
282    // Handle different asset types
283    if asset.name.ends_with(".zip") {
284        extract_from_zip(&content, &temp_path)?;
285    } else if asset.name.ends_with(".tar.gz") {
286        extract_from_tar_gz(&content, &temp_path)?;
287    } else {
288        // Assume it's a raw binary
289        fs::write(&temp_path, content)?;
290    }
291
292    // Make executable on Unix systems
293    #[cfg(unix)]
294    {
295        use std::os::unix::fs::PermissionsExt;
296        let mut perms = fs::metadata(&temp_path)?.permissions();
297        perms.set_mode(0o755);
298        fs::set_permissions(&temp_path, perms)?;
299    }
300
301    // Backup current executable
302    if current_exe.exists() && !force {
303        if backup_path.exists() {
304            fs::remove_file(&backup_path)?;
305        }
306        fs::rename(&current_exe, &backup_path)?;
307        UI::detail(&format!(
308            "📦 Backed up current version to {}",
309            backup_path.display()
310        ));
311    }
312
313    // Replace current executable
314    fs::rename(&temp_path, &current_exe)?;
315
316    UI::detail(&format!(
317        "✅ Installed new version to {}",
318        current_exe.display()
319    ));
320
321    Ok(())
322}
323
324fn extract_from_zip(content: &[u8], output_path: &PathBuf) -> Result<()> {
325    use std::io::Cursor;
326    use zip::ZipArchive;
327
328    let cursor = Cursor::new(content);
329    let mut archive = ZipArchive::new(cursor)?;
330
331    // Find the vx executable in the archive
332    for i in 0..archive.len() {
333        let mut file = archive.by_index(i)?;
334        let name = file.name();
335
336        if name.ends_with("vx") || name.ends_with("vx.exe") {
337            let mut output = fs::File::create(output_path)?;
338            std::io::copy(&mut file, &mut output)?;
339            return Ok(());
340        }
341    }
342
343    Err(anyhow!("vx executable not found in ZIP archive"))
344}
345
346fn extract_from_tar_gz(content: &[u8], output_path: &PathBuf) -> Result<()> {
347    use flate2::read::GzDecoder;
348    use std::io::Cursor;
349    use tar::Archive;
350
351    let cursor = Cursor::new(content);
352    let gz = GzDecoder::new(cursor);
353    let mut archive = Archive::new(gz);
354
355    for entry in archive.entries()? {
356        let mut entry = entry?;
357        let path = entry.path()?;
358
359        if let Some(name) = path.file_name() {
360            if name == "vx" || name == "vx.exe" {
361                let mut output = fs::File::create(output_path)?;
362                std::io::copy(&mut entry, &mut output)?;
363                return Ok(());
364            }
365        }
366    }
367
368    Err(anyhow!("vx executable not found in TAR.GZ archive"))
369}
370
371async fn try_jsdelivr_api(client: &reqwest::Client, prerelease: bool) -> Result<GitHubRelease> {
372    let url = "https://data.jsdelivr.com/v1/package/gh/loonghao/vx";
373
374    let response = client.get(url).send().await?;
375
376    if !response.status().is_success() {
377        return Err(anyhow!(
378            "Failed to fetch from jsDelivr: {}",
379            response.status()
380        ));
381    }
382
383    let json: serde_json::Value = response.json().await?;
384
385    // Extract version information from jsDelivr response
386    let versions = json["versions"]
387        .as_array()
388        .ok_or_else(|| anyhow!("No versions found in jsDelivr response"))?;
389
390    let latest_version = if prerelease {
391        // For prerelease, get the first version (latest)
392        versions.first()
393    } else {
394        // For stable, find the first non-prerelease version
395        versions.iter().find(|v| {
396            if let Some(version_str) = v.as_str() {
397                !version_str.contains("-") // Simple check for prerelease
398            } else {
399                false
400            }
401        })
402    }
403    .and_then(|v| v.as_str())
404    .ok_or_else(|| anyhow!("No suitable version found"))?;
405
406    // Create CDN-based assets for the version
407    let assets = create_cdn_assets(latest_version);
408
409    // Create a minimal GitHubRelease structure from jsDelivr data
410    Ok(GitHubRelease {
411        tag_name: latest_version.to_string(),
412        name: format!("Release {}", latest_version),
413        body: "Release information retrieved from CDN".to_string(),
414        prerelease: latest_version.contains("-"),
415        assets,
416    })
417}
418
419fn create_cdn_assets(version: &str) -> Vec<GitHubAsset> {
420    let base_url = format!("https://cdn.jsdelivr.net/gh/loonghao/vx@v{}", version);
421
422    // Define platform-specific asset names based on our release naming convention
423    let asset_configs = vec![
424        ("vx-Windows-msvc-x86_64.zip", "windows", "x86_64"),
425        ("vx-Windows-msvc-arm64.zip", "windows", "aarch64"),
426        ("vx-Linux-musl-x86_64.tar.gz", "linux", "x86_64"),
427        ("vx-Linux-musl-arm64.tar.gz", "linux", "aarch64"),
428        ("vx-macOS-x86_64.tar.gz", "macos", "x86_64"),
429        ("vx-macOS-arm64.tar.gz", "macos", "aarch64"),
430    ];
431
432    asset_configs
433        .into_iter()
434        .map(|(name, _os, _arch)| GitHubAsset {
435            name: name.to_string(),
436            browser_download_url: format!("{}/{}", base_url, name),
437            size: 0, // Size unknown from CDN
438        })
439        .collect()
440}
441
442async fn download_with_fallback(client: &reqwest::Client, asset: &GitHubAsset) -> Result<Vec<u8>> {
443    // Extract version from the original URL for CDN fallback
444    let version = extract_version_from_url(&asset.browser_download_url);
445
446    // Define download channels in order of preference
447    // If original URL is from CDN (jsDelivr), it means we got version info from CDN
448    // so we should prefer CDN for downloads too
449    let channels = if asset.browser_download_url.contains("jsdelivr.net") {
450        // CDN-first strategy (when version came from CDN)
451        vec![
452            ("jsDelivr CDN", asset.browser_download_url.clone()),
453            (
454                "Fastly CDN",
455                format!(
456                    "https://fastly.jsdelivr.net/gh/loonghao/vx@v{}/{}",
457                    version, asset.name
458                ),
459            ),
460            (
461                "GitHub Releases",
462                format!(
463                    "https://github.com/loonghao/vx/releases/download/v{}/{}",
464                    version, asset.name
465                ),
466            ),
467        ]
468    } else {
469        // GitHub-first strategy (when version came from GitHub API)
470        vec![
471            ("GitHub Releases", asset.browser_download_url.clone()),
472            (
473                "jsDelivr CDN",
474                format!(
475                    "https://cdn.jsdelivr.net/gh/loonghao/vx@v{}/{}",
476                    version, asset.name
477                ),
478            ),
479            (
480                "Fastly CDN",
481                format!(
482                    "https://fastly.jsdelivr.net/gh/loonghao/vx@v{}/{}",
483                    version, asset.name
484                ),
485            ),
486        ]
487    };
488
489    for (channel_name, url) in channels {
490        UI::detail(&format!("🔄 Trying {}: {}", channel_name, url));
491
492        match client.get(&url).send().await {
493            Ok(response) => {
494                if response.status().is_success() {
495                    match response.bytes().await {
496                        Ok(content) => {
497                            if content.len() > 1024 {
498                                // Basic size validation
499                                UI::info(&format!(
500                                    "✅ Downloaded from {} ({} bytes)",
501                                    channel_name,
502                                    content.len()
503                                ));
504                                return Ok(content.to_vec());
505                            } else {
506                                UI::warn(&format!(
507                                    "⚠️ Downloaded file too small from {}, trying next channel...",
508                                    channel_name
509                                ));
510                            }
511                        }
512                        Err(e) => {
513                            UI::warn(&format!(
514                                "⚠️ Failed to read content from {}: {}",
515                                channel_name, e
516                            ));
517                        }
518                    }
519                } else {
520                    UI::warn(&format!(
521                        "⚠️ HTTP {} from {}, trying next channel...",
522                        response.status(),
523                        channel_name
524                    ));
525                }
526            }
527            Err(e) => {
528                UI::warn(&format!("⚠️ Failed to connect to {}: {}", channel_name, e));
529            }
530        }
531    }
532
533    Err(anyhow!("Failed to download from all channels"))
534}
535
536fn extract_version_from_url(url: &str) -> String {
537    // Extract version from GitHub release URL or CDN URL
538    // Look for patterns like "/v1.2.3/" or "@v1.2.3"
539    for part in url.split('/') {
540        if part.starts_with('v') && part.len() > 1 {
541            let version_part = &part[1..]; // Remove 'v' prefix
542            if version_part.chars().next().unwrap_or('a').is_ascii_digit() {
543                return version_part.to_string();
544            }
545        }
546        if part.starts_with("@v") && part.len() > 2 {
547            let version_part = &part[2..]; // Remove '@v' prefix
548            if version_part.chars().next().unwrap_or('a').is_ascii_digit() {
549                return version_part.to_string();
550            }
551        }
552    }
553
554    // Fallback to current version if extraction fails
555    env!("CARGO_PKG_VERSION").to_string()
556}