Skip to main content

spider_lib/
crawler.rs

1//! The core Crawler implementation for the `spider-lib` framework.
2//!
3//! This module defines the `Crawler` struct, which acts as the central orchestrator
4//! for the web scraping process. It ties together the scheduler, downloader,
5//! middlewares, spiders, and item pipelines to execute a crawl. The crawler
6//! manages the lifecycle of requests and items, handles concurrency, supports
7//! checkpointing for fault tolerance, and collects statistics for monitoring.
8//!
9//! It utilizes a task-based asynchronous model, spawning distinct tasks for
10//! handling initial requests, downloading web pages, parsing responses, and
11//! processing scraped items.
12
13use crate::downloader::Downloader;
14use crate::error::SpiderError;
15use crate::item::{ParseOutput, ScrapedItem};
16use crate::middleware::{Middleware, MiddlewareAction};
17use crate::pipeline::Pipeline;
18use crate::request::Request;
19use crate::response::Response;
20use crate::scheduler::Scheduler;
21use crate::spider::Spider;
22use crate::state::CrawlerState;
23use crate::stats::StatCollector;
24use anyhow::Result;
25use futures_util::future::join_all;
26use kanal::{AsyncReceiver, AsyncSender, bounded_async};
27use tokio::sync::Semaphore;
28use tokio::task::JoinSet;
29use tracing::{debug, error, info, warn};
30
31#[cfg(feature = "checkpoint")]
32use crate::checkpoint::save_checkpoint;
33#[cfg(feature = "checkpoint")]
34use std::path::PathBuf;
35use std::sync::atomic::Ordering;
36use std::sync::Arc;
37use std::time::Duration;
38use tokio::sync::Mutex;
39
40#[cfg(feature = "middleware-cookies")]
41use cookie_store::CookieStore;
42
43/// The central orchestrator for the web scraping process, handling requests, responses, items, concurrency, checkpointing, and statistics collection.
44pub struct Crawler<S: Spider, C> {
45    scheduler: Arc<Scheduler>,
46    req_rx: AsyncReceiver<Request>,
47    stats: Arc<StatCollector>, // Added
48    downloader: Arc<dyn Downloader<Client = C> + Send + Sync>,
49    middlewares: Vec<Box<dyn Middleware<C> + Send + Sync>>,
50    spider: Arc<Mutex<S>>,
51    item_pipelines: Vec<Box<dyn Pipeline<S::Item>>>,
52    max_concurrent_downloads: usize,
53    parser_workers: usize,
54    max_concurrent_pipelines: usize,
55    #[cfg(feature = "checkpoint")]
56    checkpoint_path: Option<PathBuf>,
57    #[cfg(feature = "checkpoint")]
58    checkpoint_interval: Option<Duration>,
59    #[cfg(feature = "middleware-cookies")]
60    pub cookie_store: Arc<Mutex<CookieStore>>,
61}
62
63impl<S, C> Crawler<S, C>
64where
65    S: Spider + 'static,
66    S::Item: ScrapedItem,
67    C: Send + Sync + 'static,
68{
69    /// Creates a new `Crawler` instance with the given components and configuration.
70    #[allow(clippy::too_many_arguments)]
71    pub(crate) fn new(
72        scheduler: Arc<Scheduler>,
73        req_rx: AsyncReceiver<Request>,
74        downloader: Arc<dyn Downloader<Client = C> + Send + Sync>,
75        middlewares: Vec<Box<dyn Middleware<C> + Send + Sync>>,
76        spider: S,
77        item_pipelines: Vec<Box<dyn Pipeline<S::Item>>>,
78        max_concurrent_downloads: usize,
79        parser_workers: usize,
80        max_concurrent_pipelines: usize,
81        #[cfg(feature = "checkpoint")] checkpoint_path: Option<PathBuf>,
82        #[cfg(feature = "checkpoint")] checkpoint_interval: Option<Duration>,
83        stats: Arc<StatCollector>,
84        #[cfg(feature = "middleware-cookies")] cookie_store: Arc<Mutex<CookieStore>>,
85    ) -> Self {
86        Crawler {
87            scheduler,
88            req_rx,
89            stats,
90            downloader,
91            middlewares,
92            spider: Arc::new(Mutex::new(spider)),
93            item_pipelines,
94            max_concurrent_downloads,
95            parser_workers,
96            max_concurrent_pipelines,
97            #[cfg(feature = "checkpoint")]
98            checkpoint_path,
99            #[cfg(feature = "checkpoint")]
100            checkpoint_interval,
101            #[cfg(feature = "middleware-cookies")]
102            cookie_store,
103        }
104    }
105
106    /// Starts the crawl, orchestrating the scraping process, managing tasks, handling shutdown, checkpointing, and logging statistics.
107    pub async fn start_crawl(self) -> Result<(), SpiderError> {
108        info!("Crawler starting crawl");
109
110        let Crawler {
111            scheduler,
112            req_rx,
113            stats,
114            downloader,
115            middlewares,
116            spider,
117            item_pipelines,
118            max_concurrent_downloads,
119            parser_workers,
120            max_concurrent_pipelines,
121            #[cfg(feature = "checkpoint")]
122            checkpoint_path,
123            #[cfg(feature = "checkpoint")]
124            checkpoint_interval,
125            #[cfg(feature = "middleware-cookies")]
126            cookie_store,
127        } = self;
128
129        let state = CrawlerState::new();
130        let pipelines = Arc::new(item_pipelines);
131        let channel_capacity = max_concurrent_downloads * 2;
132
133        let (res_tx, res_rx) = bounded_async(channel_capacity);
134        let (item_tx, item_rx) = bounded_async(channel_capacity);
135
136        let initial_requests_task =
137            spawn_initial_requests_task::<S>(scheduler.clone(), spider.clone(), stats.clone());
138
139        let downloader_task = spawn_downloader_task::<S, C>(
140            scheduler.clone(),
141            req_rx,
142            downloader,
143            Arc::new(Mutex::new(middlewares)),
144            state.clone(),
145            res_tx.clone(),
146            max_concurrent_downloads,
147            stats.clone(),
148        );
149
150        let parser_task = spawn_parser_task::<S>(
151            scheduler.clone(),
152            spider.clone(),
153            state.clone(),
154            res_rx,
155            item_tx.clone(),
156            parser_workers,
157            stats.clone(),
158        );
159
160        let item_processor_task = spawn_item_processor_task::<S>(
161            state.clone(),
162            item_rx,
163            pipelines.clone(),
164            max_concurrent_pipelines,
165            stats.clone(),
166        );
167
168        #[cfg(feature = "checkpoint")]
169        if let (Some(path), Some(interval)) = (&checkpoint_path, checkpoint_interval) {
170            let scheduler_clone = scheduler.clone();
171            let pipelines_clone = pipelines.clone();
172            let path_clone = path.clone();
173            #[cfg(feature = "middleware-cookies")]
174            let cookie_store_clone = cookie_store.clone();
175
176            tokio::spawn(async move {
177                let mut interval_timer = tokio::time::interval(interval);
178                interval_timer.tick().await;
179                loop {
180                    tokio::select! {
181                        _ = interval_timer.tick() => {
182                            if let Ok(scheduler_checkpoint) = scheduler_clone.snapshot().await {
183                                #[cfg(not(feature = "middleware-cookies"))]
184                                let save_result = save_checkpoint::<S>(&path_clone, scheduler_checkpoint, &pipelines_clone).await;
185                                #[cfg(feature = "middleware-cookies")]
186                                let save_result = save_checkpoint::<S>(&path_clone, scheduler_checkpoint, &pipelines_clone, &cookie_store_clone).await;
187
188                                if let Err(e) = save_result {
189                                    error!("Periodic checkpoint save failed: {}", e);
190                                }
191                            }
192                        }
193                    }
194                }
195            });
196        }
197
198        tokio::select! {
199            _ = tokio::signal::ctrl_c() => {
200                info!("Ctrl-C received, initiating graceful shutdown.");
201            }
202            _ = async {
203                loop {
204                    if scheduler.is_idle() && state.is_idle() {
205                        tokio::time::sleep(Duration::from_millis(50)).await;
206                        if scheduler.is_idle() && state.is_idle() {
207                            break;
208                        }
209                    }
210                    tokio::time::sleep(Duration::from_millis(100)).await;
211                }
212            } => {
213                info!("Crawl has become idle, initiating shutdown.");
214            }
215        }
216
217        info!("Initiating actor shutdowns.");
218
219        #[cfg(feature = "checkpoint")]
220        let scheduler_checkpoint = scheduler.snapshot().await?;
221
222        drop(res_tx);
223        drop(item_tx);
224
225        scheduler.shutdown().await?;
226
227        item_processor_task
228            .await
229            .map_err(|e| SpiderError::GeneralError(format!("Item processor task failed: {}", e)))?;
230
231        parser_task
232            .await
233            .map_err(|e| SpiderError::GeneralError(format!("Parser task failed: {}", e)))?;
234
235        downloader_task
236            .await
237            .map_err(|e| SpiderError::GeneralError(format!("Downloader task failed: {}", e)))?;
238
239        initial_requests_task.await.map_err(|e| {
240            SpiderError::GeneralError(format!("Initial requests task failed: {}", e))
241        })?;
242
243        #[cfg(feature = "checkpoint")]
244        if let Some(path) = &checkpoint_path {
245            #[cfg(not(feature = "middleware-cookies"))]
246            let result = save_checkpoint::<S>(path, scheduler_checkpoint, &pipelines).await;
247            #[cfg(feature = "middleware-cookies")]
248            let result = save_checkpoint::<S>(path, scheduler_checkpoint, &pipelines, &cookie_store).await;
249
250            if let Err(e) = result
251            {
252                error!("Final checkpoint save failed: {}", e);
253            }
254        }
255
256        // Close all pipelines
257        info!("Closing item pipelines...");
258        let closing_futures: Vec<_> = pipelines.iter().map(|p| p.close()).collect();
259        join_all(closing_futures).await;
260
261        info!("Crawl finished successfully.");
262        Ok(())
263    }
264
265    /// Returns a cloned Arc to the `StatCollector` instance used by this crawler.
266    ///
267    /// This allows programmatic access to the collected statistics at any time during or after the crawl.
268    pub fn get_stats(&self) -> Arc<StatCollector> {
269        Arc::clone(&self.stats)
270    }
271}
272
273fn spawn_initial_requests_task<S>(
274    scheduler: Arc<Scheduler>,
275    spider: Arc<Mutex<S>>,
276    stats: Arc<StatCollector>,
277) -> tokio::task::JoinHandle<()>
278where
279    S: Spider + 'static,
280    S::Item: ScrapedItem,
281{
282    tokio::spawn(async move {
283        match spider.lock().await.start_requests() {
284            Ok(requests) => {
285                for mut req in requests {
286                    req.url.set_fragment(None);
287                    match scheduler.enqueue_request(req).await {
288                        Ok(_) => {
289                            stats.increment_requests_enqueued();
290                        }
291                        Err(e) => {
292                            error!("Failed to enqueue initial request: {}", e);
293                        }
294                    }
295                }
296            }
297            Err(e) => error!("Failed to create start requests: {}", e),
298        }
299    })
300}
301
302#[allow(clippy::too_many_arguments)]
303fn spawn_downloader_task<S, C>(
304    scheduler: Arc<Scheduler>,
305    req_rx: AsyncReceiver<Request>,
306    downloader: Arc<dyn Downloader<Client = C> + Send + Sync>,
307    middlewares: Arc<Mutex<Vec<Box<dyn Middleware<C> + Send + Sync>>>>,
308    state: Arc<CrawlerState>,
309    res_tx: AsyncSender<Response>,
310    max_concurrent_downloads: usize,
311    stats: Arc<StatCollector>,
312) -> tokio::task::JoinHandle<()>
313where
314    S: Spider + 'static,
315    S::Item: ScrapedItem,
316    C: Send + Sync + 'static,
317{
318    let semaphore = Arc::new(Semaphore::new(max_concurrent_downloads));
319    let mut tasks = JoinSet::new();
320
321    tokio::spawn(async move {
322        while let Ok(request) = req_rx.recv().await {
323            let permit = match semaphore.clone().acquire_owned().await {
324                Ok(permit) => permit,
325                Err(_) => {
326                    warn!("Semaphore closed, shutting down downloader actor.");
327                    break;
328                }
329            };
330
331            state.in_flight_requests.fetch_add(1, Ordering::SeqCst);
332            let downloader_clone = Arc::clone(&downloader);
333            let middlewares_clone = Arc::clone(&middlewares);
334            let res_tx_clone = res_tx.clone();
335            let state_clone = Arc::clone(&state);
336            let scheduler_clone = Arc::clone(&scheduler);
337            let stats_clone = Arc::clone(&stats);
338
339            tasks.spawn(async move {
340                let mut early_returned_response: Option<Response> = None;
341
342                // Process request middlewares
343                let mut processed_request_opt = Some(request); 
344                for mw in middlewares_clone.lock().await.iter_mut() {
345                    let req_to_process = processed_request_opt.take().expect("Request should be present before middleware processing");
346                    match mw.process_request(downloader_clone.client(), req_to_process).await {
347                        Ok(MiddlewareAction::Continue(req)) => {
348                            processed_request_opt = Some(req);
349                        }
350                        Ok(MiddlewareAction::Retry(req, delay)) => {
351                            stats_clone.increment_requests_retried();
352                            tokio::time::sleep(delay).await;
353                            if scheduler_clone.enqueue_request(*req).await.is_err() {
354                                error!("Failed to re-enqueue retried request.");
355                            }
356                            return;
357                        }
358                        Ok(MiddlewareAction::Drop) => {
359                            stats_clone.increment_requests_dropped();
360                            debug!("Request dropped by middleware.");
361                            return;
362                        }
363                        Ok(MiddlewareAction::ReturnResponse(resp)) => {
364                            early_returned_response = Some(resp);
365                            break;
366                        }
367                        Err(e) => {
368                            error!("Request middleware error: {:?}", e);
369                            return;
370                        }
371                    }
372                }
373
374                // Download or use early response
375                // If early_returned_response is Some, request was consumed by a middleware
376                // If early_returned_response is None, processed_request_opt must contain the request
377                let response = match early_returned_response { 
378                    Some(resp) => {
379                        if resp.cached {
380                            stats_clone.increment_responses_from_cache();
381                        }
382                        stats_clone.increment_requests_succeeded();
383                        stats_clone.increment_responses_received();
384                        stats_clone.record_response_status(resp.status.as_u16());
385                        resp
386                    },
387                    None => {
388                        let request_for_download = processed_request_opt.expect("Request must be available for download if not handled by middleware or early returned response");
389                        stats_clone.increment_requests_sent();
390                        match downloader_clone.download(request_for_download).await { 
391                            Ok(resp) => {
392                                stats_clone.increment_requests_succeeded();
393                                stats_clone.increment_responses_received();
394                                stats_clone.record_response_status(resp.status.as_u16());
395                                // Corrected for Response.body being Bytes, not Option<Body>
396                                stats_clone.add_bytes_downloaded(resp.body.len());
397                                resp
398                            },
399                            Err(e) => {
400                                stats_clone.increment_requests_failed();
401                                error!("Download error: {:?}", e);
402                                return;
403                            }
404                        }
405                    },
406                };
407
408                // Process response middlewares
409                let mut processed_response_opt = Some(response); 
410                for mw in middlewares_clone.lock().await.iter_mut().rev() {
411                    let res_to_process = processed_response_opt.take().expect("Response should be present before middleware processing"); // Take ownership for current middleware
412                    match mw.process_response(res_to_process).await {
413                        Ok(MiddlewareAction::Continue(res)) => {
414                            processed_response_opt = Some(res); // Reassign for next middleware
415                        }
416                        Ok(MiddlewareAction::Retry(req, delay)) => {
417                            stats_clone.increment_requests_retried();
418                            tokio::time::sleep(delay).await;
419                            if scheduler_clone.enqueue_request(*req).await.is_err() {
420                                error!("Failed to re-enqueue retried request.");
421                            }
422                            return;
423                        }
424                        Ok(MiddlewareAction::Drop) => {
425                            stats_clone.increment_requests_dropped();
426                            debug!("Response dropped by middleware.");
427                            return;
428                        }
429                        Ok(MiddlewareAction::ReturnResponse(_)) => {
430                            // This indicates the middleware has fully handled or consumed the response.
431                            // Effectively, the response is dropped from further processing by this chain.
432                            debug!("ReturnResponse action encountered in process_response; this is unexpected and effectively drops the response for further processing.");
433                            processed_response_opt = None; 
434                            break; 
435                        }
436                        Err(e) => {
437                            error!("Response middleware error: {:?}", e);
438                            return;
439                        }
440                    }
441                }
442
443                // Send the final processed response, if it still exists
444                if let Some(final_response) = processed_response_opt
445                    && res_tx_clone.send(final_response).await.is_err() {
446                    error!("Response channel closed, cannot send parsed response.");
447                }
448
449                state_clone.in_flight_requests.fetch_sub(1, Ordering::SeqCst);
450                drop(permit);
451            });
452        }
453        while let Some(res) = tasks.join_next().await {
454            if let Err(e) = res {
455                error!("A download task failed: {:?}", e);
456            }
457        }
458    })
459}
460
461fn spawn_parser_task<S>(
462    scheduler: Arc<Scheduler>,
463    spider: Arc<Mutex<S>>,
464    state: Arc<CrawlerState>,
465    res_rx: AsyncReceiver<Response>,
466    item_tx: AsyncSender<S::Item>,
467    parser_workers: usize,
468    stats: Arc<StatCollector>,
469) -> tokio::task::JoinHandle<()>
470where
471    S: Spider + 'static,
472    S::Item: ScrapedItem,
473{
474    let mut tasks = JoinSet::new();
475    let internal_parse_tx: AsyncSender<Response>;
476    let internal_parse_rx: AsyncReceiver<Response>;
477    (internal_parse_tx, internal_parse_rx) = bounded_async(parser_workers * 2);
478
479            // Spawn N parsing worker tasks
480
481        for _ in 0..parser_workers {
482
483            let internal_parse_rx_clone = internal_parse_rx.clone();
484
485            let spider_clone = Arc::clone(&spider);
486
487            let scheduler_clone = Arc::clone(&scheduler);
488
489            let item_tx_clone = item_tx.clone();
490
491            let state_clone = Arc::clone(&state);
492
493            let stats_clone = Arc::clone(&stats);
494
495    
496
497            tasks.spawn(async move {
498
499                while let Ok(response) = internal_parse_rx_clone.recv().await {
500
501                    debug!("Parsing response from {}", response.url);
502
503                    match spider_clone.lock().await.parse(response).await {
504
505                        Ok(outputs) => {
506
507                                                        process_crawl_outputs::<S>(
508
509                                                            outputs,
510
511                                                            scheduler_clone.clone(),
512
513                                                            item_tx_clone.clone(),
514
515                                                            state_clone.clone(),
516
517                                                            stats_clone.clone(),
518
519                                                    )
520                        .await;
521                    }
522                    Err(e) => error!("Spider parsing error: {:?}", e),
523                }
524                state_clone.parsing_responses.fetch_sub(1, Ordering::SeqCst);
525            }
526        });
527    }
528
529    tokio::spawn(async move {
530        while let Ok(response) = res_rx.recv().await {
531            state.parsing_responses.fetch_add(1, Ordering::SeqCst);
532            if internal_parse_tx.send(response).await.is_err() {
533                error!("Internal parse channel closed, cannot send response to parser worker.");
534                state.parsing_responses.fetch_sub(1, Ordering::SeqCst);
535            }
536        }
537
538        drop(internal_parse_tx);
539
540        while let Some(res) = tasks.join_next().await {
541            if let Err(e) = res {
542                error!("A parsing worker task failed: {:?}", e);
543            }
544        }
545    })
546}
547
548async fn process_crawl_outputs<S>(
549    outputs: ParseOutput<S::Item>,
550    scheduler: Arc<Scheduler>,
551    item_tx: AsyncSender<S::Item>,
552    state: Arc<CrawlerState>,
553    stats: Arc<StatCollector>,
554) where
555    S: Spider + 'static,
556    S::Item: ScrapedItem,
557{
558    let (items, requests) = outputs.into_parts();
559    info!(
560        "Processed {} requests and {} items from spider output.",
561        requests.len(),
562        items.len()
563    );
564
565    stats.increment_items_scraped();
566
567    let mut request_error_total = 0;
568    for request in requests {
569        match scheduler.enqueue_request(request).await {
570            Ok(_) => {
571                // Stat: requests_enqueued
572                stats.increment_requests_enqueued();
573            }
574            Err(_) => {
575                request_error_total += 1;
576            }
577        }
578    }
579    if request_error_total > 0 {
580        error!(
581            "Failed to enqueue {} requests: scheduler channel closed.",
582            request_error_total
583        );
584    }
585
586    let mut item_error_total = 0;
587    for item in items {
588        state.processing_items.fetch_add(1, Ordering::SeqCst);
589        if item_tx.send(item).await.is_err() {
590            item_error_total += 1;
591            state.processing_items.fetch_sub(1, Ordering::SeqCst);
592        }
593    }
594    if item_error_total > 0 {
595        error!(
596            "Failed to send {} scraped items: channel closed.",
597            item_error_total
598        );
599    }
600}
601
602fn spawn_item_processor_task<S>(
603    state: Arc<CrawlerState>,
604    item_rx: AsyncReceiver<S::Item>,
605    pipelines: Arc<Vec<Box<dyn Pipeline<S::Item>>>>,
606    max_concurrent_pipelines: usize,
607    stats: Arc<StatCollector>,
608) -> tokio::task::JoinHandle<()>
609where
610    S: Spider + 'static,
611    S::Item: ScrapedItem,
612{
613    let mut tasks = JoinSet::new();
614    let semaphore = Arc::new(Semaphore::new(max_concurrent_pipelines));
615    tokio::spawn(async move {
616        while let Ok(item) = item_rx.recv().await {
617            let permit = match semaphore.clone().acquire_owned().await {
618                Ok(p) => p,
619                Err(_) => {
620                    warn!("Semaphore closed, shutting down item processor actor.");
621                    break;
622                }
623            };
624
625            let state_clone = Arc::clone(&state);
626            let pipelines_clone = Arc::clone(&pipelines);
627            let stats_clone = Arc::clone(&stats);
628            tasks.spawn(async move {
629                let mut item_to_process = Some(item);
630                for pipeline in pipelines_clone.iter() {
631                    if let Some(item) = item_to_process.take() {
632                        match pipeline.process_item(item).await {
633                            Ok(Some(next_item)) => item_to_process = Some(next_item),
634                            Ok(None) => {
635                                stats_clone.increment_items_dropped_by_pipeline();
636                                break;
637                            }
638                            Err(e) => {
639                                error!("Pipeline error for {}: {:?}", pipeline.name(), e);
640                                stats_clone.increment_items_dropped_by_pipeline();
641                                break;
642                            }
643                        }
644                    } else {
645                        break;
646                    }
647                }
648                // If item survived all pipelines, it's processed
649                if item_to_process.is_some() {
650                    stats_clone.increment_items_processed();
651                }
652                state_clone.processing_items.fetch_sub(1, Ordering::SeqCst);
653                drop(permit);
654            });
655        }
656        while let Some(res) = tasks.join_next().await {
657            if let Err(e) = res {
658                error!("An item processing task failed: {:?}", e);
659            }
660        }
661    })
662}