Skip to main content

spider_core/engine/
crawler.rs

1//! The core Crawler implementation for the `spider-lib` framework.
2//!
3//! This module defines the `Crawler` struct, which acts as the central orchestrator
4//! for the web scraping process. It ties together the scheduler, downloader,
5//! middlewares, spiders, and item pipelines to execute a crawl. The crawler
6//! manages the lifecycle of requests and items, handles concurrency, supports
7//! checkpointing for fault tolerance, and collects statistics for monitoring.
8//!
9//! It utilizes a task-based asynchronous model, spawning distinct tasks for
10//! handling initial requests, downloading web pages, parsing responses, and
11//! processing scraped items.
12
13use crate::Downloader;
14use crate::config::CrawlerConfig;
15use crate::engine::CrawlerContext;
16use crate::scheduler::Scheduler;
17use crate::spider::Spider;
18use crate::state::CrawlerState;
19use crate::stats::StatCollector;
20use anyhow::Result;
21#[cfg(feature = "live-stats")]
22use crossterm::{
23    cursor::{Hide, MoveToColumn, MoveUp, Show},
24    execute, queue,
25    terminal::{Clear, ClearType, size},
26};
27use futures_util::future::join_all;
28use kanal::{AsyncReceiver, bounded_async};
29use log::{debug, error, info, trace, warn};
30use spider_middleware::middleware::Middleware;
31use spider_pipeline::pipeline::Pipeline;
32use spider_util::error::SpiderError;
33use spider_util::item::ScrapedItem;
34use spider_util::request::Request;
35
36#[cfg(feature = "checkpoint")]
37use crate::checkpoint::save_checkpoint;
38#[cfg(feature = "checkpoint")]
39use crate::config::CheckpointConfig;
40
41#[cfg(feature = "live-stats")]
42use std::io::{IsTerminal, Write};
43use std::sync::Arc;
44use std::time::Duration;
45
46#[cfg(feature = "cookie-store")]
47use tokio::sync::RwLock;
48#[cfg(feature = "live-stats")]
49use tokio::sync::oneshot;
50#[cfg(feature = "live-stats")]
51use tokio::time::MissedTickBehavior;
52
53#[cfg(feature = "cookie-store")]
54use cookie_store::CookieStore;
55
56/// The central orchestrator for the web scraping process, handling requests, responses, items, concurrency, checkpointing, and statistics collection.
57pub struct Crawler<S: Spider, C> {
58    scheduler: Arc<Scheduler>,
59    req_rx: AsyncReceiver<Request>,
60    stats: Arc<StatCollector>,
61    downloader: Arc<dyn Downloader<Client = C> + Send + Sync>,
62    middlewares: Vec<Box<dyn Middleware<C> + Send + Sync>>,
63    spider: Arc<S>,
64    spider_state: Arc<S::State>,
65    pipelines: Vec<Box<dyn Pipeline<S::Item>>>,
66    config: CrawlerConfig,
67    #[cfg(feature = "checkpoint")]
68    checkpoint_config: CheckpointConfig,
69    #[cfg(feature = "cookie-store")]
70    pub cookie_store: Arc<RwLock<CookieStore>>,
71}
72
73impl<S, C> Crawler<S, C>
74where
75    S: Spider + 'static,
76    S::Item: ScrapedItem,
77    C: Send + Sync + Clone + 'static,
78{
79    #[allow(clippy::too_many_arguments)]
80    pub(crate) fn new(
81        scheduler: Arc<Scheduler>,
82        req_rx: AsyncReceiver<Request>,
83        downloader: Arc<dyn Downloader<Client = C> + Send + Sync>,
84        middlewares: Vec<Box<dyn Middleware<C> + Send + Sync>>,
85        spider: S,
86        pipelines: Vec<Box<dyn Pipeline<S::Item>>>,
87        config: CrawlerConfig,
88        #[cfg(feature = "checkpoint")] checkpoint_config: CheckpointConfig,
89        stats: Arc<StatCollector>,
90        #[cfg(feature = "cookie-store")] cookie_store: Arc<tokio::sync::RwLock<CookieStore>>,
91    ) -> Self {
92        Crawler {
93            scheduler,
94            req_rx,
95            stats,
96            downloader,
97            middlewares,
98            spider: Arc::new(spider),
99            spider_state: Arc::new(S::State::default()),
100            pipelines,
101            config,
102            #[cfg(feature = "checkpoint")]
103            checkpoint_config,
104            #[cfg(feature = "cookie-store")]
105            cookie_store,
106        }
107    }
108
109    pub async fn start_crawl(self) -> Result<(), SpiderError> {
110        info!(
111            "Crawler starting crawl with configuration: max_concurrent_downloads={}, parser_workers={}, max_concurrent_pipelines={}",
112            self.config.max_concurrent_downloads,
113            self.config.parser_workers,
114            self.config.max_concurrent_pipelines
115        );
116
117        let Crawler {
118            scheduler,
119            req_rx,
120            stats,
121            downloader,
122            middlewares,
123            spider,
124            spider_state,
125            pipelines,
126            config,
127            #[cfg(feature = "checkpoint")]
128            checkpoint_config,
129            #[cfg(feature = "cookie-store")]
130                cookie_store: _,
131        } = self;
132
133        let state = CrawlerState::new();
134        let pipelines = Arc::new(pipelines);
135
136        // Create aggregated context for efficient sharing across tasks
137        let ctx = CrawlerContext::new(
138            Arc::clone(&scheduler),
139            Arc::clone(&stats),
140            Arc::clone(&spider),
141            Arc::clone(&spider_state),
142            Arc::clone(&pipelines),
143        );
144
145        let channel_capacity = std::cmp::max(
146            config.max_concurrent_downloads * 3,
147            config.parser_workers * config.max_concurrent_pipelines * 2,
148        )
149        .max(config.channel_capacity);
150
151        trace!(
152            "Creating communication channels with capacity: {}",
153            channel_capacity
154        );
155        let (res_tx, res_rx) = bounded_async(channel_capacity);
156        let (item_tx, item_rx) = bounded_async(channel_capacity);
157
158        trace!("Spawning initial requests task");
159        let init_task = spawn_init_task(ctx.clone());
160
161        trace!("Initializing middleware manager");
162        let middlewares = super::SharedMiddlewareManager::new(middlewares);
163
164        trace!("Spawning downloader task");
165        let downloader_handle = super::spawn_downloader_task::<S, C>(
166            Arc::clone(&ctx.scheduler),
167            req_rx,
168            downloader,
169            middlewares,
170            state.clone(),
171            res_tx.clone(),
172            config.max_concurrent_downloads,
173            Arc::clone(&ctx.stats),
174        );
175
176        trace!("Spawning parser task");
177        let parser_handle = super::spawn_parser_task::<S>(
178            Arc::clone(&ctx.scheduler),
179            Arc::clone(&ctx.spider),
180            Arc::clone(&ctx.spider_state),
181            state.clone(),
182            res_rx,
183            item_tx.clone(),
184            config.parser_workers,
185            Arc::clone(&ctx.stats),
186        );
187
188        trace!("Spawning item processor task");
189        let processor_handle = super::spawn_item_processor_task::<S>(
190            state.clone(),
191            item_rx,
192            Arc::clone(&ctx.pipelines),
193            config.max_concurrent_pipelines,
194            Arc::clone(&ctx.stats),
195        );
196
197        #[cfg(feature = "live-stats")]
198        let mut live_stats_task: Option<(
199            oneshot::Sender<()>,
200            tokio::task::JoinHandle<()>,
201        )> = if config.live_stats && std::io::stdout().is_terminal() {
202            let (stop_tx, stop_rx) = oneshot::channel();
203            let stats_for_live = Arc::clone(&ctx.stats);
204            let interval = config.live_stats_interval;
205            let handle = tokio::spawn(async move {
206                run_live_stats(stats_for_live, interval, stop_rx).await;
207            });
208            Some((stop_tx, handle))
209        } else {
210            None
211        };
212        #[cfg(not(feature = "live-stats"))]
213        let mut live_stats_task: Option<((), ())> = None;
214
215        #[cfg(feature = "checkpoint")]
216        {
217            if let (Some(path), Some(interval)) =
218                (&checkpoint_config.path, checkpoint_config.interval)
219            {
220                let scheduler_cp = Arc::clone(&ctx.scheduler);
221                let pipelines_cp = Arc::clone(&ctx.pipelines);
222                let path_cp = path.clone();
223
224                #[cfg(feature = "cookie-store")]
225                let cookie_store_cp = self.cookie_store.clone();
226
227                #[cfg(not(feature = "cookie-store"))]
228                let _cookie_store_cp = ();
229
230                trace!(
231                    "Starting periodic checkpoint task with interval: {:?}",
232                    interval
233                );
234                tokio::spawn(async move {
235                    let mut interval_timer = tokio::time::interval(interval);
236                    interval_timer.tick().await;
237                    loop {
238                        tokio::select! {
239                            _ = interval_timer.tick() => {
240                                trace!("Checkpoint timer ticked, creating snapshot");
241                                if let Ok(scheduler_checkpoint) = scheduler_cp.snapshot().await {
242                                    debug!("Scheduler snapshot created, saving checkpoint to {:?}", path_cp);
243
244                                    #[cfg(feature = "cookie-store")]
245                                    let save_result = save_checkpoint::<S>(&path_cp, scheduler_checkpoint, &pipelines_cp, &cookie_store_cp).await;
246
247                                    #[cfg(not(feature = "cookie-store"))]
248                                    let save_result = save_checkpoint::<S>(&path_cp, scheduler_checkpoint, &pipelines_cp, &()).await;
249
250                                    if let Err(e) = save_result {
251                                        error!("Periodic checkpoint save failed: {}", e);
252                                    } else {
253                                        debug!("Periodic checkpoint saved successfully to {:?}", path_cp);
254                                    }
255                                } else {
256                                    warn!("Failed to create scheduler snapshot for checkpoint");
257                                }
258                            }
259                        }
260                    }
261                });
262            }
263        }
264
265        tokio::select! {
266            _ = tokio::signal::ctrl_c() => {
267                info!("Ctrl-C received, initiating graceful shutdown.");
268            }
269            _ = async {
270                loop {
271                    if scheduler.is_idle() && state.is_idle() {
272                        tokio::time::sleep(Duration::from_millis(50)).await;
273                        if scheduler.is_idle() && state.is_idle() {
274                            break;
275                        }
276                    }
277                    tokio::time::sleep(Duration::from_millis(100)).await;
278                }
279            } => {
280                info!("Crawl has become idle, initiating shutdown.");
281            }
282        };
283
284        trace!("Closing communication channels");
285        drop(res_tx);
286        drop(item_tx);
287
288        if let Err(e) = scheduler.shutdown().await {
289            error!("Error during scheduler shutdown: {}", e);
290        } else {
291            debug!("Scheduler shutdown initiated successfully");
292        }
293
294        let timeout_duration = Duration::from_secs(30);
295
296        let mut tasks = tokio::task::JoinSet::new();
297        tasks.spawn(processor_handle);
298        tasks.spawn(parser_handle);
299        tasks.spawn(downloader_handle);
300        tasks.spawn(init_task);
301
302        let results = tokio::time::timeout(timeout_duration, async {
303            let mut results = Vec::new();
304            while let Some(result) = tasks.join_next().await {
305                results.push(result);
306            }
307            results
308        })
309        .await;
310
311        let results = match results {
312            Ok(results) => {
313                trace!("All tasks completed during shutdown");
314                results
315            }
316            Err(_) => {
317                warn!(
318                    "Tasks did not complete within timeout ({}s), aborting remaining tasks and continuing with shutdown...",
319                    timeout_duration.as_secs()
320                );
321                tasks.abort_all();
322
323                tokio::time::sleep(Duration::from_millis(100)).await;
324
325                Vec::new()
326            }
327        };
328
329        for result in results {
330            if let Err(e) = result {
331                error!("Task failed during shutdown: {}", e);
332            } else {
333                trace!("Task completed successfully during shutdown");
334            }
335        }
336
337        #[cfg(feature = "live-stats")]
338        if let Some((stop_tx, handle)) = live_stats_task.take() {
339            let _ = stop_tx.send(());
340            let _ = handle.await;
341        }
342        #[cfg(not(feature = "live-stats"))]
343        let _ = live_stats_task.take();
344
345        #[cfg(feature = "checkpoint")]
346        {
347            if let Some(path) = &checkpoint_config.path {
348                debug!("Creating final checkpoint at {:?}", path);
349                let scheduler_checkpoint = scheduler.snapshot().await?;
350
351                #[cfg(feature = "cookie-store")]
352                let result = save_checkpoint::<S>(
353                    path,
354                    scheduler_checkpoint,
355                    &pipelines,
356                    &self.cookie_store,
357                )
358                .await;
359
360                #[cfg(not(feature = "cookie-store"))]
361                let result =
362                    save_checkpoint::<S>(path, scheduler_checkpoint, &pipelines, &()).await;
363
364                if let Err(e) = result {
365                    error!("Final checkpoint save failed: {}", e);
366                } else {
367                    info!("Final checkpoint saved successfully to {:?}", path);
368                }
369            }
370        }
371
372        info!("Closing item pipelines...");
373        let futures: Vec<_> = pipelines.iter().map(|p| p.close()).collect();
374        join_all(futures).await;
375        debug!("All item pipelines closed");
376
377        if config.live_stats {
378            println!("{}\n", stats.to_live_report_string());
379        } else {
380            info!("Crawl finished successfully\n{}", stats);
381        }
382        Ok(())
383    }
384
385    /// Returns a shared handle to crawler runtime statistics.
386    pub fn stats(&self) -> Arc<StatCollector> {
387        Arc::clone(&self.stats)
388    }
389
390    /// Returns a reference to the spider state.
391    pub fn state(&self) -> &S::State {
392        &self.spider_state
393    }
394
395    /// Returns an Arc clone of the spider state.
396    pub fn state_arc(&self) -> Arc<S::State> {
397        Arc::clone(&self.spider_state)
398    }
399}
400
401fn spawn_init_task<S, I>(ctx: CrawlerContext<S, I>) -> tokio::task::JoinHandle<()>
402where
403    S: Spider<Item = I> + 'static,
404    I: ScrapedItem,
405{
406    tokio::spawn(async move {
407        match ctx.spider.start_requests() {
408            Ok(requests) => {
409                for mut req in requests {
410                    req.url.set_fragment(None);
411                    match ctx.scheduler.enqueue_request(req).await {
412                        Ok(_) => {
413                            ctx.stats.increment_requests_enqueued();
414                        }
415                        Err(e) => {
416                            error!("Failed to enqueue initial request: {}", e);
417                        }
418                    }
419                }
420            }
421            Err(e) => error!("Failed to create start requests: {}", e),
422        }
423    })
424}
425
426#[cfg(feature = "live-stats")]
427struct LiveStatsRenderer {
428    previous_lines: Vec<String>,
429}
430
431#[cfg(feature = "live-stats")]
432impl LiveStatsRenderer {
433    fn new() -> Self {
434        let mut out = std::io::stdout();
435        let _ = execute!(out, Hide);
436        let _ = writeln!(out);
437        let _ = out.flush();
438        Self {
439            previous_lines: Vec::new(),
440        }
441    }
442
443    fn render(&mut self, content: &str) {
444        let mut out = std::io::stdout();
445        let terminal_width = Self::terminal_width();
446        let next_lines: Vec<String> = content
447            .lines()
448            .map(|line| Self::trim_to_width(line, terminal_width))
449            .collect();
450        let previous_len = self.previous_lines.len();
451        let next_len = next_lines.len();
452        let max_len = previous_len.max(next_len);
453
454        if previous_len > 1 {
455            let _ = queue!(out, MoveUp((previous_len - 1) as u16));
456        }
457        let _ = queue!(out, MoveToColumn(0));
458
459        for line_idx in 0..max_len {
460            let _ = queue!(out, MoveToColumn(0), Clear(ClearType::CurrentLine));
461
462            if let Some(line) = next_lines.get(line_idx) {
463                let _ = write!(out, "{}", line);
464            }
465
466            if line_idx + 1 < max_len {
467                let _ = writeln!(out);
468            }
469        }
470
471        let _ = out.flush();
472        self.previous_lines = next_lines;
473    }
474
475    fn terminal_width() -> usize {
476        size()
477            .map(|(width, _)| usize::from(width.max(1)))
478            .unwrap_or(usize::MAX)
479    }
480
481    fn trim_to_width(line: &str, width: usize) -> String {
482        if width == usize::MAX {
483            return line.to_owned();
484        }
485        line.chars().take(width).collect()
486    }
487
488    fn finish(self) {
489        let mut out = std::io::stdout();
490        self.clear_previous(&mut out);
491        let _ = execute!(out, MoveToColumn(0), Clear(ClearType::CurrentLine), Show);
492        let _ = out.flush();
493    }
494
495    fn clear_previous(&self, out: &mut std::io::Stdout) {
496        if self.previous_lines.is_empty() {
497            return;
498        }
499        let lines = self.previous_lines.len();
500        let _ = queue!(out, MoveToColumn(0));
501        if lines > 1 {
502            let _ = queue!(out, MoveUp((lines - 1) as u16));
503        }
504        for line_idx in 0..lines {
505            let _ = queue!(out, MoveToColumn(0), Clear(ClearType::CurrentLine));
506            if line_idx + 1 < lines {
507                let _ = writeln!(out);
508            }
509        }
510        if lines > 1 {
511            let _ = queue!(out, MoveUp((lines - 1) as u16));
512        }
513    }
514}
515
516#[cfg(feature = "live-stats")]
517async fn run_live_stats(
518    stats: Arc<StatCollector>,
519    interval: Duration,
520    mut stop_rx: oneshot::Receiver<()>,
521) {
522    let mut ticker = tokio::time::interval(interval);
523    ticker.set_missed_tick_behavior(MissedTickBehavior::Skip);
524    let mut renderer = LiveStatsRenderer::new();
525
526    loop {
527        tokio::select! {
528            _ = ticker.tick() => {
529                renderer.render(&stats.to_live_report_string());
530            }
531            _ = &mut stop_rx => {
532                break;
533            }
534        }
535    }
536
537    renderer.finish();
538}