1use crate::Downloader;
43use crate::ReqwestClientDownloader;
44use crate::scheduler::Scheduler;
45use crate::spider::Spider;
46use num_cpus;
47use spider_middleware::middleware::Middleware;
48use spider_pipeline::pipeline::Pipeline;
49use spider_util::error::SpiderError;
50use std::marker::PhantomData;
51use std::path::{Path, PathBuf};
52use std::sync::Arc;
53use std::time::Duration;
54
55use super::Crawler;
56use crate::stats::StatCollector;
57#[cfg(feature = "checkpoint")]
58use log::{debug, warn};
59
60#[cfg(feature = "checkpoint")]
61use crate::SchedulerCheckpoint;
62#[cfg(feature = "checkpoint")]
63use rmp_serde;
64#[cfg(feature = "checkpoint")]
65use std::fs;
66
67pub struct CrawlerConfig {
69 pub max_concurrent_downloads: usize,
71 pub parser_workers: usize,
73 pub max_concurrent_pipelines: usize,
75 pub channel_capacity: usize,
77}
78
79impl Default for CrawlerConfig {
80 fn default() -> Self {
81 CrawlerConfig {
82 max_concurrent_downloads: num_cpus::get().max(16),
83 parser_workers: num_cpus::get().clamp(4, 16),
84 max_concurrent_pipelines: num_cpus::get().min(8),
85 channel_capacity: 1000,
86 }
87 }
88}
89
90pub struct CrawlerBuilder<S: Spider, D>
91where
92 D: Downloader,
93{
94 config: CrawlerConfig,
95 downloader: D,
96 spider: Option<S>,
97 middlewares: Vec<Box<dyn Middleware<D::Client> + Send + Sync>>,
98 pipelines: Vec<Box<dyn Pipeline<S::Item>>>,
99 checkpoint_path: Option<PathBuf>,
100 checkpoint_interval: Option<Duration>,
101 _phantom: PhantomData<S>,
102}
103
104impl<S: Spider> Default for CrawlerBuilder<S, ReqwestClientDownloader> {
105 fn default() -> Self {
106 Self {
107 config: CrawlerConfig::default(),
108 downloader: ReqwestClientDownloader::default(),
109 spider: None,
110 middlewares: Vec::new(),
111 pipelines: Vec::new(),
112 checkpoint_path: None,
113 checkpoint_interval: None,
114 _phantom: PhantomData,
115 }
116 }
117}
118
119impl<S: Spider> CrawlerBuilder<S, ReqwestClientDownloader> {
120 pub fn new(spider: S) -> Self {
122 Self {
123 spider: Some(spider),
124 ..Default::default()
125 }
126 }
127}
128
129impl<S: Spider, D: Downloader> CrawlerBuilder<S, D> {
130 pub fn max_concurrent_downloads(mut self, limit: usize) -> Self {
131 self.config.max_concurrent_downloads = limit;
132 self
133 }
134
135 pub fn max_parser_workers(mut self, limit: usize) -> Self {
136 self.config.parser_workers = limit;
137 self
138 }
139
140 pub fn max_concurrent_pipelines(mut self, limit: usize) -> Self {
141 self.config.max_concurrent_pipelines = limit;
142 self
143 }
144
145 pub fn channel_capacity(mut self, capacity: usize) -> Self {
146 self.config.channel_capacity = capacity;
147 self
148 }
149
150 pub fn downloader(mut self, downloader: D) -> Self {
151 self.downloader = downloader;
152 self
153 }
154
155 pub fn add_middleware<M>(mut self, middleware: M) -> Self
156 where
157 M: Middleware<D::Client> + Send + Sync + 'static,
158 {
159 self.middlewares.push(Box::new(middleware));
160 self
161 }
162
163 pub fn add_pipeline<P>(mut self, pipeline: P) -> Self
164 where
165 P: Pipeline<S::Item> + 'static,
166 {
167 self.pipelines.push(Box::new(pipeline));
168 self
169 }
170
171 pub fn with_checkpoint_path<P: AsRef<Path>>(mut self, path: P) -> Self {
172 self.checkpoint_path = Some(path.as_ref().to_path_buf());
173 self
174 }
175
176 pub fn with_checkpoint_interval(mut self, interval: Duration) -> Self {
177 self.checkpoint_interval = Some(interval);
178 self
179 }
180
181 #[allow(unused_variables)]
182 pub async fn build(mut self) -> Result<Crawler<S, D::Client>, SpiderError>
183 where
184 D: Downloader + Send + Sync + 'static,
185 D::Client: Send + Sync + Clone,
186 S::Item: Send + Sync + 'static,
187 {
188 let spider = self.take_spider()?;
189 self.init_default_pipeline();
190
191 #[cfg(all(feature = "checkpoint", feature = "cookie-store"))]
192 {
193 let (scheduler_state, cookie_store) = self.restore_checkpoint().await?;
194 let (scheduler_arc, req_rx) = Scheduler::new(scheduler_state);
195 let downloader_arc = Arc::new(self.downloader);
196 let stats = Arc::new(StatCollector::new());
197 let crawler = Crawler::new(
198 scheduler_arc,
199 req_rx,
200 downloader_arc,
201 self.middlewares,
202 spider,
203 self.pipelines,
204 self.config.max_concurrent_downloads,
205 self.config.parser_workers,
206 self.config.max_concurrent_pipelines,
207 self.config.channel_capacity,
208 self.checkpoint_path.take(),
209 self.checkpoint_interval,
210 stats,
211 Arc::new(tokio::sync::RwLock::new(
212 cookie_store.unwrap_or_default(),
213 )),
214 );
215 Ok(crawler)
216 }
217
218 #[cfg(all(feature = "checkpoint", not(feature = "cookie-store")))]
219 {
220 let (scheduler_state, _cookie_store) = self.restore_checkpoint().await?;
221 let (scheduler_arc, req_rx) = Scheduler::new(scheduler_state);
222 let downloader_arc = Arc::new(self.downloader);
223 let stats = Arc::new(StatCollector::new());
224 let crawler = Crawler::new(
225 scheduler_arc,
226 req_rx,
227 downloader_arc,
228 self.middlewares,
229 spider,
230 self.pipelines,
231 self.config.max_concurrent_downloads,
232 self.config.parser_workers,
233 self.config.max_concurrent_pipelines,
234 self.config.channel_capacity,
235 self.checkpoint_path.take(),
236 self.checkpoint_interval,
237 stats,
238 );
239 Ok(crawler)
240 }
241
242 #[cfg(all(not(feature = "checkpoint"), feature = "cookie-store"))]
243 {
244 let (_scheduler_state, cookie_store) = self.restore_checkpoint().await?;
245 let (scheduler_arc, req_rx) = Scheduler::new(None::<()>);
246 let downloader_arc = Arc::new(self.downloader);
247 let stats = Arc::new(StatCollector::new());
248 let crawler = Crawler::new(
249 scheduler_arc,
250 req_rx,
251 downloader_arc,
252 self.middlewares,
253 spider,
254 self.pipelines,
255 self.config.max_concurrent_downloads,
256 self.config.parser_workers,
257 self.config.max_concurrent_pipelines,
258 self.config.channel_capacity,
259 stats,
260 Arc::new(tokio::sync::RwLock::new(
261 cookie_store.unwrap_or_default(),
262 )),
263 );
264 Ok(crawler)
265 }
266
267 #[cfg(all(not(feature = "checkpoint"), not(feature = "cookie-store")))]
268 {
269 let (_scheduler_state, _cookie_store) = self.restore_checkpoint().await?;
270 let (scheduler_arc, req_rx) = Scheduler::new(None::<()>);
271 let downloader_arc = Arc::new(self.downloader);
272 let stats = Arc::new(StatCollector::new());
273 let crawler = Crawler::new(
274 scheduler_arc,
275 req_rx,
276 downloader_arc,
277 self.middlewares,
278 spider,
279 self.pipelines,
280 self.config.max_concurrent_downloads,
281 self.config.parser_workers,
282 self.config.max_concurrent_pipelines,
283 self.config.channel_capacity,
284 stats,
285 );
286 Ok(crawler)
287 }
288 }
289
290 #[cfg(all(feature = "checkpoint", feature = "cookie-store"))]
291 async fn restore_checkpoint(
292 &mut self,
293 ) -> Result<(Option<SchedulerCheckpoint>, Option<crate::CookieStore>), SpiderError> {
294 let mut scheduler_state = None;
295 let mut pipeline_states = None;
296 let mut cookie_store = None;
297
298 if let Some(path) = &self.checkpoint_path {
299 debug!("Attempting to load checkpoint from {:?}", path);
300 match fs::read(path) {
301 Ok(bytes) => match rmp_serde::from_slice::<crate::Checkpoint>(&bytes) {
302 Ok(checkpoint) => {
303 scheduler_state = Some(checkpoint.scheduler);
304 pipeline_states = Some(checkpoint.pipelines);
305 cookie_store = Some(checkpoint.cookie_store);
306 }
307 Err(e) => warn!("Failed to deserialize checkpoint from {:?}: {}", path, e),
308 },
309 Err(e) => warn!("Failed to read checkpoint file {:?}: {}", path, e),
310 }
311 }
312
313 if let Some(states) = pipeline_states {
314 for (name, state) in states {
315 if let Some(pipeline) = self.pipelines.iter().find(|p| p.name() == name) {
316 pipeline.restore_state(state).await?;
317 } else {
318 warn!("Checkpoint contains state for unknown pipeline: {}", name);
319 }
320 }
321 }
322
323 Ok((scheduler_state, cookie_store))
324 }
325
326 #[cfg(all(feature = "checkpoint", not(feature = "cookie-store")))]
327 async fn restore_checkpoint(
328 &mut self,
329 ) -> Result<(Option<SchedulerCheckpoint>, Option<()>), SpiderError> {
330 let mut scheduler_state = None;
331 let mut pipeline_states = None;
332
333 if let Some(path) = &self.checkpoint_path {
334 debug!("Attempting to load checkpoint from {:?}", path);
335 match fs::read(path) {
336 Ok(bytes) => match rmp_serde::from_slice::<crate::Checkpoint>(&bytes) {
337 Ok(checkpoint) => {
338 scheduler_state = Some(checkpoint.scheduler);
339 pipeline_states = Some(checkpoint.pipelines);
340 }
341 Err(e) => warn!("Failed to deserialize checkpoint from {:?}: {}", path, e),
342 },
343 Err(e) => warn!("Failed to read checkpoint file {:?}: {}", path, e),
344 }
345 }
346
347 if let Some(states) = pipeline_states {
348 for (name, state) in states {
349 if let Some(pipeline) = self.pipelines.iter().find(|p| p.name() == name) {
350 pipeline.restore_state(state).await?;
351 } else {
352 warn!("Checkpoint contains state for unknown pipeline: {}", name);
353 }
354 }
355 }
356
357 Ok((scheduler_state, None))
358 }
359
360 #[cfg(all(not(feature = "checkpoint"), not(feature = "cookie-store")))]
361 async fn restore_checkpoint(&mut self) -> Result<(Option<()>, Option<()>), SpiderError> {
362 Ok((None, None))
363 }
364
365 #[cfg(all(not(feature = "checkpoint"), feature = "cookie-store"))]
366 async fn restore_checkpoint(&mut self) -> Result<(Option<()>, Option<crate::CookieStore>), SpiderError> {
367 Ok((None, Some(crate::CookieStore::default())))
368 }
369
370 fn take_spider(&mut self) -> Result<S, SpiderError> {
371 if self.config.max_concurrent_downloads == 0 {
372 return Err(SpiderError::ConfigurationError(
373 "max_concurrent_downloads must be greater than 0.".to_string(),
374 ));
375 }
376 if self.config.parser_workers == 0 {
377 return Err(SpiderError::ConfigurationError(
378 "parser_workers must be greater than 0.".to_string(),
379 ));
380 }
381 self.spider.take().ok_or_else(|| {
382 SpiderError::ConfigurationError("Crawler must have a spider.".to_string())
383 })
384 }
385
386 fn init_default_pipeline(&mut self) {
387 if self.pipelines.is_empty() {
388 use spider_pipeline::console::ConsolePipeline;
389 self.pipelines.push(Box::new(ConsolePipeline::new()));
390 }
391 }
392}
393