Skip to main content

spider_lib/
builder.rs

1//! Builder for constructing and configuring the `Crawler` instance.
2//!
3//! This module provides the `CrawlerBuilder`, a fluent API for
4//! setting up and customizing a web crawler. It simplifies the process of
5//! assembling various `spider-lib` components, including:
6//! - Defining concurrency settings for downloads, parsing, and pipelines.
7//! - Attaching custom `Downloader` implementations.
8//! - Registering `Middleware`s to process requests and responses.
9//! - Adding `Pipeline`s to process scraped items.
10//! - Configuring checkpointing for persistence and fault tolerance.
11//!
12//! The builder handles default configurations (e.g., adding a default User-Agent
13//! middleware if none is specified) and loading existing checkpoints.
14
15use crate::ConsoleWriterPipeline;
16#[cfg(feature = "checkpoint")]
17use crate::checkpoint::Checkpoint;
18use crate::downloader::Downloader;
19use crate::downloaders::reqwest_client::ReqwestClientDownloader;
20use crate::error::SpiderError;
21use crate::middleware::Middleware;
22use crate::middlewares::user_agent::{UserAgentMiddleware, UserAgentSource};
23use crate::pipeline::Pipeline;
24use crate::scheduler::Scheduler;
25use crate::spider::Spider;
26use num_cpus;
27#[cfg(feature = "checkpoint")]
28use std::fs;
29use std::marker::PhantomData;
30#[cfg(feature = "checkpoint")]
31use std::path::{Path, PathBuf};
32use std::sync::Arc;
33#[cfg(feature = "checkpoint")]
34use std::time::Duration;
35#[cfg(feature = "checkpoint")]
36use tracing::{debug, warn};
37
38use super::Crawler;
39
40/// Configuration for the crawler's concurrency settings.
41pub struct CrawlerConfig {
42    /// The maximum number of concurrent downloads.
43    pub max_concurrent_downloads: usize,
44    /// The number of workers dedicated to parsing responses.
45    pub parser_workers: usize,
46    /// The maximum number of concurrent item processing pipelines.
47    pub max_concurrent_pipelines: usize,
48}
49
50impl Default for CrawlerConfig {
51    fn default() -> Self {
52        CrawlerConfig {
53            max_concurrent_downloads: 5,
54            parser_workers: num_cpus::get(),
55            max_concurrent_pipelines: 5,
56        }
57    }
58}
59
60pub struct CrawlerBuilder<S: Spider, D = ReqwestClientDownloader>
61where
62    D: Downloader,
63{
64    crawler_config: CrawlerConfig,
65    downloader: D,
66    spider: Option<S>,
67    middlewares: Vec<Box<dyn Middleware<D::Client> + Send + Sync>>,
68    item_pipelines: Vec<Box<dyn Pipeline<S::Item>>>,
69    #[cfg(feature = "checkpoint")]
70    checkpoint_path: Option<PathBuf>,
71    #[cfg(feature = "checkpoint")]
72    checkpoint_interval: Option<Duration>,
73    _phantom: PhantomData<S>,
74}
75
76impl<S: Spider, D: Default + Downloader> Default for CrawlerBuilder<S, D> {
77    fn default() -> Self {
78        Self {
79            crawler_config: CrawlerConfig::default(),
80            downloader: D::default(),
81            spider: None,
82            middlewares: Vec::new(),
83            item_pipelines: Vec::new(),
84            #[cfg(feature = "checkpoint")]
85            checkpoint_path: None,
86            #[cfg(feature = "checkpoint")]
87            checkpoint_interval: None,
88            _phantom: PhantomData,
89        }
90    }
91}
92
93impl<S: Spider, D: Downloader> CrawlerBuilder<S, D> {
94    /// Creates a new `CrawlerBuilder` for a given spider.
95    pub fn new(spider: S) -> Self
96    where
97        D: Default,
98    {
99        Self {
100            spider: Some(spider),
101            ..Default::default()
102        }
103    }
104
105    /// Sets the maximum number of concurrent downloads.
106    pub fn max_concurrent_downloads(mut self, limit: usize) -> Self {
107        self.crawler_config.max_concurrent_downloads = limit;
108        self
109    }
110
111    /// Sets the maximum number of concurrent parser workers.
112    pub fn max_parser_workers(mut self, limit: usize) -> Self {
113        self.crawler_config.parser_workers = limit;
114        self
115    }
116
117    /// Sets the maximum number of concurrent pipelines.
118    pub fn max_concurrent_pipelines(mut self, limit: usize) -> Self {
119        self.crawler_config.max_concurrent_pipelines = limit;
120        self
121    }
122
123    /// Sets a custom downloader for the crawler.
124    pub fn downloader(mut self, downloader: D) -> Self {
125        self.downloader = downloader;
126        self
127    }
128
129    /// Adds a middleware to the crawler.
130    pub fn add_middleware<M>(mut self, middleware: M) -> Self
131    where
132        D: Downloader,
133        M: Middleware<D::Client> + Send + Sync + 'static,
134    {
135        self.middlewares.push(Box::new(middleware));
136        self
137    }
138
139    /// Adds an item pipeline to the crawler.
140    pub fn add_pipeline<P>(mut self, pipeline: P) -> Self
141    where
142        P: Pipeline<S::Item> + 'static,
143    {
144        self.item_pipelines.push(Box::new(pipeline));
145        self
146    }
147
148    /// Enables checkpointing and sets the path for the checkpoint file.
149    #[cfg(feature = "checkpoint")]
150    pub fn with_checkpoint_path<P: AsRef<Path>>(mut self, path: P) -> Self {
151        self.checkpoint_path = Some(path.as_ref().to_path_buf());
152        self
153    }
154
155    /// Sets the interval for periodic checkpointing.
156    #[cfg(feature = "checkpoint")]
157    pub fn with_checkpoint_interval(mut self, interval: Duration) -> Self {
158        self.checkpoint_interval = Some(interval);
159        self
160    }
161
162    /// Builds the `Crawler` instance.
163    pub async fn build(mut self) -> Result<Crawler<S, D::Client>, SpiderError>
164    where
165        D: Downloader + Send + Sync + 'static,
166        D::Client: Send + Sync,
167    {
168        if self.item_pipelines.is_empty() {
169            self = self.add_pipeline(ConsoleWriterPipeline::new());
170        }
171
172        let spider = self.spider.take().ok_or_else(|| {
173            SpiderError::ConfigurationError("Crawler must have a spider.".to_string())
174        })?;
175
176        if self.crawler_config.max_concurrent_downloads == 0 {
177            return Err(SpiderError::ConfigurationError(
178                "max_concurrent_downloads must be greater than 0.".to_string(),
179            ));
180        }
181        if self.crawler_config.parser_workers == 0 {
182            return Err(SpiderError::ConfigurationError(
183                "parser_workers must be greater than 0.".to_string(),
184            ));
185        }
186
187        #[cfg(feature = "checkpoint")]
188        let mut initial_scheduler_state = None;
189        #[cfg(not(feature = "checkpoint"))]
190        let initial_scheduler_state = None;
191        #[cfg(feature = "checkpoint")]
192        let mut loaded_pipelines_state = None;
193
194        #[cfg(feature = "checkpoint")]
195        if let Some(path) = &self.checkpoint_path {
196            debug!("Attempting to load checkpoint from {:?}", path);
197            match fs::read(path) {
198                Ok(bytes) => match rmp_serde::from_slice::<Checkpoint>(&bytes) {
199                    Ok(checkpoint) => {
200                        initial_scheduler_state = Some(checkpoint.scheduler);
201                        loaded_pipelines_state = Some(checkpoint.pipelines);
202                    }
203                    Err(e) => warn!("Failed to deserialize checkpoint from {:?}: {}", path, e),
204                },
205                Err(e) => warn!("Failed to read checkpoint file {:?}: {}", path, e),
206            }
207        }
208
209        #[cfg(feature = "checkpoint")]
210        // Restore pipeline states now that pipelines are built
211        if let Some(pipeline_states) = loaded_pipelines_state {
212            for (name, state) in pipeline_states {
213                if let Some(pipeline) = self.item_pipelines.iter().find(|p| p.name() == name) {
214                    pipeline.restore_state(state).await?;
215                } else {
216                    warn!("Checkpoint contains state for unknown pipeline: {}", name);
217                }
218            }
219        }
220
221        let (scheduler_arc, req_rx) = Scheduler::new(initial_scheduler_state);
222
223        let has_user_agent_middleware = self
224            .middlewares
225            .iter()
226            .any(|m| m.name() == "UserAgentMiddleware");
227
228        if !has_user_agent_middleware {
229            let pkg_name = env!("CARGO_PKG_NAME");
230            let pkg_version = env!("CARGO_PKG_VERSION");
231            let default_user_agent = format!("{}/{}", pkg_name, pkg_version);
232
233            let default_user_agent_mw = UserAgentMiddleware::builder()
234                .source(UserAgentSource::List(vec![default_user_agent.clone()]))
235                .fallback_user_agent(default_user_agent)
236                .build()?;
237            self.middlewares.insert(0, Box::new(default_user_agent_mw));
238        }
239
240        let downloader_arc = Arc::new(self.downloader);
241
242        let crawler = Crawler::new(
243            scheduler_arc,
244            req_rx,
245            downloader_arc,
246            self.middlewares,
247            spider,
248            self.item_pipelines,
249            self.crawler_config.max_concurrent_downloads,
250            self.crawler_config.parser_workers,
251            self.crawler_config.max_concurrent_pipelines,
252            #[cfg(feature = "checkpoint")]
253            self.checkpoint_path.take(),
254            #[cfg(feature = "checkpoint")]
255            self.checkpoint_interval,
256        );
257
258        Ok(crawler)
259    }
260}