Skip to main content

spider_lib/
builder.rs

1//! Builder for constructing and configuring the `Crawler` instance.
2//!
3//! This module provides the `CrawlerBuilder`, a fluent API for
4//! setting up and customizing a web crawler. It simplifies the process of
5//! assembling various `spider-lib` components, including:
6//! - Defining concurrency settings for downloads, parsing, and pipelines.
7//! - Attaching custom `Downloader` implementations.
8//! - Registering `Middleware`s to process requests and responses.
9//! - Adding `Pipeline`s to process scraped items.
10//! - Configuring checkpointing for persistence and fault tolerance.
11//!
12//! The builder handles default configurations (e.g., adding a default User-Agent
13//! middleware if none is specified) and loading existing checkpoints.
14
15use crate::ConsoleWriterPipeline;
16use crate::checkpoint::Checkpoint;
17use crate::downloader::Downloader;
18use crate::downloaders::reqwest_client::ReqwestClientDownloader;
19use crate::error::SpiderError;
20use crate::middleware::Middleware;
21use crate::middlewares::user_agent::{UserAgentMiddleware, UserAgentSource};
22use crate::pipeline::Pipeline;
23use crate::scheduler::Scheduler;
24use crate::spider::Spider;
25use num_cpus;
26use std::fs;
27use std::marker::PhantomData;
28use std::path::{Path, PathBuf};
29use std::sync::Arc;
30use std::time::Duration;
31use tracing::{debug, warn};
32
33use super::Crawler;
34
35/// Configuration for the crawler's concurrency settings.
36pub struct CrawlerConfig {
37    /// The maximum number of concurrent downloads.
38    pub max_concurrent_downloads: usize,
39    /// The number of workers dedicated to parsing responses.
40    pub parser_workers: usize,
41    /// The maximum number of concurrent item processing pipelines.
42    pub max_concurrent_pipelines: usize,
43}
44
45impl Default for CrawlerConfig {
46    fn default() -> Self {
47        CrawlerConfig {
48            max_concurrent_downloads: 5,
49            parser_workers: num_cpus::get(),
50            max_concurrent_pipelines: 5,
51        }
52    }
53}
54
55pub struct CrawlerBuilder<S: Spider, D = ReqwestClientDownloader>
56where
57    D: Downloader,
58{
59    crawler_config: CrawlerConfig,
60    downloader: D,
61    spider: Option<S>,
62    middlewares: Vec<Box<dyn Middleware<D::Client> + Send + Sync>>,
63    item_pipelines: Vec<Box<dyn Pipeline<S::Item>>>,
64    checkpoint_path: Option<PathBuf>,
65    checkpoint_interval: Option<Duration>,
66    _phantom: PhantomData<S>,
67}
68
69impl<S: Spider, D: Default + Downloader> Default for CrawlerBuilder<S, D> {
70    fn default() -> Self {
71        Self {
72            crawler_config: CrawlerConfig::default(),
73            downloader: D::default(),
74            spider: None,
75            middlewares: Vec::new(),
76            item_pipelines: Vec::new(),
77            checkpoint_path: None,
78            checkpoint_interval: None,
79            _phantom: PhantomData,
80        }
81    }
82}
83
84impl<S: Spider, D: Downloader> CrawlerBuilder<S, D> {
85    /// Creates a new `CrawlerBuilder` for a given spider.
86    pub fn new(spider: S) -> Self
87    where
88        D: Default,
89    {
90        Self {
91            spider: Some(spider),
92            ..Default::default()
93        }
94    }
95
96    /// Sets the maximum number of concurrent downloads.
97    pub fn max_concurrent_downloads(mut self, limit: usize) -> Self {
98        self.crawler_config.max_concurrent_downloads = limit;
99        self
100    }
101
102    /// Sets the maximum number of concurrent parser workers.
103    pub fn max_parser_workers(mut self, limit: usize) -> Self {
104        self.crawler_config.parser_workers = limit;
105        self
106    }
107
108    /// Sets the maximum number of concurrent pipelines.
109    pub fn max_concurrent_pipelines(mut self, limit: usize) -> Self {
110        self.crawler_config.max_concurrent_pipelines = limit;
111        self
112    }
113
114    /// Sets a custom downloader for the crawler.
115    pub fn downloader(mut self, downloader: D) -> Self {
116        self.downloader = downloader;
117        self
118    }
119
120    /// Adds a middleware to the crawler.
121    pub fn add_middleware<M>(mut self, middleware: M) -> Self
122    where
123        D: Downloader,
124        M: Middleware<D::Client> + Send + Sync + 'static,
125    {
126        self.middlewares.push(Box::new(middleware));
127        self
128    }
129
130    /// Adds an item pipeline to the crawler.
131    pub fn add_pipeline<P>(mut self, pipeline: P) -> Self
132    where
133        P: Pipeline<S::Item> + 'static,
134    {
135        self.item_pipelines.push(Box::new(pipeline));
136        self
137    }
138
139    /// Enables checkpointing and sets the path for the checkpoint file.
140    pub fn with_checkpoint_path<P: AsRef<Path>>(mut self, path: P) -> Self {
141        self.checkpoint_path = Some(path.as_ref().to_path_buf());
142        self
143    }
144
145    /// Sets the interval for periodic checkpointing.
146    pub fn with_checkpoint_interval(mut self, interval: Duration) -> Self {
147        self.checkpoint_interval = Some(interval);
148        self
149    }
150
151    /// Builds the `Crawler` instance.
152    pub async fn build(mut self) -> Result<Crawler<S, D::Client>, SpiderError>
153    where
154        D: Downloader + Send + Sync + 'static,
155        D::Client: Send + Sync,
156    {
157        if self.item_pipelines.is_empty() {
158            self = self.add_pipeline(ConsoleWriterPipeline::new());
159        }
160
161        let spider = self.spider.take().ok_or_else(|| {
162            SpiderError::ConfigurationError("Crawler must have a spider.".to_string())
163        })?;
164
165        if self.crawler_config.max_concurrent_downloads == 0 {
166            return Err(SpiderError::ConfigurationError(
167                "max_concurrent_downloads must be greater than 0.".to_string(),
168            ));
169        }
170        if self.crawler_config.parser_workers == 0 {
171            return Err(SpiderError::ConfigurationError(
172                "parser_workers must be greater than 0.".to_string(),
173            ));
174        }
175
176        let mut initial_scheduler_state = None;
177        let mut loaded_pipelines_state = None;
178
179        if let Some(path) = &self.checkpoint_path {
180            debug!("Attempting to load checkpoint from {:?}", path);
181            match fs::read(path) {
182                Ok(bytes) => match rmp_serde::from_slice::<Checkpoint>(&bytes) {
183                    Ok(checkpoint) => {
184                        initial_scheduler_state = Some(checkpoint.scheduler);
185                        loaded_pipelines_state = Some(checkpoint.pipelines);
186                    }
187                    Err(e) => warn!("Failed to deserialize checkpoint from {:?}: {}", path, e),
188                },
189                Err(e) => warn!("Failed to read checkpoint file {:?}: {}", path, e),
190            }
191        }
192
193        let (scheduler_arc, req_rx) = Scheduler::new(initial_scheduler_state.clone());
194
195        // Restore pipeline states now that pipelines are built
196        if let Some(pipeline_states) = loaded_pipelines_state {
197            for (name, state) in pipeline_states {
198                if let Some(pipeline) = self.item_pipelines.iter().find(|p| p.name() == name) {
199                    pipeline.restore_state(state).await?;
200                } else {
201                    warn!("Checkpoint contains state for unknown pipeline: {}", name);
202                }
203            }
204        }
205
206        let has_user_agent_middleware = self
207            .middlewares
208            .iter()
209            .any(|m| m.name() == "UserAgentMiddleware");
210
211        if !has_user_agent_middleware {
212            let pkg_name = env!("CARGO_PKG_NAME");
213            let pkg_version = env!("CARGO_PKG_VERSION");
214            let default_user_agent = format!("{}/{}", pkg_name, pkg_version);
215
216            let default_user_agent_mw = UserAgentMiddleware::builder()
217                .source(UserAgentSource::List(vec![default_user_agent.clone()]))
218                .fallback_user_agent(default_user_agent)
219                .build()?;
220            self.middlewares.insert(0, Box::new(default_user_agent_mw));
221        }
222
223        let downloader_arc = Arc::new(self.downloader);
224
225        let crawler = Crawler::new(
226            scheduler_arc,
227            req_rx,
228            downloader_arc,
229            self.middlewares,
230            spider,
231            self.item_pipelines,
232            self.crawler_config.max_concurrent_downloads,
233            self.crawler_config.parser_workers,
234            self.crawler_config.max_concurrent_pipelines,
235            self.checkpoint_path.take(),
236            self.checkpoint_interval,
237        );
238
239        Ok(crawler)
240    }
241}