Skip to main content

spider_core/
builder.rs

1//! # Builder Module
2//!
3//! Provides the `CrawlerBuilder`, a fluent API for constructing and configuring
4//! `Crawler` instances with customizable settings and components.
5//!
6//! ## Overview
7//!
8//! The `CrawlerBuilder` simplifies the process of assembling various `spider-core`
9//! components into a fully configured web crawler. It provides a flexible,
10//! ergonomic interface for setting up all aspects of the crawling process.
11//!
12//! ## Key Features
13//!
14//! - **Concurrency Configuration**: Control the number of concurrent downloads,
15//!   parsing workers, and pipeline processors
16//! - **Component Registration**: Attach custom downloaders, middlewares, and pipelines
17//! - **Checkpoint Management**: Configure automatic saving and loading of crawl state (feature: `core-checkpoint`)
18//! - **Statistics Integration**: Initialize and connect the `StatCollector`
19//! - **Default Handling**: Automatic addition of essential middlewares when needed
20//!
21//! ## Example
22//!
23//! ```rust,ignore
24//! use spider_core::CrawlerBuilder;
25//! use spider_middleware::rate_limit::RateLimitMiddleware;
26//! use spider_pipeline::console::ConsolePipeline;
27//!
28//! async fn setup_crawler() -> Result<(), SpiderError> {
29//!     let crawler = CrawlerBuilder::new(MySpider)
30//!         .max_concurrent_downloads(10)
31//!         .max_parser_workers(4)
32//!         .add_middleware(RateLimitMiddleware::default())
33//!         .add_pipeline(ConsolePipeline::new())
34//!         .with_checkpoint_path("./crawl.checkpoint")
35//!         .build()
36//!         .await?;
37//!
38//!     crawler.start_crawl().await
39//! }
40//! ```
41
42use crate::Downloader;
43use crate::ReqwestClientDownloader;
44use crate::scheduler::Scheduler;
45use crate::spider::Spider;
46use num_cpus;
47use spider_middleware::middleware::Middleware;
48use spider_pipeline::pipeline::Pipeline;
49use spider_util::error::SpiderError;
50use std::marker::PhantomData;
51use std::path::{Path, PathBuf};
52use std::sync::Arc;
53use std::time::Duration;
54
55use super::Crawler;
56use crate::stats::StatCollector;
57#[cfg(feature = "checkpoint")]
58use log::{debug, warn};
59
60#[cfg(feature = "checkpoint")]
61use crate::SchedulerCheckpoint;
62#[cfg(feature = "checkpoint")]
63use rmp_serde;
64#[cfg(feature = "checkpoint")]
65use std::fs;
66
67/// Configuration for the crawler's concurrency settings.
68pub struct CrawlerConfig {
69    /// The maximum number of concurrent downloads.
70    pub max_concurrent_downloads: usize,
71    /// The number of workers dedicated to parsing responses.
72    pub parser_workers: usize,
73    /// The maximum number of concurrent item processing pipelines.
74    pub max_concurrent_pipelines: usize,
75    /// The capacity of communication channels between components.
76    pub channel_capacity: usize,
77}
78
79impl Default for CrawlerConfig {
80    fn default() -> Self {
81        CrawlerConfig {
82            max_concurrent_downloads: num_cpus::get().max(16),
83            parser_workers: num_cpus::get().clamp(4, 16),
84            max_concurrent_pipelines: num_cpus::get().min(8),
85            channel_capacity: 1000,
86        }
87    }
88}
89
90pub struct CrawlerBuilder<S: Spider, D>
91where
92    D: Downloader,
93{
94    config: CrawlerConfig,
95    downloader: D,
96    spider: Option<S>,
97    middlewares: Vec<Box<dyn Middleware<D::Client> + Send + Sync>>,
98    pipelines: Vec<Box<dyn Pipeline<S::Item>>>,
99    checkpoint_path: Option<PathBuf>,
100    checkpoint_interval: Option<Duration>,
101    _phantom: PhantomData<S>,
102}
103
104impl<S: Spider> Default for CrawlerBuilder<S, ReqwestClientDownloader> {
105    fn default() -> Self {
106        Self {
107            config: CrawlerConfig::default(),
108            downloader: ReqwestClientDownloader::default(),
109            spider: None,
110            middlewares: Vec::new(),
111            pipelines: Vec::new(),
112            checkpoint_path: None,
113            checkpoint_interval: None,
114            _phantom: PhantomData,
115        }
116    }
117}
118
119impl<S: Spider> CrawlerBuilder<S, ReqwestClientDownloader> {
120    /// Creates a new `CrawlerBuilder` for a given spider with the default ReqwestClientDownloader.
121    pub fn new(spider: S) -> Self {
122        Self {
123            spider: Some(spider),
124            ..Default::default()
125        }
126    }
127}
128
129impl<S: Spider, D: Downloader> CrawlerBuilder<S, D> {
130    pub fn max_concurrent_downloads(mut self, limit: usize) -> Self {
131        self.config.max_concurrent_downloads = limit;
132        self
133    }
134
135    pub fn max_parser_workers(mut self, limit: usize) -> Self {
136        self.config.parser_workers = limit;
137        self
138    }
139
140    pub fn max_concurrent_pipelines(mut self, limit: usize) -> Self {
141        self.config.max_concurrent_pipelines = limit;
142        self
143    }
144
145    pub fn channel_capacity(mut self, capacity: usize) -> Self {
146        self.config.channel_capacity = capacity;
147        self
148    }
149
150    pub fn downloader(mut self, downloader: D) -> Self {
151        self.downloader = downloader;
152        self
153    }
154
155    pub fn add_middleware<M>(mut self, middleware: M) -> Self
156    where
157        M: Middleware<D::Client> + Send + Sync + 'static,
158    {
159        self.middlewares.push(Box::new(middleware));
160        self
161    }
162
163    pub fn add_pipeline<P>(mut self, pipeline: P) -> Self
164    where
165        P: Pipeline<S::Item> + 'static,
166    {
167        self.pipelines.push(Box::new(pipeline));
168        self
169    }
170
171    pub fn with_checkpoint_path<P: AsRef<Path>>(mut self, path: P) -> Self {
172        self.checkpoint_path = Some(path.as_ref().to_path_buf());
173        self
174    }
175
176    pub fn with_checkpoint_interval(mut self, interval: Duration) -> Self {
177        self.checkpoint_interval = Some(interval);
178        self
179    }
180
181    #[allow(unused_variables)]
182    pub async fn build(mut self) -> Result<Crawler<S, D::Client>, SpiderError>
183    where
184        D: Downloader + Send + Sync + 'static,
185        D::Client: Send + Sync + Clone,
186        S::Item: Send + Sync + 'static,
187    {
188        let spider = self.take_spider()?;
189        self.init_default_pipeline();
190
191        #[cfg(all(feature = "checkpoint", feature = "cookie-store"))]
192        {
193            let (scheduler_state, cookie_store) = self.restore_checkpoint().await?;
194            let (scheduler_arc, req_rx) = Scheduler::new(scheduler_state);
195            let downloader_arc = Arc::new(self.downloader);
196            let stats = Arc::new(StatCollector::new());
197            let crawler = Crawler::new(
198                scheduler_arc,
199                req_rx,
200                downloader_arc,
201                self.middlewares,
202                spider,
203                self.pipelines,
204                self.config.max_concurrent_downloads,
205                self.config.parser_workers,
206                self.config.max_concurrent_pipelines,
207                self.config.channel_capacity,
208                self.checkpoint_path.take(),
209                self.checkpoint_interval,
210                stats,
211                Arc::new(tokio::sync::RwLock::new(
212                    cookie_store.unwrap_or_default(),
213                )),
214            );
215            Ok(crawler)
216        }
217
218        #[cfg(all(feature = "checkpoint", not(feature = "cookie-store")))]
219        {
220            let (scheduler_state, _cookie_store) = self.restore_checkpoint().await?;
221            let (scheduler_arc, req_rx) = Scheduler::new(scheduler_state);
222            let downloader_arc = Arc::new(self.downloader);
223            let stats = Arc::new(StatCollector::new());
224            let crawler = Crawler::new(
225                scheduler_arc,
226                req_rx,
227                downloader_arc,
228                self.middlewares,
229                spider,
230                self.pipelines,
231                self.config.max_concurrent_downloads,
232                self.config.parser_workers,
233                self.config.max_concurrent_pipelines,
234                self.config.channel_capacity,
235                self.checkpoint_path.take(),
236                self.checkpoint_interval,
237                stats,
238            );
239            Ok(crawler)
240        }
241
242        #[cfg(all(not(feature = "checkpoint"), feature = "cookie-store"))]
243        {
244            let (_scheduler_state, cookie_store) = self.restore_checkpoint().await?;
245            let (scheduler_arc, req_rx) = Scheduler::new(None::<()>);
246            let downloader_arc = Arc::new(self.downloader);
247            let stats = Arc::new(StatCollector::new());
248            let crawler = Crawler::new(
249                scheduler_arc,
250                req_rx,
251                downloader_arc,
252                self.middlewares,
253                spider,
254                self.pipelines,
255                self.config.max_concurrent_downloads,
256                self.config.parser_workers,
257                self.config.max_concurrent_pipelines,
258                self.config.channel_capacity,
259                stats,
260                Arc::new(tokio::sync::RwLock::new(
261                    cookie_store.unwrap_or_default(),
262                )),
263            );
264            Ok(crawler)
265        }
266
267        #[cfg(all(not(feature = "checkpoint"), not(feature = "cookie-store")))]
268        {
269            let (_scheduler_state, _cookie_store) = self.restore_checkpoint().await?;
270            let (scheduler_arc, req_rx) = Scheduler::new(None::<()>);
271            let downloader_arc = Arc::new(self.downloader);
272            let stats = Arc::new(StatCollector::new());
273            let crawler = Crawler::new(
274                scheduler_arc,
275                req_rx,
276                downloader_arc,
277                self.middlewares,
278                spider,
279                self.pipelines,
280                self.config.max_concurrent_downloads,
281                self.config.parser_workers,
282                self.config.max_concurrent_pipelines,
283                self.config.channel_capacity,
284                stats,
285            );
286            Ok(crawler)
287        }
288    }
289
290    #[cfg(all(feature = "checkpoint", feature = "cookie-store"))]
291    async fn restore_checkpoint(
292        &mut self,
293    ) -> Result<(Option<SchedulerCheckpoint>, Option<crate::CookieStore>), SpiderError> {
294        let mut scheduler_state = None;
295        let mut pipeline_states = None;
296        let mut cookie_store = None;
297
298        if let Some(path) = &self.checkpoint_path {
299            debug!("Attempting to load checkpoint from {:?}", path);
300            match fs::read(path) {
301                Ok(bytes) => match rmp_serde::from_slice::<crate::Checkpoint>(&bytes) {
302                    Ok(checkpoint) => {
303                        scheduler_state = Some(checkpoint.scheduler);
304                        pipeline_states = Some(checkpoint.pipelines);
305                        cookie_store = Some(checkpoint.cookie_store);
306                    }
307                    Err(e) => warn!("Failed to deserialize checkpoint from {:?}: {}", path, e),
308                },
309                Err(e) => warn!("Failed to read checkpoint file {:?}: {}", path, e),
310            }
311        }
312
313        if let Some(states) = pipeline_states {
314            for (name, state) in states {
315                if let Some(pipeline) = self.pipelines.iter().find(|p| p.name() == name) {
316                    pipeline.restore_state(state).await?;
317                } else {
318                    warn!("Checkpoint contains state for unknown pipeline: {}", name);
319                }
320            }
321        }
322
323        Ok((scheduler_state, cookie_store))
324    }
325
326    #[cfg(all(feature = "checkpoint", not(feature = "cookie-store")))]
327    async fn restore_checkpoint(
328        &mut self,
329    ) -> Result<(Option<SchedulerCheckpoint>, Option<()>), SpiderError> {
330        let mut scheduler_state = None;
331        let mut pipeline_states = None;
332
333        if let Some(path) = &self.checkpoint_path {
334            debug!("Attempting to load checkpoint from {:?}", path);
335            match fs::read(path) {
336                Ok(bytes) => match rmp_serde::from_slice::<crate::Checkpoint>(&bytes) {
337                    Ok(checkpoint) => {
338                        scheduler_state = Some(checkpoint.scheduler);
339                        pipeline_states = Some(checkpoint.pipelines);
340                    }
341                    Err(e) => warn!("Failed to deserialize checkpoint from {:?}: {}", path, e),
342                },
343                Err(e) => warn!("Failed to read checkpoint file {:?}: {}", path, e),
344            }
345        }
346
347        if let Some(states) = pipeline_states {
348            for (name, state) in states {
349                if let Some(pipeline) = self.pipelines.iter().find(|p| p.name() == name) {
350                    pipeline.restore_state(state).await?;
351                } else {
352                    warn!("Checkpoint contains state for unknown pipeline: {}", name);
353                }
354            }
355        }
356
357        Ok((scheduler_state, None))
358    }
359
360    #[cfg(all(not(feature = "checkpoint"), not(feature = "cookie-store")))]
361    async fn restore_checkpoint(&mut self) -> Result<(Option<()>, Option<()>), SpiderError> {
362        Ok((None, None))
363    }
364
365    #[cfg(all(not(feature = "checkpoint"), feature = "cookie-store"))]
366    async fn restore_checkpoint(&mut self) -> Result<(Option<()>, Option<crate::CookieStore>), SpiderError> {
367        Ok((None, Some(crate::CookieStore::default())))
368    }
369
370    fn take_spider(&mut self) -> Result<S, SpiderError> {
371        if self.config.max_concurrent_downloads == 0 {
372            return Err(SpiderError::ConfigurationError(
373                "max_concurrent_downloads must be greater than 0.".to_string(),
374            ));
375        }
376        if self.config.parser_workers == 0 {
377            return Err(SpiderError::ConfigurationError(
378                "parser_workers must be greater than 0.".to_string(),
379            ));
380        }
381        self.spider.take().ok_or_else(|| {
382            SpiderError::ConfigurationError("Crawler must have a spider.".to_string())
383        })
384    }
385
386    fn init_default_pipeline(&mut self) {
387        if self.pipelines.is_empty() {
388            use spider_pipeline::console::ConsolePipeline;
389            self.pipelines.push(Box::new(ConsolePipeline::new()));
390        }
391    }
392}
393