1use crate::{
2 models::common::enums::{Channel, Provider},
3 providers::provider_manager::ProviderManager,
4 services::{
5 packaging::{PackageChecker, PackageInstaller, PackageRemover, PackageUpgrader},
6 storage::package_storage::PackageStorage,
7 },
8 utils::static_paths::UpstreamPaths,
9};
10
11use anyhow::{Context, Result, anyhow};
12use futures_util::stream::{self, StreamExt};
13use indicatif::HumanBytes;
14use std::collections::BTreeMap;
15use std::sync::{Arc, Mutex};
16use tokio::time::{self, Duration};
17
18const CHECK_CONCURRENCY: usize = 8;
19const UPGRADE_CONCURRENCY: usize = 4;
20type ProgressEntry = (Channel, Provider, u64, u64);
21type ProgressState = Arc<Mutex<BTreeMap<String, ProgressEntry>>>;
22
23macro_rules! message {
24 ($cb:expr, $($arg:tt)*) => {{
25 if let Some(cb) = $cb.as_mut() {
26 cb(&format!($($arg)*));
27 }
28 }};
29}
30
31pub struct UpgradeOperation<'a> {
32 upgrader: PackageUpgrader<'a>,
33 checker: PackageChecker<'a>,
34 package_storage: &'a mut PackageStorage,
35}
36
37pub enum UpdateCheckStatus {
38 UpdateAvailable { current: String, latest: String },
39 UpToDate { current: String },
40 Failed { error: String },
41 NotInstalled,
42}
43
44pub struct UpdateCheckRow {
45 pub name: String,
46 pub channel: Option<Channel>,
47 pub provider: Option<Provider>,
48 pub status: UpdateCheckStatus,
49}
50
51impl<'a> UpgradeOperation<'a> {
52 fn truncate_error(value: &str, max: usize) -> String {
53 let char_count = value.chars().count();
54 if char_count <= max {
55 return value.to_string();
56 }
57
58 let mut out = String::new();
59 for ch in value.chars().take(max.saturating_sub(3)) {
60 out.push(ch);
61 }
62 out.push_str("...");
63 out
64 }
65
66 fn format_transfer(downloaded: u64, total: u64) -> String {
67 if total > 0 {
68 format!("{} / {}", HumanBytes(downloaded), HumanBytes(total))
69 } else if downloaded > 0 {
70 format!("{}", HumanBytes(downloaded))
71 } else {
72 "-".to_string()
73 }
74 }
75
76 fn render_progress_row(
77 name: &str,
78 channel: &Channel,
79 provider: &Provider,
80 downloaded: u64,
81 total: u64,
82 ) -> String {
83 format!(
84 " {:<28} {:<10} {:<3} {:<10} {}",
85 name,
86 channel.to_string().to_lowercase(),
87 "u",
88 provider.to_string(),
89 Self::format_transfer(downloaded, total)
90 )
91 }
92
93 async fn check_packages_parallel(
94 &self,
95 packages: Vec<crate::models::upstream::Package>,
96 ) -> Vec<(
97 crate::models::upstream::Package,
98 Result<Option<(String, String)>>,
99 )> {
100 let mut checked = stream::iter(packages.into_iter().enumerate().map(
101 |(idx, pkg)| async move {
102 let result = self.checker.check_one(&pkg).await;
103 (idx, pkg, result)
104 },
105 ))
106 .buffer_unordered(CHECK_CONCURRENCY)
107 .collect::<Vec<_>>()
108 .await;
109
110 checked.sort_by_key(|(idx, _, _)| *idx);
111 checked
112 .into_iter()
113 .map(|(_, pkg, result)| (pkg, result))
114 .collect()
115 }
116
117 async fn check_installed_packages_detailed(
118 &self,
119 packages: Vec<crate::models::upstream::Package>,
120 ) -> Vec<UpdateCheckRow> {
121 self.check_packages_parallel(packages)
122 .await
123 .into_iter()
124 .map(|(pkg, result)| match result {
125 Ok(Some((current, latest))) => UpdateCheckRow {
126 name: pkg.name,
127 channel: Some(pkg.channel),
128 provider: Some(pkg.provider),
129 status: UpdateCheckStatus::UpdateAvailable { current, latest },
130 },
131 Ok(None) => UpdateCheckRow {
132 name: pkg.name,
133 channel: Some(pkg.channel),
134 provider: Some(pkg.provider),
135 status: UpdateCheckStatus::UpToDate {
136 current: pkg.version.to_string(),
137 },
138 },
139 Err(error) => UpdateCheckRow {
140 name: pkg.name,
141 channel: Some(pkg.channel),
142 provider: Some(pkg.provider),
143 status: UpdateCheckStatus::Failed {
144 error: error.to_string(),
145 },
146 },
147 })
148 .collect()
149 }
150
151 pub fn new(
152 provider_manager: &'a ProviderManager,
153 package_storage: &'a mut PackageStorage,
154 paths: &'a UpstreamPaths,
155 ) -> Result<Self> {
156 let installer = PackageInstaller::new(provider_manager, paths)?;
157 let remover = PackageRemover::new(paths);
158
159 let upgrader = PackageUpgrader::new(provider_manager, installer, remover, paths);
160
161 let checker = PackageChecker::new(provider_manager);
162
163 Ok(Self {
164 upgrader,
165 checker,
166 package_storage,
167 })
168 }
169
170 pub async fn upgrade_all<F, G, H>(
171 &mut self,
172 force_option: &bool,
173 ignore_checksums: bool,
174 download_progress: &mut Option<F>,
175 overall_progress: &mut Option<G>,
176 message_callback: &mut Option<H>,
177 ) -> Result<()>
178 where
179 F: FnMut(u64, u64),
180 G: FnMut(u32, u32),
181 H: FnMut(&str),
182 {
183 let names: Vec<String> = self
184 .package_storage
185 .get_all_packages()
186 .iter()
187 .map(|p| p.name.clone())
188 .collect();
189
190 self.upgrade_bulk(
191 &names,
192 force_option,
193 ignore_checksums,
194 download_progress,
195 overall_progress,
196 message_callback,
197 )
198 .await
199 }
200
201 pub async fn upgrade_bulk<F, G, H>(
202 &mut self,
203 names: &[String],
204 force_option: &bool,
205 ignore_checksums: bool,
206 download_progress: &mut Option<F>,
207 overall_progress: &mut Option<G>,
208 message_callback: &mut Option<H>,
209 ) -> Result<()>
210 where
211 F: FnMut(u64, u64),
212 G: FnMut(u32, u32),
213 H: FnMut(&str),
214 {
215 let total = names.len() as u32;
216 let mut completed = 0;
217 let mut failures = 0;
218 let mut upgraded = 0;
219 let force = *force_option;
220 let upgrader = &self.upgrader;
221 let progress_state: ProgressState = Arc::new(Mutex::new(BTreeMap::new()));
222 let mut last_progress_render: BTreeMap<String, String> = BTreeMap::new();
223
224 let packages: Vec<_> = names
225 .iter()
226 .map(|name| {
227 self.package_storage
228 .get_package_by_name(name)
229 .ok_or_else(|| anyhow!("Package '{}' is not installed", name))
230 .cloned()
231 })
232 .collect::<Result<Vec<_>>>()?;
233
234 let mut updated_packages = Vec::new();
235 let mut pending = stream::iter(packages.into_iter().map(|package| {
236 let state_ref = Arc::clone(&progress_state);
237 async move {
238 let name = package.name.clone();
239 let channel = package.channel.clone();
240 let provider = package.provider.clone();
241
242 if let Ok(mut state) = state_ref.lock() {
243 state.insert(name.clone(), (channel.clone(), provider.clone(), 0, 0));
244 }
245
246 let mut downloaded: u64 = 0;
247 let mut bytes_total: u64 = 0;
248 let mut download_cb = Some(|d: u64, t: u64| {
249 downloaded = d;
250 bytes_total = t;
251 if let Ok(mut state) = state_ref.lock() {
252 state.insert(name.clone(), (channel.clone(), provider.clone(), d, t));
253 }
254 });
255 let mut no_messages: Option<fn(&str)> = None;
256
257 let result = upgrader
258 .upgrade(
259 &package,
260 force,
261 ignore_checksums,
262 &mut download_cb,
263 &mut no_messages,
264 )
265 .await
266 .context(format!("Failed to upgrade package '{}'", name));
267 (name, channel, provider, downloaded, bytes_total, result)
268 }
269 }))
270 .buffer_unordered(UPGRADE_CONCURRENCY);
271
272 let mut ticker = time::interval(Duration::from_millis(350));
273 ticker.set_missed_tick_behavior(time::MissedTickBehavior::Delay);
274
275 while completed < total {
276 tokio::select! {
277 maybe_item = pending.next() => {
278 let Some((name, channel, provider, downloaded, bytes_total, result)) = maybe_item else {
279 break;
280 };
281
282 if let Ok(mut state) = progress_state.lock() {
283 state.remove(&name);
284 }
285 last_progress_render.remove(&name);
286 message!(message_callback, "__UPGRADE_PROGRESS_DONE__ {}", name);
287
288 let transfer = Self::format_transfer(downloaded, bytes_total);
289 match result {
290 Ok(Some(updated)) => {
291 updated_packages.push(updated);
292 message!(
293 message_callback,
294 "[✓] {:<28} {:<10} {:<3} {:<10} {}",
295 name,
296 channel.to_string().to_lowercase(),
297 "u",
298 provider.to_string(),
299 transfer
300 );
301 upgraded += 1;
302 }
303 Ok(None) => {
304 message!(
305 message_callback,
306 "[=] {:<28} {:<10} {:<3} {:<10} {}",
307 name,
308 channel.to_string().to_lowercase(),
309 "-",
310 provider.to_string(),
311 transfer
312 );
313 }
314 Err(e) => {
315 message!(
316 message_callback,
317 "[!] {:<28} {:<10} {:<3} {:<10} {}",
318 name,
319 channel.to_string().to_lowercase(),
320 "!",
321 provider.to_string(),
322 Self::truncate_error(&e.to_string(), 36)
323 );
324 failures += 1;
325 }
326 }
327
328 completed += 1;
329 if let Some(cb) = overall_progress.as_mut() {
330 cb(completed, total);
331 }
332 }
333 _ = ticker.tick() => {
334 if let Ok(state) = progress_state.lock() {
335 for (name, (channel, provider, downloaded, total_bytes)) in state.iter() {
336 let row = Self::render_progress_row(
337 name,
338 channel,
339 provider,
340 *downloaded,
341 *total_bytes
342 );
343 let changed = last_progress_render
344 .get(name)
345 .map(|prev| prev != &row)
346 .unwrap_or(true);
347 if changed {
348 message!(message_callback, "__UPGRADE_PROGRESS_ROW__ {}\t{}", name, row);
349 last_progress_render.insert(name.clone(), row);
350 }
351 }
352 }
353 }
354 }
355 }
356
357 message!(message_callback, "__UPGRADE_PROGRESS_CLEAR__");
358
359 for updated in updated_packages {
361 self.package_storage.add_or_update_package(updated)?;
362 }
363
364 let _ = download_progress;
366
367 self.package_storage
368 .save_packages()
369 .context("Failed to save updated package information")?;
370
371 message!(
372 message_callback,
373 "Completed: {} upgraded, {} up-to-date, {} failed",
374 upgraded,
375 total - upgraded - failures,
376 failures
377 );
378
379 Ok(())
380 }
381
382 pub async fn upgrade_single<F, H>(
383 &mut self,
384 package_name: &str,
385 force_option: &bool,
386 ignore_checksums: bool,
387 download_progress: &mut Option<F>,
388 message_callback: &mut Option<H>,
389 ) -> Result<bool>
390 where
391 F: FnMut(u64, u64),
392 H: FnMut(&str),
393 {
394 let package = self
395 .package_storage
396 .get_package_by_name(package_name)
397 .ok_or_else(|| anyhow!("Package '{}' is not installed", package_name))?
398 .clone();
399
400 let upgraded = self
401 .upgrader
402 .upgrade(
403 &package,
404 *force_option,
405 ignore_checksums,
406 download_progress,
407 message_callback,
408 )
409 .await?;
410
411 if let Some(updated) = upgraded {
412 self.package_storage.add_or_update_package(updated)?;
413 self.package_storage.save_packages()?;
414 Ok(true)
415 } else {
416 Ok(false)
417 }
418 }
419
420 pub async fn check_all_detailed(&self) -> Vec<UpdateCheckRow> {
421 let packages = self.package_storage.get_all_packages().to_vec();
422 self.check_installed_packages_detailed(packages).await
423 }
424
425 pub async fn check_all_machine_readable(&self) -> Vec<(String, String, String)> {
426 let rows = self.check_all_detailed().await;
427 rows.into_iter()
428 .filter_map(|row| match row.status {
429 UpdateCheckStatus::UpdateAvailable { current, latest } => {
430 Some((row.name, current, latest))
431 }
432 _ => None,
433 })
434 .collect()
435 }
436
437 pub async fn check_selected_detailed(&self, package_names: &[String]) -> Vec<UpdateCheckRow> {
438 let mut rows: Vec<Option<UpdateCheckRow>> =
439 (0..package_names.len()).map(|_| None).collect();
440 let mut selected_packages = Vec::new();
441 let mut selected_indices = Vec::new();
442
443 for (idx, name) in package_names.iter().enumerate() {
444 match self.package_storage.get_package_by_name(name) {
445 Some(package) => {
446 selected_packages.push(package.clone());
447 selected_indices.push(idx);
448 }
449 None => {
450 rows[idx] = Some(UpdateCheckRow {
451 name: name.clone(),
452 channel: None,
453 provider: None,
454 status: UpdateCheckStatus::NotInstalled,
455 })
456 }
457 }
458 }
459
460 let checked_rows = self
461 .check_installed_packages_detailed(selected_packages)
462 .await;
463 for (row_idx, checked_row) in selected_indices.into_iter().zip(checked_rows) {
464 rows[row_idx] = Some(checked_row);
465 }
466
467 rows.into_iter().flatten().collect()
468 }
469
470 pub async fn check_selected_machine_readable(
471 &self,
472 package_names: &[String],
473 ) -> Vec<(String, String, String)> {
474 let rows = self.check_selected_detailed(package_names).await;
475 rows.into_iter()
476 .filter_map(|row| match row.status {
477 UpdateCheckStatus::UpdateAvailable { current, latest } => {
478 Some((row.name, current, latest))
479 }
480 _ => None,
481 })
482 .collect()
483 }
484}
485
486#[cfg(test)]
487mod tests {
488 use super::UpgradeOperation;
489 use crate::models::common::enums::{Channel, Provider};
490
491 #[test]
492 fn truncate_error_adds_ellipsis_when_limit_exceeded() {
493 let input = "this is a fairly long error string";
494 let truncated = UpgradeOperation::truncate_error(input, 12);
495 assert!(truncated.ends_with("..."));
496 assert!(truncated.chars().count() <= 12);
497 }
498
499 #[test]
500 fn format_transfer_handles_known_unknown_and_empty_sizes() {
501 assert_eq!(UpgradeOperation::format_transfer(0, 0), "-");
502 assert!(UpgradeOperation::format_transfer(42, 0).contains("42"));
503 let known_total = UpgradeOperation::format_transfer(1024, 2048);
504 assert!(known_total.contains('/'));
505 }
506
507 #[test]
508 fn render_progress_row_includes_package_channel_provider_and_transfer() {
509 let row = UpgradeOperation::render_progress_row(
510 "ripgrep",
511 &Channel::Stable,
512 &Provider::Github,
513 128,
514 256,
515 );
516 assert!(row.contains("ripgrep"));
517 assert!(row.contains("stable"));
518 assert!(row.contains("github"));
519 assert!(row.contains('/'));
520 }
521}