1use spider_util::error::SpiderError;
43use spider_middleware::middleware::Middleware;
44use spider_pipeline::pipeline::Pipeline;
45use crate::scheduler::Scheduler;
46use crate::spider::Spider;
47use crate::Downloader;
48use crate::ReqwestClientDownloader;
49use num_cpus;
50use std::marker::PhantomData;
51use std::path::{Path, PathBuf};
52use std::sync::Arc;
53use std::time::Duration;
54
55use super::Crawler;
56use crate::stats::StatCollector;
57#[cfg(feature = "checkpoint")]
58use tracing::{debug, warn};
59
60#[cfg(feature = "checkpoint")]
61use crate::SchedulerCheckpoint;
62#[cfg(feature = "checkpoint")]
63use std::fs;
64#[cfg(feature = "checkpoint")]
65use rmp_serde;
66
67pub struct CrawlerConfig {
69 pub max_concurrent_downloads: usize,
71 pub parser_workers: usize,
73 pub max_concurrent_pipelines: usize,
75 pub channel_capacity: usize,
77}
78
79impl Default for CrawlerConfig {
80 fn default() -> Self {
81 CrawlerConfig {
82 max_concurrent_downloads: 5,
83 parser_workers: num_cpus::get(),
84 max_concurrent_pipelines: 5,
85 channel_capacity: 200,
86 }
87 }
88}
89
90pub struct CrawlerBuilder<S: Spider, D>
91where
92 D: Downloader,
93{
94 crawler_config: CrawlerConfig,
95 downloader: D,
96 spider: Option<S>,
97 middlewares: Vec<Box<dyn Middleware<D::Client> + Send + Sync>>,
98 item_pipelines: Vec<Box<dyn Pipeline<S::Item>>>,
99 checkpoint_path: Option<PathBuf>,
100 checkpoint_interval: Option<Duration>,
101 _phantom: PhantomData<S>,
102}
103
104impl<S: Spider> Default for CrawlerBuilder<S, ReqwestClientDownloader> {
105 fn default() -> Self {
106 Self {
107 crawler_config: CrawlerConfig::default(),
108 downloader: ReqwestClientDownloader::default(),
109 spider: None,
110 middlewares: Vec::new(),
111 item_pipelines: Vec::new(),
112 checkpoint_path: None,
113 checkpoint_interval: None,
114 _phantom: PhantomData,
115 }
116 }
117}
118
119impl<S: Spider> CrawlerBuilder<S, ReqwestClientDownloader> {
120 pub fn new(spider: S) -> Self {
122 Self {
123 spider: Some(spider),
124 ..Default::default()
125 }
126 }
127}
128
129impl<S: Spider, D: Downloader> CrawlerBuilder<S, D> {
130 pub fn max_concurrent_downloads(mut self, limit: usize) -> Self {
132 self.crawler_config.max_concurrent_downloads = limit;
133 self
134 }
135
136 pub fn max_parser_workers(mut self, limit: usize) -> Self {
138 self.crawler_config.parser_workers = limit;
139 self
140 }
141
142 pub fn max_concurrent_pipelines(mut self, limit: usize) -> Self {
144 self.crawler_config.max_concurrent_pipelines = limit;
145 self
146 }
147
148 pub fn channel_capacity(mut self, capacity: usize) -> Self {
150 self.crawler_config.channel_capacity = capacity;
151 self
152 }
153
154 pub fn downloader(mut self, downloader: D) -> Self {
156 self.downloader = downloader;
157 self
158 }
159
160 pub fn add_middleware<M>(mut self, middleware: M) -> Self
162 where
163 M: Middleware<D::Client> + Send + Sync + 'static,
164 {
165 self.middlewares.push(Box::new(middleware));
166 self
167 }
168
169 pub fn add_pipeline<P>(mut self, pipeline: P) -> Self
171 where
172 P: Pipeline<S::Item> + 'static,
173 {
174 self.item_pipelines.push(Box::new(pipeline));
175 self
176 }
177
178 pub fn with_checkpoint_path<P: AsRef<Path>>(mut self, path: P) -> Self {
180 self.checkpoint_path = Some(path.as_ref().to_path_buf());
181 self
182 }
183
184 pub fn with_checkpoint_interval(mut self, interval: Duration) -> Self {
186 self.checkpoint_interval = Some(interval);
187 self
188 }
189
190 #[allow(unused_variables)]
192 pub async fn build(mut self) -> Result<Crawler<S, D::Client>, SpiderError>
193 where
194 D: Downloader + Send + Sync + 'static,
195 D::Client: Send + Sync + Clone,
196 S::Item: Send + Sync + 'static,
197 {
198 let spider = self.validate_and_get_spider()?;
199
200 if self.item_pipelines.is_empty() {
202 use spider_pipeline::console_writer::ConsoleWriterPipeline;
203 self.item_pipelines.push(Box::new(ConsoleWriterPipeline::new()));
204 }
205
206 #[cfg(all(feature = "checkpoint", feature = "cookie-store"))]
207 {
208 let (initial_scheduler_state, loaded_cookie_store) =
209 self.load_and_restore_checkpoint_state().await?;
210 let (scheduler_arc, req_rx) = Scheduler::new(initial_scheduler_state);
211 let downloader_arc = Arc::new(self.downloader);
212 let stats = Arc::new(StatCollector::new());
213 let crawler = Crawler::new(
214 scheduler_arc,
215 req_rx,
216 downloader_arc,
217 self.middlewares,
218 spider,
219 self.item_pipelines,
220 self.crawler_config.max_concurrent_downloads,
221 self.crawler_config.parser_workers,
222 self.crawler_config.max_concurrent_pipelines,
223 self.crawler_config.channel_capacity,
224 self.checkpoint_path.take(),
225 self.checkpoint_interval,
226 stats,
227 Arc::new(tokio::sync::RwLock::new(loaded_cookie_store.unwrap_or_default())),
228 );
229 return Ok(crawler);
230 }
231
232 #[cfg(all(feature = "checkpoint", not(feature = "cookie-store")))]
233 {
234 let (initial_scheduler_state, _loaded_cookie_store) =
235 self.load_and_restore_checkpoint_state().await?;
236 let (scheduler_arc, req_rx) = Scheduler::new(initial_scheduler_state);
237 let downloader_arc = Arc::new(self.downloader);
238 let stats = Arc::new(StatCollector::new());
239 let crawler = Crawler::new(
240 scheduler_arc,
241 req_rx,
242 downloader_arc,
243 self.middlewares,
244 spider,
245 self.item_pipelines,
246 self.crawler_config.max_concurrent_downloads,
247 self.crawler_config.parser_workers,
248 self.crawler_config.max_concurrent_pipelines,
249 self.crawler_config.channel_capacity,
250 self.checkpoint_path.take(),
251 self.checkpoint_interval,
252 stats,
253 );
254 return Ok(crawler);
255 }
256
257 #[cfg(all(not(feature = "checkpoint"), feature = "cookie-store"))]
258 {
259 let (_initial_scheduler_state, loaded_cookie_store) =
260 self.load_and_restore_checkpoint_state().await?;
261 let (scheduler_arc, req_rx) = Scheduler::new(None::<()>);
262 let downloader_arc = Arc::new(self.downloader);
263 let stats = Arc::new(StatCollector::new());
264 let crawler = Crawler::new(
265 scheduler_arc,
266 req_rx,
267 downloader_arc,
268 self.middlewares,
269 spider,
270 self.item_pipelines,
271 self.crawler_config.max_concurrent_downloads,
272 self.crawler_config.parser_workers,
273 self.crawler_config.max_concurrent_pipelines,
274 self.crawler_config.channel_capacity,
275 stats,
276 Arc::new(tokio::sync::RwLock::new(loaded_cookie_store.unwrap_or_default())),
277 );
278 return Ok(crawler);
279 }
280
281 #[cfg(all(not(feature = "checkpoint"), not(feature = "cookie-store")))]
282 {
283 let (_initial_scheduler_state, _loaded_cookie_store) =
284 self.load_and_restore_checkpoint_state().await?;
285 let (scheduler_arc, req_rx) = Scheduler::new(None::<()>);
286 let downloader_arc = Arc::new(self.downloader);
287 let stats = Arc::new(StatCollector::new());
288 let crawler = Crawler::new(
289 scheduler_arc,
290 req_rx,
291 downloader_arc,
292 self.middlewares,
293 spider,
294 self.item_pipelines,
295 self.crawler_config.max_concurrent_downloads,
296 self.crawler_config.parser_workers,
297 self.crawler_config.max_concurrent_pipelines,
298 self.crawler_config.channel_capacity,
299 stats,
300 );
301 return Ok(crawler);
302 }
303 }
304
305 #[cfg(all(feature = "checkpoint", feature = "cookie-store"))]
306 async fn load_and_restore_checkpoint_state(
307 &mut self,
308 ) -> Result<(Option<SchedulerCheckpoint>, Option<crate::CookieStore>), SpiderError> {
309 let mut initial_scheduler_state = None;
310 let mut loaded_pipelines_state = None;
311 let mut loaded_cookie_store = None;
312
313 if let Some(path) = &self.checkpoint_path {
314 debug!("Attempting to load checkpoint from {:?}", path);
315 match fs::read(path) {
316 Ok(bytes) => match rmp_serde::from_slice::<crate::Checkpoint>(&bytes) {
317 Ok(checkpoint) => {
318 initial_scheduler_state = Some(checkpoint.scheduler);
319 loaded_pipelines_state = Some(checkpoint.pipelines);
320
321 loaded_cookie_store = Some(checkpoint.cookie_store);
322 }
323 Err(e) => warn!("Failed to deserialize checkpoint from {:?}: {}", path, e),
324 },
325 Err(e) => warn!("Failed to read checkpoint file {:?}: {}", path, e),
326 }
327 }
328
329 if let Some(pipeline_states) = loaded_pipelines_state {
330 for (name, state) in pipeline_states {
331 if let Some(pipeline) = self.item_pipelines.iter().find(|p| p.name() == name) {
332 pipeline.restore_state(state).await?;
333 } else {
334 warn!("Checkpoint contains state for unknown pipeline: {}", name);
335 }
336 }
337 }
338
339 Ok((initial_scheduler_state, loaded_cookie_store))
340 }
341
342 #[cfg(all(feature = "checkpoint", not(feature = "cookie-store")))]
343 async fn load_and_restore_checkpoint_state(
344 &mut self,
345 ) -> Result<(Option<SchedulerCheckpoint>, Option<()>), SpiderError> {
346 let mut initial_scheduler_state = None;
347 let mut loaded_pipelines_state = None;
348
349 if let Some(path) = &self.checkpoint_path {
350 debug!("Attempting to load checkpoint from {:?}", path);
351 match fs::read(path) {
352 Ok(bytes) => match rmp_serde::from_slice::<crate::Checkpoint>(&bytes) {
353 Ok(checkpoint) => {
354 initial_scheduler_state = Some(checkpoint.scheduler);
355 loaded_pipelines_state = Some(checkpoint.pipelines);
356 }
357 Err(e) => warn!("Failed to deserialize checkpoint from {:?}: {}", path, e),
358 },
359 Err(e) => warn!("Failed to read checkpoint file {:?}: {}", path, e),
360 }
361 }
362
363 if let Some(pipeline_states) = loaded_pipelines_state {
364 for (name, state) in pipeline_states {
365 if let Some(pipeline) = self.item_pipelines.iter().find(|p| p.name() == name) {
366 pipeline.restore_state(state).await?;
367 } else {
368 warn!("Checkpoint contains state for unknown pipeline: {}", name);
369 }
370 }
371 }
372
373 Ok((initial_scheduler_state, None))
374 }
375
376 #[cfg(all(not(feature = "checkpoint"), not(feature = "cookie-store")))]
377 async fn load_and_restore_checkpoint_state(
378 &mut self,
379 ) -> Result<(Option<()>, Option<()>), SpiderError> {
380 Ok((None, None))
382 }
383
384 #[cfg(all(not(feature = "checkpoint"), feature = "cookie-store"))]
385 async fn load_and_restore_checkpoint_state(
386 &mut self,
387 ) -> Result<(Option<()>, Option<crate::CookieStore>), SpiderError> {
388 Ok((None, Some(crate::CookieStore::default())))
390 }
391
392 fn validate_and_get_spider(&mut self) -> Result<S, SpiderError> {
393 if self.crawler_config.max_concurrent_downloads == 0 {
394 return Err(SpiderError::ConfigurationError(
395 "max_concurrent_downloads must be greater than 0.".to_string(),
396 ));
397 }
398 if self.crawler_config.parser_workers == 0 {
399 return Err(SpiderError::ConfigurationError(
400 "parser_workers must be greater than 0.".to_string(),
401 ));
402 }
403 self.spider.take().ok_or_else(|| {
404 SpiderError::ConfigurationError("Crawler must have a spider.".to_string())
405 })
406 }
407}