1use crate::{
4 discover::{DiscoveredContext, DiscoveredVisitor},
5 source::Source,
6};
7use futures::{StreamExt, TryFutureExt, TryStreamExt, stream};
8use std::{fmt::Debug, sync::Arc};
9use url::ParseError;
10use walker_common::progress::{Progress, ProgressBar};
11
12#[derive(Debug, thiserror::Error)]
13pub enum Error<VE, SE>
14where
15 VE: std::fmt::Display + Debug,
16 SE: std::fmt::Display + Debug,
17{
18 #[error("Source error: {0}")]
19 Source(SE),
20 #[error("URL error: {0}")]
21 Url(#[from] ParseError),
22 #[error("Visitor error: {0}")]
23 Visitor(VE),
24}
25
26pub struct Walker<S: Source, P: Progress> {
27 source: S,
28 progress: P,
29}
30
31impl<S: Source> Walker<S, ()> {
32 pub fn new(source: S) -> Self {
33 Self {
34 source,
35 progress: (),
36 }
37 }
38}
39
40impl<S: Source, P: Progress> Walker<S, P> {
41 pub fn with_progress<U: Progress>(self, progress: U) -> Walker<S, U> {
42 Walker {
43 source: self.source,
44 progress,
45 }
46 }
47
48 pub async fn walk<V>(self, visitor: V) -> Result<(), Error<V::Error, S::Error>>
49 where
50 V: DiscoveredVisitor,
51 {
52 let metadata = self.source.load_metadata().await.map_err(Error::Source)?;
53
54 let context = visitor
55 .visit_context(&DiscoveredContext {
56 metadata: &metadata,
57 })
58 .await
59 .map_err(Error::Visitor)?;
60
61 let index = self.source.load_index().await.map_err(Error::Source)?;
62 let mut progress = self.progress.start(index.len());
63
64 for sbom in index {
65 log::debug!(" Discovered SBOM: {sbom:?}");
66 progress
67 .set_message(
68 sbom.url
69 .path()
70 .rsplit_once('/')
71 .map(|(_, s)| s)
72 .unwrap_or(sbom.url.as_str())
73 .to_string(),
74 )
75 .await;
76 visitor
77 .visit_sbom(&context, sbom)
78 .await
79 .map_err(Error::Visitor)?;
80 progress.tick().await;
81 }
82
83 progress.finish().await;
84
85 Ok(())
86 }
87
88 pub async fn walk_parallel<V>(
89 self,
90 limit: usize,
91 visitor: V,
92 ) -> Result<(), Error<V::Error, S::Error>>
93 where
94 V: DiscoveredVisitor,
95 {
96 log::debug!("Running {limit} workers");
97
98 let metadata = self.source.load_metadata().await.map_err(Error::Source)?;
99 let context = visitor
100 .visit_context(&DiscoveredContext {
101 metadata: &metadata,
102 })
103 .await
104 .map_err(Error::Visitor)?;
105
106 let visitor = Arc::new(visitor);
107 let context = Arc::new(context);
108
109 stream::iter(self.source.load_index().await.map_err(Error::Source)?)
110 .map(Ok)
111 .try_for_each_concurrent(limit, async |sbom| {
112 log::debug!("Discovered advisory: {}", sbom.url);
113
114 visitor
115 .visit_sbom(&context, sbom)
116 .map_err(Error::Visitor)
117 .await
118 })
119 .await?;
120
121 Ok(())
122 }
123}