Skip to main content

spider_lib/
crawler.rs

1use crate::checkpoint::{Checkpoint, SchedulerCheckpoint};
2use crate::downloader::Downloader;
3use crate::error::SpiderError;
4use crate::item::{ParseOutput, ScrapedItem};
5use crate::middleware::{Middleware, MiddlewareAction};
6use crate::pipeline::Pipeline;
7use crate::request::Request;
8use crate::response::Response;
9use crate::scheduler::Scheduler;
10use crate::spider::Spider;
11use crate::state::CrawlerState;
12use anyhow::Result;
13use futures_util::future::join_all;
14use kanal::{AsyncReceiver, AsyncSender, bounded_async};
15use std::collections::{HashMap, VecDeque};
16use std::fs;
17use std::path::{Path, PathBuf};
18use std::sync::Arc;
19use std::sync::atomic::Ordering;
20use std::time::Duration;
21use tokio::sync::Mutex;
22use tokio::sync::Semaphore;
23use tokio::task::JoinSet;
24use tracing::{debug, error, info, warn};
25
26pub struct Crawler<S: Spider, C> {
27    scheduler: Arc<Scheduler>,
28    req_rx: AsyncReceiver<Request>,
29    downloader: Arc<dyn Downloader<Client = C> + Send + Sync>,
30    middlewares: Vec<Box<dyn Middleware<C> + Send + Sync>>,
31    spider: Arc<Mutex<S>>,
32    item_pipelines: Vec<Box<dyn Pipeline<S::Item>>>,
33    max_concurrent_downloads: usize,
34    parser_workers: usize,
35    max_concurrent_pipelines: usize,
36    checkpoint_path: Option<PathBuf>,
37    checkpoint_interval: Option<Duration>,
38}
39
40impl<S, C> Crawler<S, C>
41where
42    S: Spider + 'static,
43    S::Item: ScrapedItem,
44    C: Send + Sync + 'static,
45{
46    #[allow(clippy::too_many_arguments)]
47    pub(crate) fn new(
48        scheduler: Arc<Scheduler>,
49        req_rx: AsyncReceiver<Request>,
50        downloader: Arc<dyn Downloader<Client = C> + Send + Sync>,
51        middlewares: Vec<Box<dyn Middleware<C> + Send + Sync>>,
52        spider: S,
53        item_pipelines: Vec<Box<dyn Pipeline<S::Item>>>,
54        max_concurrent_downloads: usize,
55        parser_workers: usize,
56        max_concurrent_pipelines: usize,
57        checkpoint_path: Option<PathBuf>,
58        checkpoint_interval: Option<Duration>,
59    ) -> Self {
60        Crawler {
61            scheduler,
62            req_rx,
63            downloader,
64            middlewares,
65            spider: Arc::new(Mutex::new(spider)),
66            item_pipelines,
67            max_concurrent_downloads,
68            parser_workers,
69            max_concurrent_pipelines,
70            checkpoint_path,
71            checkpoint_interval,
72        }
73    }
74
75    /// Starts the crawl.
76    pub async fn start_crawl(self) -> Result<(), SpiderError> {
77        info!("Crawler starting crawl");
78
79        let Crawler {
80            scheduler,
81            req_rx,
82            downloader,
83            middlewares,
84            spider,
85            item_pipelines,
86            max_concurrent_downloads,
87            parser_workers,
88            max_concurrent_pipelines,
89            checkpoint_path,
90            checkpoint_interval,
91        } = self;
92
93        let state = CrawlerState::new();
94        let pipelines = Arc::new(item_pipelines);
95        let channel_capacity = max_concurrent_downloads * 2;
96
97        let (res_tx, res_rx) = bounded_async(channel_capacity);
98        let (item_tx, item_rx) = bounded_async(channel_capacity);
99
100        let (salvaged_requests_tx, salvaged_requests_rx) = bounded_async(channel_capacity);
101
102        let initial_requests_task = spawn_initial_requests_task::<S>(
103            scheduler.clone(),
104            spider.clone(),
105            salvaged_requests_tx.clone(),
106        );
107
108        let downloader_task = spawn_downloader_task::<S, C>(
109            scheduler.clone(),
110            req_rx,
111            downloader,
112            Arc::new(Mutex::new(middlewares)),
113            state.clone(),
114            res_tx.clone(),
115            max_concurrent_downloads,
116            salvaged_requests_tx.clone(),
117        );
118
119        let parser_task = spawn_parser_task::<S>(
120            scheduler.clone(),
121            spider.clone(),
122            state.clone(),
123            res_rx,
124            item_tx.clone(),
125            parser_workers,
126            salvaged_requests_tx.clone(),
127        );
128
129        let item_processor_task = spawn_item_processor_task::<S>(
130            state.clone(),
131            item_rx,
132            pipelines.clone(),
133            max_concurrent_pipelines,
134        );
135
136        if let (Some(path), Some(interval)) = (&checkpoint_path, checkpoint_interval) {
137            let scheduler_clone = scheduler.clone();
138            let pipelines_clone = pipelines.clone();
139            let path_clone = path.clone();
140            let salvaged_requests_rx_clone = salvaged_requests_rx.clone();
141            tokio::spawn(async move {
142                let mut interval_timer = tokio::time::interval(interval);
143                interval_timer.tick().await;
144                loop {
145                    tokio::select! {
146                                    _ = interval_timer.tick() => {
147                                        if let Ok(scheduler_checkpoint) = scheduler_clone.snapshot().await &&
148                                            let Err(e) = save_checkpoint::<S>(&path_clone, scheduler_checkpoint, &pipelines_clone, salvaged_requests_rx_clone.clone()).await {
149                                                error!("Periodic checkpoint save failed: {}", e);
150                                        }
151                                    }
152                    }
153                }
154            });
155        }
156
157        tokio::select! {
158            _ = tokio::signal::ctrl_c() => {
159                info!("Ctrl-C received, initiating graceful shutdown.");
160            }
161            _ = async {
162                loop {
163                    if scheduler.is_idle() && state.is_idle() {
164                        tokio::time::sleep(Duration::from_millis(50)).await;
165                        if scheduler.is_idle() && state.is_idle() {
166                            break;
167                        }
168                    }
169                    tokio::time::sleep(Duration::from_millis(100)).await;
170                }
171            } => {
172                info!("Crawl has become idle, initiating shutdown.");
173            }
174        }
175
176        info!("Initiating actor shutdowns.");
177
178        let scheduler_checkpoint = scheduler.snapshot().await?;
179
180        drop(res_tx);
181        drop(item_tx);
182        drop(salvaged_requests_tx);
183
184        scheduler.shutdown().await?;
185
186        item_processor_task
187            .await
188            .map_err(|e| SpiderError::GeneralError(format!("Item processor task failed: {}", e)))?;
189
190        parser_task
191            .await
192            .map_err(|e| SpiderError::GeneralError(format!("Parser task failed: {}", e)))?;
193
194        downloader_task
195            .await
196            .map_err(|e| SpiderError::GeneralError(format!("Downloader task failed: {}", e)))?;
197
198        initial_requests_task.await.map_err(|e| {
199            SpiderError::GeneralError(format!("Initial requests task failed: {}", e))
200        })?;
201
202        // Save final checkpoint BEFORE closing pipelines
203        if let Some(path) = &checkpoint_path
204            && let Err(e) =
205                save_checkpoint::<S>(path, scheduler_checkpoint, &pipelines, salvaged_requests_rx)
206                    .await
207        {
208            error!("Final checkpoint save failed: {}", e);
209        }
210
211        // Close all pipelines
212        info!("Closing item pipelines...");
213        let closing_futures: Vec<_> = pipelines.iter().map(|p| p.close()).collect();
214        join_all(closing_futures).await;
215
216        info!("Crawl finished successfully.");
217        Ok(())
218    }
219}
220
221fn spawn_initial_requests_task<S>(
222    scheduler: Arc<Scheduler>,
223    spider: Arc<Mutex<S>>,
224    salvaged_requests_tx: AsyncSender<Request>,
225) -> tokio::task::JoinHandle<()>
226where
227    S: Spider + 'static,
228    S::Item: ScrapedItem,
229{
230    tokio::spawn(async move {
231        match spider.lock().await.start_requests() {
232            Ok(requests) => {
233                for mut req in requests {
234                    req.url.set_fragment(None);
235                    match scheduler.enqueue_request(req).await {
236                        Ok(_) => {}
237                        Err((req, e)) => {
238                            error!("Failed to enqueue initial request: {}", e);
239                            if salvaged_requests_tx.send(req).await.is_err() {
240                                error!("Failed to send salvaged request to channel.");
241                            }
242                        }
243                    }
244                }
245            }
246            Err(e) => error!("Failed to create start requests: {}", e),
247        }
248    })
249}
250
251async fn save_checkpoint<S: Spider>(
252    path: &Path,
253    mut scheduler_checkpoint: SchedulerCheckpoint,
254    pipelines: &Arc<Vec<Box<dyn Pipeline<S::Item>>>>,
255    salvaged_requests_rx: AsyncReceiver<Request>,
256) -> Result<(), SpiderError>
257where
258    S::Item: ScrapedItem,
259{
260    info!("Saving checkpoint to {:?}", path);
261
262    let mut pipelines_checkpoint_map = HashMap::new();
263    for pipeline in pipelines.iter() {
264        if let Some(state) = pipeline.get_state().await? {
265            pipelines_checkpoint_map.insert(pipeline.name().to_string(), state);
266        }
267    }
268
269    let mut salvaged_requests_vec: VecDeque<Request> = VecDeque::new();
270    while let Ok(req_option) = salvaged_requests_rx.try_recv() {
271        if let Some(req) = req_option {
272            salvaged_requests_vec.push_back(req);
273        }
274    }
275    scheduler_checkpoint
276        .request_queue
277        .extend(salvaged_requests_vec.drain(..));
278    if !salvaged_requests_vec.is_empty() {
279        warn!(
280            "Found {} salvaged requests during checkpoint. These have been added to the request queue.",
281            salvaged_requests_vec.len()
282        );
283    }
284
285    let checkpoint = Checkpoint {
286        scheduler: scheduler_checkpoint,
287        pipelines: pipelines_checkpoint_map,
288    };
289
290    // Write to a temporary file then rename for atomicity
291    let tmp_path = path.with_extension("tmp");
292    let encoded = rmp_serde::to_vec(&checkpoint)
293        .map_err(|e| SpiderError::GeneralError(format!("Failed to serialize checkpoint: {}", e)))?;
294    fs::write(&tmp_path, encoded).map_err(|e| {
295        SpiderError::GeneralError(format!(
296            "Failed to write checkpoint to temporary file: {}",
297            e
298        ))
299    })?;
300    fs::rename(&tmp_path, path).map_err(|e| {
301        SpiderError::GeneralError(format!("Failed to rename temporary checkpoint file: {}", e))
302    })?;
303
304    info!("Checkpoint saved successfully.");
305    Ok(())
306}
307
308#[allow(clippy::too_many_arguments)]
309fn spawn_downloader_task<S, C>(
310    scheduler: Arc<Scheduler>,
311    req_rx: AsyncReceiver<Request>,
312    downloader: Arc<dyn Downloader<Client = C> + Send + Sync>,
313    middlewares: Arc<Mutex<Vec<Box<dyn Middleware<C> + Send + Sync>>>>,
314    state: Arc<CrawlerState>,
315    res_tx: AsyncSender<Response>,
316    max_concurrent_downloads: usize,
317    salvaged_requests_tx: AsyncSender<Request>,
318) -> tokio::task::JoinHandle<()>
319where
320    S: Spider + 'static,
321    S::Item: ScrapedItem,
322    C: Send + Sync + 'static,
323{
324    let semaphore = Arc::new(Semaphore::new(max_concurrent_downloads));
325    let mut tasks = JoinSet::new();
326
327    tokio::spawn(async move {
328        while let Ok(request) = req_rx.recv().await {
329            let permit = match semaphore.clone().acquire_owned().await {
330                Ok(permit) => permit,
331                Err(_) => {
332                    warn!("Semaphore closed, shutting down downloader actor.");
333                    break;
334                }
335            };
336
337            state.in_flight_requests.fetch_add(1, Ordering::SeqCst);
338            let downloader_clone = Arc::clone(&downloader);
339            let middlewares_clone = Arc::clone(&middlewares);
340            let res_tx_clone = res_tx.clone();
341            let state_clone = Arc::clone(&state);
342            let scheduler_clone = Arc::clone(&scheduler);
343            let _salvaged_requests_tx_clone = salvaged_requests_tx.clone();
344
345            tasks.spawn(async move {
346                let mut processed_request = request;
347                let mut early_returned_response: Option<Response> = None;
348
349                // Process request middlewares
350                for mw in middlewares_clone.lock().await.iter_mut() {
351                    match mw.process_request(downloader_clone.client(), processed_request.clone()).await {
352                        Ok(MiddlewareAction::Continue(req)) => {
353                            processed_request = req;
354                        }
355                        Ok(MiddlewareAction::Retry(req, delay)) => {
356                            tokio::time::sleep(delay).await;
357                            if scheduler_clone.enqueue_request(*req).await.is_err() {
358                                error!("Failed to re-enqueue retried request.");
359                            }
360                            return;
361                        }
362                        Ok(MiddlewareAction::Drop) => {
363                            debug!("Request dropped by middleware.");
364                            return;
365                        }
366                        Ok(MiddlewareAction::ReturnResponse(resp)) => {
367                            early_returned_response = Some(resp);
368                            break;
369                        }
370                        Err(e) => {
371                            error!("Request middleware error: {:?}", e);
372                            return;
373                        }
374                    }
375                }
376
377                // Download or use early response
378                let mut response = match early_returned_response {
379                    Some(resp) => resp,
380                    None => match downloader_clone.download(processed_request).await {
381                        Ok(resp) => resp,
382                        Err(e) => {
383                            error!("Download error: {:?}", e);
384                            return;
385                        }
386                    },
387                };
388
389                // Process response middlewares
390                for mw in middlewares_clone.lock().await.iter_mut().rev() {
391                    match mw.process_response(response.clone()).await {
392                        Ok(MiddlewareAction::Continue(res)) => {
393                            response = res;
394                        }
395                        Ok(MiddlewareAction::Retry(req, delay)) => {
396                            tokio::time::sleep(delay).await;
397                            if scheduler_clone.enqueue_request(*req).await.is_err() {
398                                error!("Failed to re-enqueue retried request.");
399                            }
400                            return;
401                        }
402                        Ok(MiddlewareAction::Drop) => {
403                            debug!("Response dropped by middleware.");
404                            return;
405                        }
406                        Ok(MiddlewareAction::ReturnResponse(_)) => {
407                            debug!("ReturnResponse action encountered in process_response; this is unexpected.");
408                            continue;
409                        }
410                        Err(e) => {
411                            error!("Response middleware error: {:?}", e);
412                            return;
413                        }
414                    }
415                }
416
417                if res_tx_clone.send(response).await.is_err() {
418                    error!("Response channel closed, cannot send parsed response.");
419                }
420
421                state_clone.in_flight_requests.fetch_sub(1, Ordering::SeqCst);
422                drop(permit);
423            });
424        }
425        while let Some(res) = tasks.join_next().await {
426            if let Err(e) = res {
427                error!("A download task failed: {:?}", e);
428            }
429        }
430    })
431}
432
433fn spawn_parser_task<S>(
434    scheduler: Arc<Scheduler>,
435    spider: Arc<Mutex<S>>,
436    state: Arc<CrawlerState>,
437    res_rx: AsyncReceiver<Response>,
438    item_tx: AsyncSender<S::Item>,
439    parser_workers: usize,
440    salvaged_requests_tx: AsyncSender<Request>,
441) -> tokio::task::JoinHandle<()>
442where
443    S: Spider + 'static,
444    S::Item: ScrapedItem,
445{
446    let mut tasks = JoinSet::new();
447    let internal_parse_tx: AsyncSender<Response>;
448    let internal_parse_rx: AsyncReceiver<Response>;
449    (internal_parse_tx, internal_parse_rx) = bounded_async(parser_workers * 2);
450
451    // Spawn N parsing worker tasks
452    for _ in 0..parser_workers {
453        let internal_parse_rx_clone = internal_parse_rx.clone();
454        let spider_clone = Arc::clone(&spider);
455        let scheduler_clone = Arc::clone(&scheduler);
456        let item_tx_clone = item_tx.clone();
457        let state_clone = Arc::clone(&state);
458        let salvaged_requests_tx_clone = salvaged_requests_tx.clone();
459
460        tasks.spawn(async move {
461            while let Ok(response) = internal_parse_rx_clone.recv().await {
462                debug!("Parsing response from {}", response.url);
463                match spider_clone.lock().await.parse(response).await {
464                    Ok(outputs) => {
465                        process_crawl_outputs::<S>(
466                            outputs,
467                            scheduler_clone.clone(),
468                            item_tx_clone.clone(),
469                            state_clone.clone(),
470                            salvaged_requests_tx_clone.clone(),
471                        )
472                        .await;
473                    }
474                    Err(e) => error!("Spider parsing error: {:?}", e),
475                }
476                state_clone.parsing_responses.fetch_sub(1, Ordering::SeqCst);
477            }
478        });
479    }
480
481    tokio::spawn(async move {
482        while let Ok(response) = res_rx.recv().await {
483            state.parsing_responses.fetch_add(1, Ordering::SeqCst);
484            if internal_parse_tx.send(response).await.is_err() {
485                error!("Internal parse channel closed, cannot send response to parser worker.");
486                state.parsing_responses.fetch_sub(1, Ordering::SeqCst); // Decrement if send fails
487            }
488        }
489        // Drop the sender to signal worker tasks to shut down
490        drop(internal_parse_tx);
491
492        // Wait for all parsing tasks to complete
493        while let Some(res) = tasks.join_next().await {
494            if let Err(e) = res {
495                error!("A parsing worker task failed: {:?}", e);
496            }
497        }
498    })
499}
500
501async fn process_crawl_outputs<S>(
502    outputs: ParseOutput<S::Item>,
503    scheduler: Arc<Scheduler>,
504    item_tx: AsyncSender<S::Item>,
505    state: Arc<CrawlerState>,
506    salvaged_requests_tx: AsyncSender<Request>,
507) where
508    S: Spider + 'static,
509    S::Item: ScrapedItem,
510{
511    let (items, requests) = outputs.into_parts();
512    info!(
513        "Processed {} requests and {} items from spider output.",
514        requests.len(),
515        items.len()
516    );
517
518    let mut request_error_total = 0;
519    for request in requests {
520        match scheduler.enqueue_request(request).await {
521            Ok(_) => {}
522            Err((req, _)) => {
523                request_error_total += 1;
524                if salvaged_requests_tx.send(req).await.is_err() {
525                    error!(
526                        "Failed to send salvaged request to channel from process_crawl_outputs."
527                    );
528                }
529            }
530        }
531    }
532    if request_error_total > 0 {
533        error!(
534            "Failed to enqueue {} requests: scheduler channel closed.",
535            request_error_total
536        );
537    }
538
539    let mut item_error_total = 0;
540    for item in items {
541        state.processing_items.fetch_add(1, Ordering::SeqCst);
542        if item_tx.send(item).await.is_err() {
543            item_error_total += 1;
544            state.processing_items.fetch_sub(1, Ordering::SeqCst);
545        }
546    }
547    if item_error_total > 0 {
548        error!(
549            "Failed to send {} scraped items: channel closed.",
550            item_error_total
551        );
552    }
553}
554
555fn spawn_item_processor_task<S>(
556    state: Arc<CrawlerState>,
557    item_rx: AsyncReceiver<S::Item>,
558    pipelines: Arc<Vec<Box<dyn Pipeline<S::Item>>>>,
559    max_concurrent_pipelines: usize,
560) -> tokio::task::JoinHandle<()>
561where
562    S: Spider + 'static,
563    S::Item: ScrapedItem,
564{
565    let mut tasks = JoinSet::new();
566    let semaphore = Arc::new(Semaphore::new(max_concurrent_pipelines));
567    tokio::spawn(async move {
568        while let Ok(item) = item_rx.recv().await {
569            let permit = match semaphore.clone().acquire_owned().await {
570                Ok(p) => p,
571                Err(_) => {
572                    warn!("Semaphore closed, shutting down item processor actor.");
573                    break;
574                }
575            };
576
577            let state_clone = Arc::clone(&state);
578            let pipelines_clone = Arc::clone(&pipelines);
579            tasks.spawn(async move {
580                let mut item_to_process = Some(item);
581                for pipeline in pipelines_clone.iter() {
582                    if let Some(item) = item_to_process.take() {
583                        match pipeline.process_item(item).await {
584                            Ok(Some(next_item)) => item_to_process = Some(next_item),
585                            Ok(None) => break, // Item was dropped
586                            Err(e) => {
587                                error!("Pipeline error for {}: {:?}", pipeline.name(), e);
588                                break;
589                            }
590                        }
591                    } else {
592                        break;
593                    }
594                }
595                state_clone.processing_items.fetch_sub(1, Ordering::SeqCst);
596                drop(permit);
597            });
598        }
599        while let Some(res) = tasks.join_next().await {
600            if let Err(e) = res {
601                error!("An item processing task failed: {:?}", e);
602            }
603        }
604    })
605}