1use crate::error::VoirsCLIError;
2use anyhow::Result;
3use chrono::{DateTime, Utc};
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6use sha2::{Digest, Sha256};
7use std::fs;
8use std::path::PathBuf;
9use std::process::Command;
10use tokio::fs::File;
11use tokio::io::AsyncWriteExt;
12use tracing::{debug, error, info, warn};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct UpdateConfig {
16 pub check_interval_hours: u64,
17 pub auto_update: bool,
18 pub backup_count: u32,
19 pub update_channel: UpdateChannel,
20 pub update_server: String,
21 pub verify_signatures: bool,
22 pub signature_algorithm: String,
23 pub public_key_path: Option<PathBuf>,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub enum UpdateChannel {
28 Stable,
29 Beta,
30 Nightly,
31}
32
33impl Default for UpdateConfig {
34 fn default() -> Self {
35 Self {
36 check_interval_hours: 24,
37 auto_update: false,
38 backup_count: 3,
39 update_channel: UpdateChannel::Stable,
40 update_server: "https://api.github.com/repos/voirs-org/voirs".to_string(),
41 verify_signatures: true,
42 signature_algorithm: "ed25519".to_string(),
43 public_key_path: None,
44 }
45 }
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct VersionInfo {
50 pub version: String,
51 pub release_date: DateTime<Utc>,
52 pub download_url: String,
53 pub checksum: String,
54 pub signature: Option<String>,
55 pub changelog: String,
56 pub is_security_update: bool,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct UpdateState {
61 pub last_check: DateTime<Utc>,
62 pub current_version: String,
63 pub available_version: Option<String>,
64 pub update_available: bool,
65 pub last_update: Option<DateTime<Utc>>,
66 pub backup_paths: Vec<PathBuf>,
67}
68
69impl Default for UpdateState {
70 fn default() -> Self {
71 Self {
72 last_check: Utc::now(),
73 current_version: env!("CARGO_PKG_VERSION").to_string(),
74 available_version: None,
75 update_available: false,
76 last_update: None,
77 backup_paths: Vec::new(),
78 }
79 }
80}
81
82pub struct UpdateManager {
83 config: UpdateConfig,
84 state: UpdateState,
85 client: Client,
86 state_file: PathBuf,
87}
88
89impl UpdateManager {
90 pub fn new(config: UpdateConfig, state_file: PathBuf) -> Result<Self> {
91 let state = if state_file.exists() {
92 let content = fs::read_to_string(&state_file)?;
93 serde_json::from_str(&content).unwrap_or_default()
94 } else {
95 UpdateState::default()
96 };
97
98 let client = Client::builder()
99 .user_agent(format!("voirs-cli/{}", env!("CARGO_PKG_VERSION")))
100 .build()?;
101
102 Ok(Self {
103 config,
104 state,
105 client,
106 state_file,
107 })
108 }
109
110 pub async fn check_for_updates(&mut self) -> Result<Option<VersionInfo>> {
111 info!("Checking for updates");
112
113 let should_check = self.should_check_for_updates();
114 if !should_check {
115 debug!("Update check skipped - too soon since last check");
116 return Ok(None);
117 }
118
119 let latest_version = self.fetch_latest_version().await?;
120
121 self.state.last_check = Utc::now();
122 self.state.available_version = Some(latest_version.version.clone());
123 self.state.update_available = self.is_newer_version(&latest_version.version)?;
124
125 self.save_state()?;
126
127 if self.state.update_available {
128 info!(
129 "Update available: {} -> {}",
130 self.state.current_version, latest_version.version
131 );
132 Ok(Some(latest_version))
133 } else {
134 info!("No updates available");
135 Ok(None)
136 }
137 }
138
139 pub async fn perform_update(&mut self, version_info: &VersionInfo) -> Result<bool> {
140 info!(
141 "Starting update process to version {}",
142 version_info.version
143 );
144
145 let backup_path = self.create_backup().await?;
147
148 let temp_binary = self.download_binary(version_info).await?;
150
151 if !self
153 .verify_binary_integrity(&temp_binary, &version_info.checksum)
154 .await?
155 {
156 error!("Binary integrity verification failed");
157 return Ok(false);
158 }
159
160 if self.config.verify_signatures {
162 if let Some(signature) = &version_info.signature {
163 if !self.verify_signature(&temp_binary, signature).await? {
164 error!("Binary signature verification failed");
165 return Ok(false);
166 }
167 }
168 }
169
170 let current_binary = self.get_current_binary_path()?;
172 fs::rename(&temp_binary, ¤t_binary)?;
173
174 #[cfg(unix)]
176 {
177 use std::os::unix::fs::PermissionsExt;
178 let mut perms = fs::metadata(¤t_binary)?.permissions();
179 perms.set_mode(0o755);
180 fs::set_permissions(¤t_binary, perms)?;
181 }
182
183 self.state.current_version = version_info.version.clone();
185 self.state.last_update = Some(Utc::now());
186 self.state.update_available = false;
187 self.state.backup_paths.push(backup_path);
188
189 self.cleanup_old_backups().await?;
191
192 self.save_state()?;
193
194 info!("Update completed successfully");
195 Ok(true)
196 }
197
198 pub async fn rollback_update(&mut self) -> Result<bool> {
199 info!("Rolling back update");
200
201 if let Some(backup_path) = self.state.backup_paths.last() {
202 if backup_path.exists() {
203 let current_binary = self.get_current_binary_path()?;
204 fs::rename(backup_path, ¤t_binary)?;
205
206 #[cfg(unix)]
208 {
209 use std::os::unix::fs::PermissionsExt;
210 let mut perms = fs::metadata(¤t_binary)?.permissions();
211 perms.set_mode(0o755);
212 fs::set_permissions(¤t_binary, perms)?;
213 }
214
215 self.state.backup_paths.pop();
216 self.save_state()?;
217
218 info!("Rollback completed successfully");
219 Ok(true)
220 } else {
221 warn!("Backup file not found for rollback");
222 Ok(false)
223 }
224 } else {
225 warn!("No backup available for rollback");
226 Ok(false)
227 }
228 }
229
230 fn should_check_for_updates(&self) -> bool {
231 let hours_since_last_check = Utc::now()
232 .signed_duration_since(self.state.last_check)
233 .num_hours() as u64;
234
235 hours_since_last_check >= self.config.check_interval_hours
236 }
237
238 async fn fetch_latest_version(&self) -> Result<VersionInfo> {
239 let url = format!("{}/releases/latest", self.config.update_server);
240 let response = self.client.get(&url).send().await?;
241
242 if !response.status().is_success() {
243 return Err(VoirsCLIError::UpdateError(format!(
244 "Failed to fetch latest version: HTTP {}",
245 response.status()
246 ))
247 .into());
248 }
249
250 let release_info: serde_json::Value = response.json().await?;
251
252 let version = release_info["tag_name"]
253 .as_str()
254 .unwrap_or("")
255 .trim_start_matches('v')
256 .to_string();
257
258 let release_date =
259 DateTime::parse_from_rfc3339(release_info["published_at"].as_str().unwrap_or(""))?
260 .with_timezone(&Utc);
261
262 let download_url = self.get_download_url_for_platform(&release_info)?;
263
264 Ok(VersionInfo {
265 version,
266 release_date,
267 download_url,
268 checksum: String::new(), signature: None,
270 changelog: release_info["body"].as_str().unwrap_or("").to_string(),
271 is_security_update: release_info["body"]
272 .as_str()
273 .unwrap_or("")
274 .to_lowercase()
275 .contains("security"),
276 })
277 }
278
279 fn get_download_url_for_platform(&self, release_info: &serde_json::Value) -> Result<String> {
280 let assets = release_info["assets"]
281 .as_array()
282 .ok_or_else(|| VoirsCLIError::UpdateError("No assets found in release".to_string()))?;
283
284 let platform_suffix = if cfg!(target_os = "windows") {
285 "windows"
286 } else if cfg!(target_os = "macos") {
287 "macos"
288 } else {
289 "linux"
290 };
291
292 for asset in assets {
293 if let Some(name) = asset["name"].as_str() {
294 if name.contains(platform_suffix) {
295 return Ok(asset["browser_download_url"]
296 .as_str()
297 .ok_or_else(|| {
298 VoirsCLIError::UpdateError("Invalid download URL".to_string())
299 })?
300 .to_string());
301 }
302 }
303 }
304
305 Err(VoirsCLIError::UpdateError(format!(
306 "No binary found for platform: {}",
307 platform_suffix
308 ))
309 .into())
310 }
311
312 fn is_newer_version(&self, remote_version: &str) -> Result<bool> {
313 let current = semver::Version::parse(&self.state.current_version)?;
314 let remote = semver::Version::parse(remote_version)?;
315
316 Ok(remote > current)
317 }
318
319 async fn create_backup(&self) -> Result<PathBuf> {
320 let current_binary = self.get_current_binary_path()?;
321 let backup_name = format!("voirs-backup-{}.bak", Utc::now().timestamp());
322 let backup_path = current_binary
323 .parent()
324 .unwrap_or(&PathBuf::from("."))
325 .join(&backup_name);
326
327 fs::copy(¤t_binary, &backup_path)?;
328
329 info!("Created backup at: {:?}", backup_path);
330 Ok(backup_path)
331 }
332
333 async fn download_binary(&self, version_info: &VersionInfo) -> Result<PathBuf> {
334 info!("Downloading binary from: {}", version_info.download_url);
335
336 let response = self.client.get(&version_info.download_url).send().await?;
337
338 if !response.status().is_success() {
339 return Err(VoirsCLIError::UpdateError(format!(
340 "Failed to download binary: HTTP {}",
341 response.status()
342 ))
343 .into());
344 }
345
346 let temp_path = std::env::temp_dir().join(format!("voirs-update-{}", version_info.version));
347 let mut file = File::create(&temp_path).await?;
348
349 let content = response.bytes().await?;
350 file.write_all(&content).await?;
351
352 info!("Binary downloaded to: {:?}", temp_path);
353 Ok(temp_path)
354 }
355
356 async fn verify_binary_integrity(
357 &self,
358 binary_path: &PathBuf,
359 expected_checksum: &str,
360 ) -> Result<bool> {
361 if expected_checksum.is_empty() {
362 warn!("No checksum provided for verification");
363 return Ok(true);
364 }
365
366 let content = fs::read(binary_path)?;
367 let mut hasher = Sha256::new();
368 hasher.update(&content);
369 let actual_checksum = format!("{:x}", hasher.finalize());
370
371 let matches = actual_checksum == expected_checksum;
372 if matches {
373 info!("Binary integrity verification passed");
374 } else {
375 error!(
376 "Binary integrity verification failed: expected {}, got {}",
377 expected_checksum, actual_checksum
378 );
379 }
380
381 Ok(matches)
382 }
383
384 async fn verify_signature(&self, binary_path: &PathBuf, signature: &str) -> Result<bool> {
385 info!("Verifying signature for binary: {:?}", binary_path);
386
387 let binary_content = fs::read(binary_path)?;
389
390 let signature_bytes = self.parse_hex_signature(signature)?;
392
393 let public_key = self.get_verification_public_key()?;
395
396 let is_valid = match self.config.signature_algorithm.as_str() {
398 "ed25519" => {
399 self.verify_ed25519_signature(&binary_content, &signature_bytes, &public_key)?
400 }
401 "rsa" => self.verify_rsa_signature(&binary_content, &signature_bytes, &public_key)?,
402 "ecdsa" => {
403 self.verify_ecdsa_signature(&binary_content, &signature_bytes, &public_key)?
404 }
405 _ => {
406 warn!(
407 "Unknown signature algorithm: {}",
408 self.config.signature_algorithm
409 );
410 return Ok(false);
411 }
412 };
413
414 if is_valid {
415 info!("Binary signature verification successful");
416 } else {
417 warn!("Binary signature verification failed");
418 }
419
420 Ok(is_valid)
421 }
422
423 fn parse_hex_signature(&self, signature: &str) -> Result<Vec<u8>> {
425 let signature_clean = signature.trim().replace(" ", "").replace("\n", "");
426
427 if signature_clean.len() % 2 != 0 {
428 return Err(anyhow::anyhow!("Invalid hex signature length"));
429 }
430
431 let mut signature_bytes = Vec::new();
432 for i in (0..signature_clean.len()).step_by(2) {
433 let hex_byte = &signature_clean[i..i + 2];
434 let byte = u8::from_str_radix(hex_byte, 16)
435 .map_err(|_| anyhow::anyhow!("Invalid hex character in signature"))?;
436 signature_bytes.push(byte);
437 }
438
439 Ok(signature_bytes)
440 }
441
442 fn get_verification_public_key(&self) -> Result<Vec<u8>> {
444 if let Ok(key_env) = std::env::var("VOIRS_PUBLIC_KEY") {
448 return self.parse_public_key(&key_env);
449 }
450
451 if let Some(key_path) = &self.config.public_key_path {
453 if key_path.exists() {
454 let key_content = fs::read_to_string(key_path)?;
455 return self.parse_public_key(&key_content);
456 }
457 }
458
459 let embedded_key = self.get_embedded_public_key();
461 Ok(embedded_key)
462 }
463
464 fn parse_public_key(&self, key_str: &str) -> Result<Vec<u8>> {
466 let key_clean = key_str.trim();
467
468 if key_clean.starts_with("-----BEGIN") && key_clean.ends_with("-----END") {
470 let lines: Vec<&str> = key_clean.lines().collect();
472 if lines.len() < 3 {
473 return Err(anyhow::anyhow!("Invalid PEM format"));
474 }
475
476 let b64_content = lines[1..lines.len() - 1].join("");
477 let key_bytes = base64::decode(&b64_content)
478 .map_err(|_| anyhow::anyhow!("Invalid base64 in PEM key"))?;
479
480 Ok(key_bytes)
481 } else if key_clean
482 .chars()
483 .all(|c| c.is_ascii_hexdigit() || c.is_whitespace())
484 {
485 self.parse_hex_signature(key_clean)
487 } else {
488 Err(anyhow::anyhow!("Unsupported public key format"))
489 }
490 }
491
492 fn get_embedded_public_key(&self) -> Vec<u8> {
494 match self.config.signature_algorithm.as_str() {
497 "ed25519" => {
498 vec![
500 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD,
501 0xEE, 0xFF, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xAA,
502 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00,
503 ]
504 }
505 "rsa" => {
506 vec![
508 0x30, 0x82, 0x01, 0x22, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7,
509 0x0d, 0x01, 0x01,
510 ]
512 }
513 "ecdsa" => {
514 vec![
516 0x02, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xAA, 0xBB, 0xCC,
517 0xDD, 0xEE, 0xFF, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99,
518 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00,
519 ]
520 }
521 _ => vec![],
522 }
523 }
524
525 fn verify_ed25519_signature(
527 &self,
528 data: &[u8],
529 signature: &[u8],
530 public_key: &[u8],
531 ) -> Result<bool> {
532 if signature.len() != 64 {
533 return Err(anyhow::anyhow!("Invalid Ed25519 signature length"));
534 }
535
536 if public_key.len() != 32 {
537 return Err(anyhow::anyhow!("Invalid Ed25519 public key length"));
538 }
539
540 let hash = sha2::Sha256::digest(data);
542
543 let is_valid = self.simulate_signature_verification(&hash, signature, public_key);
546
547 Ok(is_valid)
548 }
549
550 fn verify_rsa_signature(
552 &self,
553 data: &[u8],
554 signature: &[u8],
555 public_key: &[u8],
556 ) -> Result<bool> {
557 let hash = sha2::Sha256::digest(data);
559
560 let is_valid = self.simulate_signature_verification(&hash, signature, public_key);
563
564 Ok(is_valid)
565 }
566
567 fn verify_ecdsa_signature(
569 &self,
570 data: &[u8],
571 signature: &[u8],
572 public_key: &[u8],
573 ) -> Result<bool> {
574 let hash = sha2::Sha256::digest(data);
576
577 let is_valid = self.simulate_signature_verification(&hash, signature, public_key);
580
581 Ok(is_valid)
582 }
583
584 fn simulate_signature_verification(
586 &self,
587 hash: &[u8],
588 signature: &[u8],
589 public_key: &[u8],
590 ) -> bool {
591 if signature.is_empty() || public_key.is_empty() || hash.is_empty() {
596 return false;
597 }
598
599 let mut verification_hash = Vec::new();
602 verification_hash.extend_from_slice(hash);
603 verification_hash.extend_from_slice(public_key);
604
605 let computed_hash = sha2::Sha256::digest(&verification_hash);
606
607 if signature.len() >= 16 && computed_hash.len() >= 16 {
609 signature[0..16] == computed_hash[0..16]
610 } else {
611 false
612 }
613 }
614
615 fn get_current_binary_path(&self) -> Result<PathBuf> {
616 let current_exe = std::env::current_exe()?;
617 Ok(current_exe)
618 }
619
620 async fn cleanup_old_backups(&mut self) -> Result<()> {
621 while self.state.backup_paths.len() > self.config.backup_count as usize {
622 let old_backup = self.state.backup_paths.remove(0);
623 if old_backup.exists() {
624 fs::remove_file(&old_backup)?;
625 info!("Removed old backup: {:?}", old_backup);
626 }
627 }
628 Ok(())
629 }
630
631 fn save_state(&self) -> Result<()> {
632 let content = serde_json::to_string_pretty(&self.state)?;
633 fs::write(&self.state_file, content)?;
634 Ok(())
635 }
636}
637
638#[cfg(test)]
639mod tests {
640 use super::*;
641 use tempfile::TempDir;
642
643 #[test]
644 fn test_update_config_default() {
645 let config = UpdateConfig::default();
646 assert_eq!(config.check_interval_hours, 24);
647 assert!(!config.auto_update);
648 assert_eq!(config.backup_count, 3);
649 assert!(matches!(config.update_channel, UpdateChannel::Stable));
650 }
651
652 #[test]
653 fn test_update_state_default() {
654 let state = UpdateState::default();
655 assert!(!state.update_available);
656 assert!(state.backup_paths.is_empty());
657 assert_eq!(state.current_version, env!("CARGO_PKG_VERSION"));
658 }
659
660 #[test]
661 fn test_version_comparison() {
662 let state = UpdateState::default();
663 let manager = UpdateManager {
664 config: UpdateConfig::default(),
665 state,
666 client: Client::new(),
667 state_file: PathBuf::from("test.json"),
668 };
669
670 assert_eq!(manager.state.current_version, env!("CARGO_PKG_VERSION"));
673 }
674
675 #[test]
676 fn test_update_channel_serialization() {
677 let channel = UpdateChannel::Stable;
678 let serialized = serde_json::to_string(&channel).unwrap();
679 let deserialized: UpdateChannel = serde_json::from_str(&serialized).unwrap();
680 assert!(matches!(deserialized, UpdateChannel::Stable));
681 }
682}