1use crate::downloader::Downloader;
14use crate::error::SpiderError;
15use crate::item::{ParseOutput, ScrapedItem};
16use crate::middleware::{Middleware, MiddlewareAction};
17use crate::pipeline::Pipeline;
18use crate::request::Request;
19use crate::response::Response;
20use crate::scheduler::Scheduler;
21use crate::spider::Spider;
22use crate::state::CrawlerState;
23use anyhow::Result;
24use futures_util::future::join_all;
25use kanal::{AsyncReceiver, AsyncSender, bounded_async};
26
27#[cfg(feature = "checkpoint")]
28use crate::checkpoint::save_checkpoint;
29#[cfg(feature = "checkpoint")]
30use std::path::PathBuf;
31use std::sync::Arc;
32use std::sync::atomic::Ordering;
33use std::time::Duration;
34use tokio::sync::Mutex;
35use tokio::sync::Semaphore;
36use tokio::task::JoinSet;
37use tracing::{debug, error, info, warn};
38
39pub struct Crawler<S: Spider, C> {
40 scheduler: Arc<Scheduler>,
41 req_rx: AsyncReceiver<Request>,
42 downloader: Arc<dyn Downloader<Client = C> + Send + Sync>,
43 middlewares: Vec<Box<dyn Middleware<C> + Send + Sync>>,
44 spider: Arc<Mutex<S>>,
45 item_pipelines: Vec<Box<dyn Pipeline<S::Item>>>,
46 max_concurrent_downloads: usize,
47 parser_workers: usize,
48 max_concurrent_pipelines: usize,
49 #[cfg(feature = "checkpoint")]
50 checkpoint_path: Option<PathBuf>,
51 #[cfg(feature = "checkpoint")]
52 checkpoint_interval: Option<Duration>,
53}
54
55impl<S, C> Crawler<S, C>
56where
57 S: Spider + 'static,
58 S::Item: ScrapedItem,
59 C: Send + Sync + 'static,
60{
61 #[allow(clippy::too_many_arguments)]
62 pub(crate) fn new(
63 scheduler: Arc<Scheduler>,
64 req_rx: AsyncReceiver<Request>,
65 downloader: Arc<dyn Downloader<Client = C> + Send + Sync>,
66 middlewares: Vec<Box<dyn Middleware<C> + Send + Sync>>,
67 spider: S,
68 item_pipelines: Vec<Box<dyn Pipeline<S::Item>>>,
69 max_concurrent_downloads: usize,
70 parser_workers: usize,
71 max_concurrent_pipelines: usize,
72 #[cfg(feature = "checkpoint")] checkpoint_path: Option<PathBuf>,
73 #[cfg(feature = "checkpoint")] checkpoint_interval: Option<Duration>,
74 ) -> Self {
75 Crawler {
76 scheduler,
77 req_rx,
78 downloader,
79 middlewares,
80 spider: Arc::new(Mutex::new(spider)),
81 item_pipelines,
82 max_concurrent_downloads,
83 parser_workers,
84 max_concurrent_pipelines,
85 #[cfg(feature = "checkpoint")]
86 checkpoint_path,
87 #[cfg(feature = "checkpoint")]
88 checkpoint_interval,
89 }
90 }
91
92 pub async fn start_crawl(self) -> Result<(), SpiderError> {
94 info!("Crawler starting crawl");
95
96 let Crawler {
97 scheduler,
98 req_rx,
99 downloader,
100 middlewares,
101 spider,
102 item_pipelines,
103 max_concurrent_downloads,
104 parser_workers,
105 max_concurrent_pipelines,
106 #[cfg(feature = "checkpoint")]
107 checkpoint_path,
108 #[cfg(feature = "checkpoint")]
109 checkpoint_interval,
110 } = self;
111
112 let state = CrawlerState::new();
113 let pipelines = Arc::new(item_pipelines);
114 let channel_capacity = max_concurrent_downloads * 2;
115
116 let (res_tx, res_rx) = bounded_async(channel_capacity);
117 let (item_tx, item_rx) = bounded_async(channel_capacity);
118
119 let initial_requests_task =
120 spawn_initial_requests_task::<S>(scheduler.clone(), spider.clone());
121
122 let downloader_task = spawn_downloader_task::<S, C>(
123 scheduler.clone(),
124 req_rx,
125 downloader,
126 Arc::new(Mutex::new(middlewares)),
127 state.clone(),
128 res_tx.clone(),
129 max_concurrent_downloads,
130 );
131
132 let parser_task = spawn_parser_task::<S>(
133 scheduler.clone(),
134 spider.clone(),
135 state.clone(),
136 res_rx,
137 item_tx.clone(),
138 parser_workers,
139 );
140
141 let item_processor_task = spawn_item_processor_task::<S>(
142 state.clone(),
143 item_rx,
144 pipelines.clone(),
145 max_concurrent_pipelines,
146 );
147
148 #[cfg(feature = "checkpoint")]
149 if let (Some(path), Some(interval)) = (&checkpoint_path, checkpoint_interval) {
150 let scheduler_clone = scheduler.clone();
151 let pipelines_clone = pipelines.clone();
152 let path_clone = path.clone();
153
154 tokio::spawn(async move {
155 let mut interval_timer = tokio::time::interval(interval);
156 interval_timer.tick().await;
157 loop {
158 tokio::select! {
159 _ = interval_timer.tick() => {
160 if let Ok(scheduler_checkpoint) = scheduler_clone.snapshot().await &&
161 let Err(e) = save_checkpoint::<S>(&path_clone, scheduler_checkpoint, &pipelines_clone).await {
162 error!("Periodic checkpoint save failed: {}", e);
163 }
164 }
165 }
166 }
167 });
168 }
169
170 tokio::select! {
171 _ = tokio::signal::ctrl_c() => {
172 info!("Ctrl-C received, initiating graceful shutdown.");
173 }
174 _ = async {
175 loop {
176 if scheduler.is_idle() && state.is_idle() {
177 tokio::time::sleep(Duration::from_millis(50)).await;
178 if scheduler.is_idle() && state.is_idle() {
179 break;
180 }
181 }
182 tokio::time::sleep(Duration::from_millis(100)).await;
183 }
184 } => {
185 info!("Crawl has become idle, initiating shutdown.");
186 }
187 }
188
189 info!("Initiating actor shutdowns.");
190
191 #[cfg(feature = "checkpoint")]
192 let scheduler_checkpoint = scheduler.snapshot().await?;
193
194 drop(res_tx);
195 drop(item_tx);
196
197 scheduler.shutdown().await?;
198
199 item_processor_task
200 .await
201 .map_err(|e| SpiderError::GeneralError(format!("Item processor task failed: {}", e)))?;
202
203 parser_task
204 .await
205 .map_err(|e| SpiderError::GeneralError(format!("Parser task failed: {}", e)))?;
206
207 downloader_task
208 .await
209 .map_err(|e| SpiderError::GeneralError(format!("Downloader task failed: {}", e)))?;
210
211 initial_requests_task.await.map_err(|e| {
212 SpiderError::GeneralError(format!("Initial requests task failed: {}", e))
213 })?;
214
215 #[cfg(feature = "checkpoint")]
216 if let Some(path) = &checkpoint_path
217 && let Err(e) = save_checkpoint::<S>(path, scheduler_checkpoint, &pipelines).await
218 {
219 error!("Final checkpoint save failed: {}", e);
220 }
221
222 info!("Closing item pipelines...");
224 let closing_futures: Vec<_> = pipelines.iter().map(|p| p.close()).collect();
225 join_all(closing_futures).await;
226
227 info!("Crawl finished successfully.");
228 Ok(())
229 }
230}
231
232fn spawn_initial_requests_task<S>(
233 scheduler: Arc<Scheduler>,
234 spider: Arc<Mutex<S>>,
235) -> tokio::task::JoinHandle<()>
236where
237 S: Spider + 'static,
238 S::Item: ScrapedItem,
239{
240 tokio::spawn(async move {
241 match spider.lock().await.start_requests() {
242 Ok(requests) => {
243 for mut req in requests {
244 req.url.set_fragment(None);
245 match scheduler.enqueue_request(req).await {
246 Ok(_) => {}
247 Err(e) => {
248 error!("Failed to enqueue initial request: {}", e);
249 }
250 }
251 }
252 }
253 Err(e) => error!("Failed to create start requests: {}", e),
254 }
255 })
256}
257
258#[allow(clippy::too_many_arguments)]
259fn spawn_downloader_task<S, C>(
260 scheduler: Arc<Scheduler>,
261 req_rx: AsyncReceiver<Request>,
262 downloader: Arc<dyn Downloader<Client = C> + Send + Sync>,
263 middlewares: Arc<Mutex<Vec<Box<dyn Middleware<C> + Send + Sync>>>>,
264 state: Arc<CrawlerState>,
265 res_tx: AsyncSender<Response>,
266 max_concurrent_downloads: usize,
267) -> tokio::task::JoinHandle<()>
268where
269 S: Spider + 'static,
270 S::Item: ScrapedItem,
271 C: Send + Sync + 'static,
272{
273 let semaphore = Arc::new(Semaphore::new(max_concurrent_downloads));
274 let mut tasks = JoinSet::new();
275
276 tokio::spawn(async move {
277 while let Ok(request) = req_rx.recv().await {
278 let permit = match semaphore.clone().acquire_owned().await {
279 Ok(permit) => permit,
280 Err(_) => {
281 warn!("Semaphore closed, shutting down downloader actor.");
282 break;
283 }
284 };
285
286 state.in_flight_requests.fetch_add(1, Ordering::SeqCst);
287 let downloader_clone = Arc::clone(&downloader);
288 let middlewares_clone = Arc::clone(&middlewares);
289 let res_tx_clone = res_tx.clone();
290 let state_clone = Arc::clone(&state);
291 let scheduler_clone = Arc::clone(&scheduler);
292
293 tasks.spawn(async move {
294 let mut processed_request = request;
295 let mut early_returned_response: Option<Response> = None;
296
297 for mw in middlewares_clone.lock().await.iter_mut() {
299 match mw.process_request(downloader_clone.client(), processed_request.clone()).await {
300 Ok(MiddlewareAction::Continue(req)) => {
301 processed_request = req;
302 }
303 Ok(MiddlewareAction::Retry(req, delay)) => {
304 tokio::time::sleep(delay).await;
305 if scheduler_clone.enqueue_request(*req).await.is_err() {
306 error!("Failed to re-enqueue retried request.");
307 }
308 return;
309 }
310 Ok(MiddlewareAction::Drop) => {
311 debug!("Request dropped by middleware.");
312 return;
313 }
314 Ok(MiddlewareAction::ReturnResponse(resp)) => {
315 early_returned_response = Some(resp);
316 break;
317 }
318 Err(e) => {
319 error!("Request middleware error: {:?}", e);
320 return;
321 }
322 }
323 }
324
325 let mut response = match early_returned_response {
327 Some(resp) => resp,
328 None => match downloader_clone.download(processed_request).await {
329 Ok(resp) => resp,
330 Err(e) => {
331 error!("Download error: {:?}", e);
332 return;
333 }
334 },
335 };
336
337 for mw in middlewares_clone.lock().await.iter_mut().rev() {
339 match mw.process_response(response.clone()).await {
340 Ok(MiddlewareAction::Continue(res)) => {
341 response = res;
342 }
343 Ok(MiddlewareAction::Retry(req, delay)) => {
344 tokio::time::sleep(delay).await;
345 if scheduler_clone.enqueue_request(*req).await.is_err() {
346 error!("Failed to re-enqueue retried request.");
347 }
348 return;
349 }
350 Ok(MiddlewareAction::Drop) => {
351 debug!("Response dropped by middleware.");
352 return;
353 }
354 Ok(MiddlewareAction::ReturnResponse(_)) => {
355 debug!("ReturnResponse action encountered in process_response; this is unexpected.");
356 continue;
357 }
358 Err(e) => {
359 error!("Response middleware error: {:?}", e);
360 return;
361 }
362 }
363 }
364
365 if res_tx_clone.send(response).await.is_err() {
366 error!("Response channel closed, cannot send parsed response.");
367 }
368
369 state_clone.in_flight_requests.fetch_sub(1, Ordering::SeqCst);
370 drop(permit);
371 });
372 }
373 while let Some(res) = tasks.join_next().await {
374 if let Err(e) = res {
375 error!("A download task failed: {:?}", e);
376 }
377 }
378 })
379}
380
381fn spawn_parser_task<S>(
382 scheduler: Arc<Scheduler>,
383 spider: Arc<Mutex<S>>,
384 state: Arc<CrawlerState>,
385 res_rx: AsyncReceiver<Response>,
386 item_tx: AsyncSender<S::Item>,
387 parser_workers: usize,
388) -> tokio::task::JoinHandle<()>
389where
390 S: Spider + 'static,
391 S::Item: ScrapedItem,
392{
393 let mut tasks = JoinSet::new();
394 let internal_parse_tx: AsyncSender<Response>;
395 let internal_parse_rx: AsyncReceiver<Response>;
396 (internal_parse_tx, internal_parse_rx) = bounded_async(parser_workers * 2);
397
398 for _ in 0..parser_workers {
400 let internal_parse_rx_clone = internal_parse_rx.clone();
401 let spider_clone = Arc::clone(&spider);
402 let scheduler_clone = Arc::clone(&scheduler);
403 let item_tx_clone = item_tx.clone();
404 let state_clone = Arc::clone(&state);
405
406 tasks.spawn(async move {
407 while let Ok(response) = internal_parse_rx_clone.recv().await {
408 debug!("Parsing response from {}", response.url);
409 match spider_clone.lock().await.parse(response).await {
410 Ok(outputs) => {
411 process_crawl_outputs::<S>(
412 outputs,
413 scheduler_clone.clone(),
414 item_tx_clone.clone(),
415 state_clone.clone(),
416 )
417 .await;
418 }
419 Err(e) => error!("Spider parsing error: {:?}", e),
420 }
421 state_clone.parsing_responses.fetch_sub(1, Ordering::SeqCst);
422 }
423 });
424 }
425
426 tokio::spawn(async move {
427 while let Ok(response) = res_rx.recv().await {
428 state.parsing_responses.fetch_add(1, Ordering::SeqCst);
429 if internal_parse_tx.send(response).await.is_err() {
430 error!("Internal parse channel closed, cannot send response to parser worker.");
431 state.parsing_responses.fetch_sub(1, Ordering::SeqCst);
432 }
433 }
434
435 drop(internal_parse_tx);
436
437 while let Some(res) = tasks.join_next().await {
438 if let Err(e) = res {
439 error!("A parsing worker task failed: {:?}", e);
440 }
441 }
442 })
443}
444
445async fn process_crawl_outputs<S>(
446 outputs: ParseOutput<S::Item>,
447 scheduler: Arc<Scheduler>,
448 item_tx: AsyncSender<S::Item>,
449 state: Arc<CrawlerState>,
450) where
451 S: Spider + 'static,
452 S::Item: ScrapedItem,
453{
454 let (items, requests) = outputs.into_parts();
455 info!(
456 "Processed {} requests and {} items from spider output.",
457 requests.len(),
458 items.len()
459 );
460
461 let mut request_error_total = 0;
462 for request in requests {
463 match scheduler.enqueue_request(request).await {
464 Ok(_) => {}
465 Err(_) => {
466 request_error_total += 1;
467 }
468 }
469 }
470 if request_error_total > 0 {
471 error!(
472 "Failed to enqueue {} requests: scheduler channel closed.",
473 request_error_total
474 );
475 }
476
477 let mut item_error_total = 0;
478 for item in items {
479 state.processing_items.fetch_add(1, Ordering::SeqCst);
480 if item_tx.send(item).await.is_err() {
481 item_error_total += 1;
482 state.processing_items.fetch_sub(1, Ordering::SeqCst);
483 }
484 }
485 if item_error_total > 0 {
486 error!(
487 "Failed to send {} scraped items: channel closed.",
488 item_error_total
489 );
490 }
491}
492
493fn spawn_item_processor_task<S>(
494 state: Arc<CrawlerState>,
495 item_rx: AsyncReceiver<S::Item>,
496 pipelines: Arc<Vec<Box<dyn Pipeline<S::Item>>>>,
497 max_concurrent_pipelines: usize,
498) -> tokio::task::JoinHandle<()>
499where
500 S: Spider + 'static,
501 S::Item: ScrapedItem,
502{
503 let mut tasks = JoinSet::new();
504 let semaphore = Arc::new(Semaphore::new(max_concurrent_pipelines));
505 tokio::spawn(async move {
506 while let Ok(item) = item_rx.recv().await {
507 let permit = match semaphore.clone().acquire_owned().await {
508 Ok(p) => p,
509 Err(_) => {
510 warn!("Semaphore closed, shutting down item processor actor.");
511 break;
512 }
513 };
514
515 let state_clone = Arc::clone(&state);
516 let pipelines_clone = Arc::clone(&pipelines);
517 tasks.spawn(async move {
518 let mut item_to_process = Some(item);
519 for pipeline in pipelines_clone.iter() {
520 if let Some(item) = item_to_process.take() {
521 match pipeline.process_item(item).await {
522 Ok(Some(next_item)) => item_to_process = Some(next_item),
523 Ok(None) => break,
524 Err(e) => {
525 error!("Pipeline error for {}: {:?}", pipeline.name(), e);
526 break;
527 }
528 }
529 } else {
530 break;
531 }
532 }
533 state_clone.processing_items.fetch_sub(1, Ordering::SeqCst);
534 drop(permit);
535 });
536 }
537 while let Some(res) = tasks.join_next().await {
538 if let Err(e) = res {
539 error!("An item processing task failed: {:?}", e);
540 }
541 }
542 })
543}