Skip to main content

spider_lib/
builder.rs

1//! Builder for constructing and configuring the `Crawler` instance.
2//!
3//! This module provides the `CrawlerBuilder`, a fluent API for
4//! setting up and customizing a web crawler. It simplifies the process of
5//! assembling various `spider-lib` components, including:
6//! - Defining concurrency settings for downloads, parsing, and pipelines.
7//! - Attaching custom `Downloader` implementations.
8//! - Registering `Middleware`s to process requests and responses.
9//! - Adding `Pipeline`s to process scraped items.
10//! - Configuring checkpointing for persistence and fault tolerance.
11//! - Initializing and integrating a `StatCollector` for gathering crawl statistics.
12//!
13//! The builder handles default configurations (e.g., adding a default User-Agent
14//! middleware if none is specified) and loading existing checkpoints.
15
16use crate::ConsoleWriterPipeline;
17#[cfg(feature = "checkpoint")]
18use crate::checkpoint::Checkpoint;
19use crate::downloader::Downloader;
20use crate::downloaders::reqwest_client::ReqwestClientDownloader;
21use crate::error::SpiderError;
22use crate::middleware::Middleware;
23use crate::middlewares::user_agent::{UserAgentMiddleware, UserAgentSource};
24use crate::pipeline::Pipeline;
25use crate::scheduler::Scheduler;
26use crate::spider::Spider;
27use num_cpus;
28#[cfg(feature = "checkpoint")]
29use std::fs;
30use std::marker::PhantomData;
31#[cfg(feature = "checkpoint")]
32use std::path::{Path, PathBuf};
33use std::sync::Arc;
34#[cfg(feature = "checkpoint")]
35use std::time::Duration;
36#[cfg(feature = "checkpoint")]
37use tracing::{debug, warn};
38
39use super::Crawler;
40use crate::stats::StatCollector;
41
42/// Configuration for the crawler's concurrency settings.
43pub struct CrawlerConfig {
44    /// The maximum number of concurrent downloads.
45    pub max_concurrent_downloads: usize,
46    /// The number of workers dedicated to parsing responses.
47    pub parser_workers: usize,
48    /// The maximum number of concurrent item processing pipelines.
49    pub max_concurrent_pipelines: usize,
50}
51
52impl Default for CrawlerConfig {
53    fn default() -> Self {
54        CrawlerConfig {
55            max_concurrent_downloads: 5,
56            parser_workers: num_cpus::get(),
57            max_concurrent_pipelines: 5,
58        }
59    }
60}
61
62pub struct CrawlerBuilder<S: Spider, D = ReqwestClientDownloader>
63where
64    D: Downloader,
65{
66    crawler_config: CrawlerConfig,
67    downloader: D,
68    spider: Option<S>,
69    middlewares: Vec<Box<dyn Middleware<D::Client> + Send + Sync>>,
70    item_pipelines: Vec<Box<dyn Pipeline<S::Item>>>,
71    #[cfg(feature = "checkpoint")]
72    checkpoint_path: Option<PathBuf>,
73    #[cfg(feature = "checkpoint")]
74    checkpoint_interval: Option<Duration>,
75    _phantom: PhantomData<S>,
76}
77
78impl<S: Spider, D: Default + Downloader> Default for CrawlerBuilder<S, D> {
79    fn default() -> Self {
80        Self {
81            crawler_config: CrawlerConfig::default(),
82            downloader: D::default(),
83            spider: None,
84            middlewares: Vec::new(),
85            item_pipelines: Vec::new(),
86            #[cfg(feature = "checkpoint")]
87            checkpoint_path: None,
88            #[cfg(feature = "checkpoint")]
89            checkpoint_interval: None,
90            _phantom: PhantomData,
91        }
92    }
93}
94
95impl<S: Spider> CrawlerBuilder<S> {
96    /// Creates a new `CrawlerBuilder` for a given spider.
97    pub fn new(spider: S) -> Self
98    {
99        Self {
100            spider: Some(spider),
101            ..Default::default()
102        }
103    }
104}
105
106impl<S: Spider, D: Downloader> CrawlerBuilder<S, D> {
107    /// Sets the maximum number of concurrent downloads.
108    pub fn max_concurrent_downloads(mut self, limit: usize) -> Self {
109        self.crawler_config.max_concurrent_downloads = limit;
110        self
111    }
112
113    /// Sets the maximum number of concurrent parser workers.
114    pub fn max_parser_workers(mut self, limit: usize) -> Self {
115        self.crawler_config.parser_workers = limit;
116        self
117    }
118
119    /// Sets the maximum number of concurrent pipelines.
120    pub fn max_concurrent_pipelines(mut self, limit: usize) -> Self {
121        self.crawler_config.max_concurrent_pipelines = limit;
122        self
123    }
124
125    /// Sets a custom downloader for the crawler.
126    pub fn downloader(mut self, downloader: D) -> Self {
127        self.downloader = downloader;
128        self
129    }
130
131    /// Adds a middleware to the crawler.
132    pub fn add_middleware<M>(mut self, middleware: M) -> Self
133    where
134        D: Downloader,
135        M: Middleware<D::Client> + Send + Sync + 'static,
136    {
137        self.middlewares.push(Box::new(middleware));
138        self
139    }
140
141    /// Adds an item pipeline to the crawler.
142    pub fn add_pipeline<P>(mut self, pipeline: P) -> Self
143    where
144        P: Pipeline<S::Item> + 'static,
145    {
146        self.item_pipelines.push(Box::new(pipeline));
147        self
148    }
149
150    /// Enables checkpointing and sets the path for the checkpoint file.
151    #[cfg(feature = "checkpoint")]
152    pub fn with_checkpoint_path<P: AsRef<Path>>(mut self, path: P) -> Self {
153        self.checkpoint_path = Some(path.as_ref().to_path_buf());
154        self
155    }
156
157    /// Sets the interval for periodic checkpointing.
158    #[cfg(feature = "checkpoint")]
159    pub fn with_checkpoint_interval(mut self, interval: Duration) -> Self {
160        self.checkpoint_interval = Some(interval);
161        self
162    }
163
164    /// Builds the `Crawler` instance, initializing and passing the `StatCollector` along with other configured components.
165    pub async fn build(mut self) -> Result<Crawler<S, D::Client>, SpiderError>
166    where
167        D: Downloader + Send + Sync + 'static,
168        D::Client: Send + Sync,
169    {
170        if self.item_pipelines.is_empty() {
171            self = self.add_pipeline(ConsoleWriterPipeline::new());
172        }
173
174        let spider = self.spider.take().ok_or_else(|| {
175            SpiderError::ConfigurationError("Crawler must have a spider.".to_string())
176        })?;
177
178        if self.crawler_config.max_concurrent_downloads == 0 {
179            return Err(SpiderError::ConfigurationError(
180                "max_concurrent_downloads must be greater than 0.".to_string(),
181            ));
182        }
183        if self.crawler_config.parser_workers == 0 {
184            return Err(SpiderError::ConfigurationError(
185                "parser_workers must be greater than 0.".to_string(),
186            ));
187        }
188
189        #[cfg(feature = "checkpoint")]
190        let mut initial_scheduler_state = None;
191        #[cfg(not(feature = "checkpoint"))]
192        let initial_scheduler_state = None;
193        #[cfg(feature = "checkpoint")]
194        let mut loaded_pipelines_state = None;
195
196        #[cfg(feature = "checkpoint")]
197        if let Some(path) = &self.checkpoint_path {
198            debug!("Attempting to load checkpoint from {:?}", path);
199            match fs::read(path) {
200                Ok(bytes) => match rmp_serde::from_slice::<Checkpoint>(&bytes) {
201                    Ok(checkpoint) => {
202                        initial_scheduler_state = Some(checkpoint.scheduler);
203                        loaded_pipelines_state = Some(checkpoint.pipelines);
204                    }
205                    Err(e) => warn!("Failed to deserialize checkpoint from {:?}: {}", path, e),
206                },
207                Err(e) => warn!("Failed to read checkpoint file {:?}: {}", path, e),
208            }
209        }
210
211        #[cfg(feature = "checkpoint")]
212        // Restore pipeline states now that pipelines are built
213        if let Some(pipeline_states) = loaded_pipelines_state {
214            for (name, state) in pipeline_states {
215                if let Some(pipeline) = self.item_pipelines.iter().find(|p| p.name() == name) {
216                    pipeline.restore_state(state).await?;
217                } else {
218                    warn!("Checkpoint contains state for unknown pipeline: {}", name);
219                }
220            }
221        }
222
223        let (scheduler_arc, req_rx) = Scheduler::new(initial_scheduler_state);
224
225        let has_user_agent_middleware = self
226            .middlewares
227            .iter()
228            .any(|m| m.name() == "UserAgentMiddleware");
229
230        if !has_user_agent_middleware {
231            let pkg_name = env!("CARGO_PKG_NAME");
232            let pkg_version = env!("CARGO_PKG_VERSION");
233            let default_user_agent = format!("{}/{}", pkg_name, pkg_version);
234
235            let default_user_agent_mw = UserAgentMiddleware::builder()
236                .source(UserAgentSource::List(vec![default_user_agent.clone()]))
237                .fallback_user_agent(default_user_agent)
238                .build()?;
239            self.middlewares.insert(0, Box::new(default_user_agent_mw));
240        }
241
242        let downloader_arc = Arc::new(self.downloader);
243
244        let stats = Arc::new(StatCollector::new());
245        let crawler = Crawler::new(
246            scheduler_arc,
247            req_rx,
248            downloader_arc,
249            self.middlewares,
250            spider,
251            self.item_pipelines,
252            self.crawler_config.max_concurrent_downloads,
253            self.crawler_config.parser_workers,
254            self.crawler_config.max_concurrent_pipelines,
255            #[cfg(feature = "checkpoint")]
256            self.checkpoint_path.take(),
257            #[cfg(feature = "checkpoint")]
258            self.checkpoint_interval,
259            stats,
260        );
261
262        Ok(crawler)
263    }
264}