Skip to main content

upstream_rs/application/operations/
upgrade_operation.rs

1use crate::{
2    models::common::enums::{Channel, Provider},
3    providers::provider_manager::ProviderManager,
4    services::{
5        packaging::{PackageChecker, PackageInstaller, PackageRemover, PackageUpgrader},
6        storage::package_storage::PackageStorage,
7    },
8    utils::static_paths::UpstreamPaths,
9};
10
11use anyhow::{Context, Result, anyhow};
12use futures_util::stream::{self, StreamExt};
13use indicatif::HumanBytes;
14use std::collections::BTreeMap;
15use std::sync::{Arc, Mutex};
16use tokio::time::{self, Duration};
17
18const CHECK_CONCURRENCY: usize = 8;
19const UPGRADE_CONCURRENCY: usize = 4;
20type ProgressEntry = (Channel, Provider, u64, u64);
21type ProgressState = Arc<Mutex<BTreeMap<String, ProgressEntry>>>;
22
23macro_rules! message {
24    ($cb:expr, $($arg:tt)*) => {{
25        if let Some(cb) = $cb.as_mut() {
26            cb(&format!($($arg)*));
27        }
28    }};
29}
30
31pub struct UpgradeOperation<'a> {
32    upgrader: PackageUpgrader<'a>,
33    checker: PackageChecker<'a>,
34    package_storage: &'a mut PackageStorage,
35}
36
37pub enum UpdateCheckStatus {
38    UpdateAvailable { current: String, latest: String },
39    UpToDate { current: String },
40    Failed { error: String },
41    NotInstalled,
42}
43
44pub struct UpdateCheckRow {
45    pub name: String,
46    pub channel: Option<Channel>,
47    pub provider: Option<Provider>,
48    pub status: UpdateCheckStatus,
49}
50
51impl<'a> UpgradeOperation<'a> {
52    fn truncate_error(value: &str, max: usize) -> String {
53        let char_count = value.chars().count();
54        if char_count <= max {
55            return value.to_string();
56        }
57
58        let mut out = String::new();
59        for ch in value.chars().take(max.saturating_sub(3)) {
60            out.push(ch);
61        }
62        out.push_str("...");
63        out
64    }
65
66    fn format_transfer(downloaded: u64, total: u64) -> String {
67        if total > 0 {
68            format!("{} / {}", HumanBytes(downloaded), HumanBytes(total))
69        } else if downloaded > 0 {
70            format!("{}", HumanBytes(downloaded))
71        } else {
72            "-".to_string()
73        }
74    }
75
76    fn render_progress_row(
77        name: &str,
78        channel: &Channel,
79        provider: &Provider,
80        downloaded: u64,
81        total: u64,
82    ) -> String {
83        format!(
84            " {:<28} {:<10} {:<3} {:<10} {}",
85            name,
86            channel.to_string().to_lowercase(),
87            "u",
88            provider.to_string(),
89            Self::format_transfer(downloaded, total)
90        )
91    }
92
93    async fn check_packages_parallel(
94        &self,
95        packages: Vec<crate::models::upstream::Package>,
96    ) -> Vec<(
97        crate::models::upstream::Package,
98        Result<Option<(String, String)>>,
99    )> {
100        let mut checked = stream::iter(packages.into_iter().enumerate().map(
101            |(idx, pkg)| async move {
102                let result = self.checker.check_one(&pkg).await;
103                (idx, pkg, result)
104            },
105        ))
106        .buffer_unordered(CHECK_CONCURRENCY)
107        .collect::<Vec<_>>()
108        .await;
109
110        checked.sort_by_key(|(idx, _, _)| *idx);
111        checked
112            .into_iter()
113            .map(|(_, pkg, result)| (pkg, result))
114            .collect()
115    }
116
117    async fn check_installed_packages_detailed(
118        &self,
119        packages: Vec<crate::models::upstream::Package>,
120    ) -> Vec<UpdateCheckRow> {
121        self.check_packages_parallel(packages)
122            .await
123            .into_iter()
124            .map(|(pkg, result)| match result {
125                Ok(Some((current, latest))) => UpdateCheckRow {
126                    name: pkg.name,
127                    channel: Some(pkg.channel),
128                    provider: Some(pkg.provider),
129                    status: UpdateCheckStatus::UpdateAvailable { current, latest },
130                },
131                Ok(None) => UpdateCheckRow {
132                    name: pkg.name,
133                    channel: Some(pkg.channel),
134                    provider: Some(pkg.provider),
135                    status: UpdateCheckStatus::UpToDate {
136                        current: pkg.version.to_string(),
137                    },
138                },
139                Err(error) => UpdateCheckRow {
140                    name: pkg.name,
141                    channel: Some(pkg.channel),
142                    provider: Some(pkg.provider),
143                    status: UpdateCheckStatus::Failed {
144                        error: error.to_string(),
145                    },
146                },
147            })
148            .collect()
149    }
150
151    pub fn new(
152        provider_manager: &'a ProviderManager,
153        package_storage: &'a mut PackageStorage,
154        paths: &'a UpstreamPaths,
155    ) -> Result<Self> {
156        let installer = PackageInstaller::new(provider_manager, paths)?;
157        let remover = PackageRemover::new(paths);
158
159        let upgrader = PackageUpgrader::new(provider_manager, installer, remover, paths);
160
161        let checker = PackageChecker::new(provider_manager);
162
163        Ok(Self {
164            upgrader,
165            checker,
166            package_storage,
167        })
168    }
169
170    pub async fn upgrade_all<F, G, H>(
171        &mut self,
172        force_option: &bool,
173        ignore_checksums: bool,
174        download_progress: &mut Option<F>,
175        overall_progress: &mut Option<G>,
176        message_callback: &mut Option<H>,
177    ) -> Result<()>
178    where
179        F: FnMut(u64, u64),
180        G: FnMut(u32, u32),
181        H: FnMut(&str),
182    {
183        let names: Vec<String> = self
184            .package_storage
185            .get_all_packages()
186            .iter()
187            .map(|p| p.name.clone())
188            .collect();
189
190        self.upgrade_bulk(
191            &names,
192            force_option,
193            ignore_checksums,
194            download_progress,
195            overall_progress,
196            message_callback,
197        )
198        .await
199    }
200
201    pub async fn upgrade_bulk<F, G, H>(
202        &mut self,
203        names: &[String],
204        force_option: &bool,
205        ignore_checksums: bool,
206        download_progress: &mut Option<F>,
207        overall_progress: &mut Option<G>,
208        message_callback: &mut Option<H>,
209    ) -> Result<()>
210    where
211        F: FnMut(u64, u64),
212        G: FnMut(u32, u32),
213        H: FnMut(&str),
214    {
215        let total = names.len() as u32;
216        let mut completed = 0;
217        let mut failures = 0;
218        let mut upgraded = 0;
219        let force = *force_option;
220        let upgrader = &self.upgrader;
221        let progress_state: ProgressState = Arc::new(Mutex::new(BTreeMap::new()));
222        let mut last_progress_render: BTreeMap<String, String> = BTreeMap::new();
223
224        let packages: Vec<_> = names
225            .iter()
226            .map(|name| {
227                self.package_storage
228                    .get_package_by_name(name)
229                    .ok_or_else(|| anyhow!("Package '{}' is not installed", name))
230                    .cloned()
231            })
232            .collect::<Result<Vec<_>>>()?;
233
234        let mut updated_packages = Vec::new();
235        let mut pending = stream::iter(packages.into_iter().map(|package| {
236            let state_ref = Arc::clone(&progress_state);
237            async move {
238                let name = package.name.clone();
239                let channel = package.channel.clone();
240                let provider = package.provider.clone();
241
242                if let Ok(mut state) = state_ref.lock() {
243                    state.insert(name.clone(), (channel.clone(), provider.clone(), 0, 0));
244                }
245
246                let mut downloaded: u64 = 0;
247                let mut bytes_total: u64 = 0;
248                let mut download_cb = Some(|d: u64, t: u64| {
249                    downloaded = d;
250                    bytes_total = t;
251                    if let Ok(mut state) = state_ref.lock() {
252                        state.insert(name.clone(), (channel.clone(), provider.clone(), d, t));
253                    }
254                });
255                let mut no_messages: Option<fn(&str)> = None;
256
257                let result = upgrader
258                    .upgrade(
259                        &package,
260                        force,
261                        ignore_checksums,
262                        &mut download_cb,
263                        &mut no_messages,
264                    )
265                    .await
266                    .context(format!("Failed to upgrade package '{}'", name));
267                (name, channel, provider, downloaded, bytes_total, result)
268            }
269        }))
270        .buffer_unordered(UPGRADE_CONCURRENCY);
271
272        let mut ticker = time::interval(Duration::from_millis(350));
273        ticker.set_missed_tick_behavior(time::MissedTickBehavior::Delay);
274
275        while completed < total {
276            tokio::select! {
277                maybe_item = pending.next() => {
278                    let Some((name, channel, provider, downloaded, bytes_total, result)) = maybe_item else {
279                        break;
280                    };
281
282                    if let Ok(mut state) = progress_state.lock() {
283                        state.remove(&name);
284                    }
285                    last_progress_render.remove(&name);
286                    message!(message_callback, "__UPGRADE_PROGRESS_DONE__ {}", name);
287
288            let transfer = Self::format_transfer(downloaded, bytes_total);
289            match result {
290                Ok(Some(updated)) => {
291                    updated_packages.push(updated);
292                    message!(
293                        message_callback,
294                        "[✓] {:<28} {:<10} {:<3} {:<10} {}",
295                        name,
296                        channel.to_string().to_lowercase(),
297                        "u",
298                        provider.to_string(),
299                        transfer
300                    );
301                    upgraded += 1;
302                }
303                Ok(None) => {
304                    message!(
305                        message_callback,
306                        "[=] {:<28} {:<10} {:<3} {:<10} {}",
307                        name,
308                        channel.to_string().to_lowercase(),
309                        "-",
310                        provider.to_string(),
311                        transfer
312                    );
313                }
314                Err(e) => {
315                    message!(
316                        message_callback,
317                        "[!] {:<28} {:<10} {:<3} {:<10} {}",
318                        name,
319                        channel.to_string().to_lowercase(),
320                        "!",
321                        provider.to_string(),
322                        Self::truncate_error(&e.to_string(), 36)
323                    );
324                    failures += 1;
325                }
326            }
327
328                    completed += 1;
329                    if let Some(cb) = overall_progress.as_mut() {
330                        cb(completed, total);
331                    }
332                }
333                _ = ticker.tick() => {
334                    if let Ok(state) = progress_state.lock() {
335                        for (name, (channel, provider, downloaded, total_bytes)) in state.iter() {
336                            let row = Self::render_progress_row(
337                                name,
338                                channel,
339                                provider,
340                                *downloaded,
341                                *total_bytes
342                            );
343                            let changed = last_progress_render
344                                .get(name)
345                                .map(|prev| prev != &row)
346                                .unwrap_or(true);
347                            if changed {
348                                message!(message_callback, "__UPGRADE_PROGRESS_ROW__ {}\t{}", name, row);
349                                last_progress_render.insert(name.clone(), row);
350                            }
351                        }
352                    }
353                }
354            }
355        }
356
357        message!(message_callback, "__UPGRADE_PROGRESS_CLEAR__");
358
359        // Save storage updates once parallel workers are done.
360        for updated in updated_packages {
361            self.package_storage.add_or_update_package(updated)?;
362        }
363
364        // Bulk mode uses per-package workers; a single shared download progress bar is noisy.
365        let _ = download_progress;
366
367        self.package_storage
368            .save_packages()
369            .context("Failed to save updated package information")?;
370
371        message!(
372            message_callback,
373            "Completed: {} upgraded, {} up-to-date, {} failed",
374            upgraded,
375            total - upgraded - failures,
376            failures
377        );
378
379        Ok(())
380    }
381
382    pub async fn upgrade_single<F, H>(
383        &mut self,
384        package_name: &str,
385        force_option: &bool,
386        ignore_checksums: bool,
387        download_progress: &mut Option<F>,
388        message_callback: &mut Option<H>,
389    ) -> Result<bool>
390    where
391        F: FnMut(u64, u64),
392        H: FnMut(&str),
393    {
394        let package = self
395            .package_storage
396            .get_package_by_name(package_name)
397            .ok_or_else(|| anyhow!("Package '{}' is not installed", package_name))?
398            .clone();
399
400        let upgraded = self
401            .upgrader
402            .upgrade(
403                &package,
404                *force_option,
405                ignore_checksums,
406                download_progress,
407                message_callback,
408            )
409            .await?;
410
411        if let Some(updated) = upgraded {
412            self.package_storage.add_or_update_package(updated)?;
413            self.package_storage.save_packages()?;
414            Ok(true)
415        } else {
416            Ok(false)
417        }
418    }
419
420    pub async fn check_all_detailed(&self) -> Vec<UpdateCheckRow> {
421        let packages = self.package_storage.get_all_packages().to_vec();
422        self.check_installed_packages_detailed(packages).await
423    }
424
425    pub async fn check_all_machine_readable(&self) -> Vec<(String, String, String)> {
426        let rows = self.check_all_detailed().await;
427        rows.into_iter()
428            .filter_map(|row| match row.status {
429                UpdateCheckStatus::UpdateAvailable { current, latest } => {
430                    Some((row.name, current, latest))
431                }
432                _ => None,
433            })
434            .collect()
435    }
436
437    pub async fn check_selected_detailed(&self, package_names: &[String]) -> Vec<UpdateCheckRow> {
438        let mut rows: Vec<Option<UpdateCheckRow>> =
439            (0..package_names.len()).map(|_| None).collect();
440        let mut selected_packages = Vec::new();
441        let mut selected_indices = Vec::new();
442
443        for (idx, name) in package_names.iter().enumerate() {
444            match self.package_storage.get_package_by_name(name) {
445                Some(package) => {
446                    selected_packages.push(package.clone());
447                    selected_indices.push(idx);
448                }
449                None => {
450                    rows[idx] = Some(UpdateCheckRow {
451                        name: name.clone(),
452                        channel: None,
453                        provider: None,
454                        status: UpdateCheckStatus::NotInstalled,
455                    })
456                }
457            }
458        }
459
460        let checked_rows = self
461            .check_installed_packages_detailed(selected_packages)
462            .await;
463        for (row_idx, checked_row) in selected_indices.into_iter().zip(checked_rows) {
464            rows[row_idx] = Some(checked_row);
465        }
466
467        rows.into_iter().flatten().collect()
468    }
469
470    pub async fn check_selected_machine_readable(
471        &self,
472        package_names: &[String],
473    ) -> Vec<(String, String, String)> {
474        let rows = self.check_selected_detailed(package_names).await;
475        rows.into_iter()
476            .filter_map(|row| match row.status {
477                UpdateCheckStatus::UpdateAvailable { current, latest } => {
478                    Some((row.name, current, latest))
479                }
480                _ => None,
481            })
482            .collect()
483    }
484}
485
486#[cfg(test)]
487mod tests {
488    use super::UpgradeOperation;
489    use crate::models::common::enums::{Channel, Provider};
490
491    #[test]
492    fn truncate_error_adds_ellipsis_when_limit_exceeded() {
493        let input = "this is a fairly long error string";
494        let truncated = UpgradeOperation::truncate_error(input, 12);
495        assert!(truncated.ends_with("..."));
496        assert!(truncated.chars().count() <= 12);
497    }
498
499    #[test]
500    fn format_transfer_handles_known_unknown_and_empty_sizes() {
501        assert_eq!(UpgradeOperation::format_transfer(0, 0), "-");
502        assert!(UpgradeOperation::format_transfer(42, 0).contains("42"));
503        let known_total = UpgradeOperation::format_transfer(1024, 2048);
504        assert!(known_total.contains('/'));
505    }
506
507    #[test]
508    fn render_progress_row_includes_package_channel_provider_and_transfer() {
509        let row = UpgradeOperation::render_progress_row(
510            "ripgrep",
511            &Channel::Stable,
512            &Provider::Github,
513            128,
514            256,
515        );
516        assert!(row.contains("ripgrep"));
517        assert!(row.contains("stable"));
518        assert!(row.contains("github"));
519        assert!(row.contains('/'));
520    }
521}