Skip to main content

spider_lib/
builder.rs

1//! Builder for constructing and configuring the `Crawler` instance.
2//!
3//! This module provides the `CrawlerBuilder`, a fluent API for
4//! setting up and customizing a web crawler. It simplifies the process of
5//! assembling various `spider-lib` components, including:
6//! - Defining concurrency settings for downloads, parsing, and pipelines.
7//! - Attaching custom `Downloader` implementations.
8//! - Registering `Middleware`s to process requests and responses.
9//! - Adding `Pipeline`s to process scraped items.
10//! - Configuring checkpointing for persistence and fault tolerance.
11//! - Initializing and integrating a `StatCollector` for gathering crawl statistics.
12//!
13//! The builder handles default configurations (e.g., adding a default User-Agent
14//! middleware if none is specified) and loading existing checkpoints.
15
16use crate::ConsoleWriterPipeline;
17#[cfg(feature = "checkpoint")]
18use crate::checkpoint::Checkpoint;
19use crate::downloader::Downloader;
20use crate::downloaders::reqwest_client::ReqwestClientDownloader;
21use crate::error::SpiderError;
22use crate::middleware::Middleware;
23use crate::middlewares::user_agent::{UserAgentMiddleware, UserAgentSource};
24use crate::pipeline::Pipeline;
25use crate::scheduler::Scheduler;
26use crate::spider::Spider;
27use num_cpus;
28#[cfg(feature = "middleware-cookies")]
29use std::any::Any;
30#[cfg(feature = "checkpoint")]
31use std::fs;
32use std::marker::PhantomData;
33#[cfg(feature = "checkpoint")]
34use std::path::{Path, PathBuf};
35use std::sync::Arc;
36#[cfg(feature = "checkpoint")]
37use std::time::Duration;
38#[cfg(feature = "checkpoint")]
39use tracing::{debug, info, warn};
40
41#[cfg(feature = "middleware-cookies")]
42use crate::middlewares::cookies::CookieMiddleware;
43#[cfg(feature = "middleware-cookies")]
44use cookie_store::CookieStore;
45#[cfg(feature = "middleware-cookies")]
46use tokio::sync::Mutex;
47
48use super::Crawler;
49use crate::stats::StatCollector;
50
51/// Configuration for the crawler's concurrency settings.
52pub struct CrawlerConfig {
53    /// The maximum number of concurrent downloads.
54    pub max_concurrent_downloads: usize,
55    /// The number of workers dedicated to parsing responses.
56    pub parser_workers: usize,
57    /// The maximum number of concurrent item processing pipelines.
58    pub max_concurrent_pipelines: usize,
59}
60
61impl Default for CrawlerConfig {
62    fn default() -> Self {
63        CrawlerConfig {
64            max_concurrent_downloads: 5,
65            parser_workers: num_cpus::get(),
66            max_concurrent_pipelines: 5,
67        }
68    }
69}
70
71pub struct CrawlerBuilder<S: Spider, D = ReqwestClientDownloader>
72where
73    D: Downloader,
74{
75    crawler_config: CrawlerConfig,
76    downloader: D,
77    spider: Option<S>,
78    middlewares: Vec<Box<dyn Middleware<D::Client> + Send + Sync>>,
79    item_pipelines: Vec<Box<dyn Pipeline<S::Item>>>,
80    #[cfg(feature = "checkpoint")]
81    checkpoint_path: Option<PathBuf>,
82    #[cfg(feature = "checkpoint")]
83    checkpoint_interval: Option<Duration>,
84    _phantom: PhantomData<S>,
85}
86
87impl<S: Spider, D: Default + Downloader> Default for CrawlerBuilder<S, D> {
88    fn default() -> Self {
89        Self {
90            crawler_config: CrawlerConfig::default(),
91            downloader: D::default(),
92            spider: None,
93            middlewares: Vec::new(),
94            item_pipelines: Vec::new(),
95            #[cfg(feature = "checkpoint")]
96            checkpoint_path: None,
97            #[cfg(feature = "checkpoint")]
98            checkpoint_interval: None,
99            _phantom: PhantomData,
100        }
101    }
102}
103
104impl<S: Spider> CrawlerBuilder<S> {
105    /// Creates a new `CrawlerBuilder` for a given spider.
106    pub fn new(spider: S) -> Self {
107        Self {
108            spider: Some(spider),
109            ..Default::default()
110        }
111    }
112}
113
114impl<S: Spider, D: Downloader> CrawlerBuilder<S, D> {
115    /// Sets the maximum number of concurrent downloads.
116    pub fn max_concurrent_downloads(mut self, limit: usize) -> Self {
117        self.crawler_config.max_concurrent_downloads = limit;
118        self
119    }
120
121    /// Sets the maximum number of concurrent parser workers.
122    pub fn max_parser_workers(mut self, limit: usize) -> Self {
123        self.crawler_config.parser_workers = limit;
124        self
125    }
126
127    /// Sets the maximum number of concurrent pipelines.
128    pub fn max_concurrent_pipelines(mut self, limit: usize) -> Self {
129        self.crawler_config.max_concurrent_pipelines = limit;
130        self
131    }
132
133    /// Sets a custom downloader for the crawler.
134    pub fn downloader(mut self, downloader: D) -> Self {
135        self.downloader = downloader;
136        self
137    }
138
139    /// Adds a middleware to the crawler.
140    pub fn add_middleware<M>(mut self, middleware: M) -> Self
141    where
142        D: Downloader,
143        M: Middleware<D::Client> + Send + Sync + 'static,
144    {
145        self.middlewares.push(Box::new(middleware));
146        self
147    }
148
149    /// Adds an item pipeline to the crawler.
150    pub fn add_pipeline<P>(mut self, pipeline: P) -> Self
151    where
152        P: Pipeline<S::Item> + 'static,
153    {
154        self.item_pipelines.push(Box::new(pipeline));
155        self
156    }
157
158    /// Enables checkpointing and sets the path for the checkpoint file.
159    #[cfg(feature = "checkpoint")]
160    pub fn with_checkpoint_path<P: AsRef<Path>>(mut self, path: P) -> Self {
161        self.checkpoint_path = Some(path.as_ref().to_path_buf());
162        self
163    }
164
165    /// Sets the interval for periodic checkpointing.
166    #[cfg(feature = "checkpoint")]
167    pub fn with_checkpoint_interval(mut self, interval: Duration) -> Self {
168        self.checkpoint_interval = Some(interval);
169        self
170    }
171
172    /// Builds the `Crawler` instance, initializing and passing the `StatCollector` along with other configured components.
173    pub async fn build(mut self) -> Result<Crawler<S, D::Client>, SpiderError>
174    where
175        D: Downloader + Send + Sync + 'static,
176        D::Client: Send + Sync,
177    {
178        if self.item_pipelines.is_empty() {
179            self = self.add_pipeline(ConsoleWriterPipeline::new());
180        }
181
182        let spider = self.spider.take().ok_or_else(|| {
183            SpiderError::ConfigurationError("Crawler must have a spider.".to_string())
184        })?;
185
186        if self.crawler_config.max_concurrent_downloads == 0 {
187            return Err(SpiderError::ConfigurationError(
188                "max_concurrent_downloads must be greater than 0.".to_string(),
189            ));
190        }
191        if self.crawler_config.parser_workers == 0 {
192            return Err(SpiderError::ConfigurationError(
193                "parser_workers must be greater than 0.".to_string(),
194            ));
195        }
196
197        #[cfg(feature = "checkpoint")]
198        let mut initial_scheduler_state = None;
199        #[cfg(not(feature = "checkpoint"))]
200        let initial_scheduler_state = None;
201        #[cfg(feature = "checkpoint")]
202        let mut loaded_pipelines_state = None;
203        #[cfg(all(feature = "checkpoint", feature = "middleware-cookies"))]
204        let mut loaded_cookie_store: Option<CookieStore> = None;
205
206        #[cfg(feature = "checkpoint")]
207        if let Some(path) = &self.checkpoint_path {
208            debug!("Attempting to load checkpoint from {:?}", path);
209            match fs::read(path) {
210                Ok(bytes) => match rmp_serde::from_slice::<Checkpoint>(&bytes) {
211                    Ok(checkpoint) => {
212                        initial_scheduler_state = Some(checkpoint.scheduler);
213                        loaded_pipelines_state = Some(checkpoint.pipelines);
214
215                        #[cfg(feature = "middleware-cookies")]
216                        {
217                            info!("Checkpoint loaded, restoring cookie store data.");
218                            loaded_cookie_store = Some(checkpoint.cookie_store);
219                        }
220                    }
221                    Err(e) => warn!("Failed to deserialize checkpoint from {:?}: {}", path, e),
222                },
223                Err(e) => warn!("Failed to read checkpoint file {:?}: {}", path, e),
224            }
225        }
226
227        #[cfg(feature = "checkpoint")]
228        // Restore pipeline states now that pipelines are built
229        if let Some(pipeline_states) = loaded_pipelines_state {
230            for (name, state) in pipeline_states {
231                if let Some(pipeline) = self.item_pipelines.iter().find(|p| p.name() == name) {
232                    pipeline.restore_state(state).await?;
233                } else {
234                    warn!("Checkpoint contains state for unknown pipeline: {}", name);
235                }
236            }
237        }
238
239        let (scheduler_arc, req_rx) = Scheduler::new(initial_scheduler_state);
240
241        let has_user_agent_middleware = self
242            .middlewares
243            .iter()
244            .any(|m| m.name() == "UserAgentMiddleware");
245
246        if !has_user_agent_middleware {
247            let pkg_name = env!("CARGO_PKG_NAME");
248            let pkg_version = env!("CARGO_PKG_VERSION");
249            let default_user_agent = format!("{}/{}", pkg_name, pkg_version);
250
251            let default_user_agent_mw = UserAgentMiddleware::builder()
252                .source(UserAgentSource::List(vec![default_user_agent.clone()]))
253                .fallback_user_agent(default_user_agent)
254                .build()?;
255            self.middlewares.insert(0, Box::new(default_user_agent_mw));
256        }
257
258        let downloader_arc = Arc::new(self.downloader);
259        let stats = Arc::new(StatCollector::new());
260
261        let crawler = {
262            #[cfg(not(feature = "middleware-cookies"))]
263            {
264                Crawler::new(
265                    scheduler_arc,
266                    req_rx,
267                    downloader_arc,
268                    self.middlewares,
269                    spider,
270                    self.item_pipelines,
271                    self.crawler_config.max_concurrent_downloads,
272                    self.crawler_config.parser_workers,
273                    self.crawler_config.max_concurrent_pipelines,
274                    #[cfg(feature = "checkpoint")]
275                    self.checkpoint_path.take(),
276                    #[cfg(feature = "checkpoint")]
277                    self.checkpoint_interval,
278                    stats,
279                )
280            }
281
282            #[cfg(feature = "middleware-cookies")]
283            {
284                let mut final_cookie_store =
285                    Arc::new(Mutex::new(loaded_cookie_store.unwrap_or_default()));
286
287                // Detect CookieMiddleware and use its store, overriding checkpoint or default
288                for mw_box in &self.middlewares {
289                    if let Some(cookie_mw) =
290                        (mw_box.as_ref() as &dyn Any).downcast_ref::<CookieMiddleware>()
291                    {
292                        info!(
293                            "Found CookieMiddleware, using its cookie store for Crawler. This overrides any checkpointed store."
294                        );
295                        final_cookie_store = cookie_mw.store.clone();
296                        break;
297                    }
298                }
299
300                Crawler::new(
301                    scheduler_arc,
302                    req_rx,
303                    downloader_arc,
304                    self.middlewares,
305                    spider,
306                    self.item_pipelines,
307                    self.crawler_config.max_concurrent_downloads,
308                    self.crawler_config.parser_workers,
309                    self.crawler_config.max_concurrent_pipelines,
310                    #[cfg(feature = "checkpoint")]
311                    self.checkpoint_path.take(),
312                    #[cfg(feature = "checkpoint")]
313                    self.checkpoint_interval,
314                    stats,
315                    final_cookie_store,
316                )
317            }
318        };
319
320        Ok(crawler)
321    }
322}