Skip to main content

sbom_walker/
walker.rs

1//! The actual walker
2
3use crate::{
4    discover::{DiscoveredContext, DiscoveredVisitor},
5    source::Source,
6};
7use futures::{StreamExt, TryFutureExt, TryStreamExt, stream};
8use std::{fmt::Debug, sync::Arc};
9use url::ParseError;
10use walker_common::progress::{Progress, ProgressBar};
11
12#[derive(Debug, thiserror::Error)]
13pub enum Error<VE, SE>
14where
15    VE: std::fmt::Display + Debug,
16    SE: std::fmt::Display + Debug,
17{
18    #[error("Source error: {0}")]
19    Source(SE),
20    #[error("URL error: {0}")]
21    Url(#[from] ParseError),
22    #[error("Visitor error: {0}")]
23    Visitor(VE),
24}
25
26pub struct Walker<S: Source, P: Progress> {
27    source: S,
28    progress: P,
29}
30
31impl<S: Source> Walker<S, ()> {
32    pub fn new(source: S) -> Self {
33        Self {
34            source,
35            progress: (),
36        }
37    }
38}
39
40impl<S: Source, P: Progress> Walker<S, P> {
41    pub fn with_progress<U: Progress>(self, progress: U) -> Walker<S, U> {
42        Walker {
43            source: self.source,
44            progress,
45        }
46    }
47
48    pub async fn walk<V>(self, visitor: V) -> Result<(), Error<V::Error, S::Error>>
49    where
50        V: DiscoveredVisitor,
51    {
52        let metadata = self.source.load_metadata().await.map_err(Error::Source)?;
53
54        let context = visitor
55            .visit_context(&DiscoveredContext {
56                metadata: &metadata,
57            })
58            .await
59            .map_err(Error::Visitor)?;
60
61        let index = self.source.load_index().await.map_err(Error::Source)?;
62        let mut progress = self.progress.start(index.len());
63
64        for sbom in index {
65            log::debug!("  Discovered SBOM: {sbom:?}");
66            progress
67                .set_message(
68                    sbom.url
69                        .path()
70                        .rsplit_once('/')
71                        .map(|(_, s)| s)
72                        .unwrap_or(sbom.url.as_str())
73                        .to_string(),
74                )
75                .await;
76            visitor
77                .visit_sbom(&context, sbom)
78                .await
79                .map_err(Error::Visitor)?;
80            progress.tick().await;
81        }
82
83        progress.finish().await;
84
85        Ok(())
86    }
87
88    pub async fn walk_parallel<V>(
89        self,
90        limit: usize,
91        visitor: V,
92    ) -> Result<(), Error<V::Error, S::Error>>
93    where
94        V: DiscoveredVisitor,
95    {
96        log::debug!("Running {limit} workers");
97
98        let metadata = self.source.load_metadata().await.map_err(Error::Source)?;
99        let context = visitor
100            .visit_context(&DiscoveredContext {
101                metadata: &metadata,
102            })
103            .await
104            .map_err(Error::Visitor)?;
105
106        let visitor = Arc::new(visitor);
107        let context = Arc::new(context);
108
109        stream::iter(self.source.load_index().await.map_err(Error::Source)?)
110            .map(Ok)
111            .try_for_each_concurrent(limit, async |sbom| {
112                log::debug!("Discovered advisory: {}", sbom.url);
113
114                visitor
115                    .visit_sbom(&context, sbom)
116                    .map_err(Error::Visitor)
117                    .await
118            })
119            .await?;
120
121        Ok(())
122    }
123}