Skip to main content

spider_core/
scheduler.rs

1//! # Scheduler Module
2//!
3//! Implements the request scheduler for managing the crawling frontier and duplicate detection.
4//!
5//! ## Overview
6//!
7//! The `Scheduler` is a central component that coordinates the web crawling process
8//! by managing the queue of pending requests and tracking visited URLs to prevent
9//! duplicate processing. It uses an actor-like design pattern with internal message
10//! processing for thread-safe operations.
11//!
12//! ## Key Responsibilities
13//!
14//! - **Request Queue Management**: Maintains a queue of pending requests to be processed
15//! - **Duplicate Detection**: Tracks visited URLs using Bloom Filter and LRU cache for efficiency
16//! - **Request Salvaging**: Handles failed enqueuing attempts to prevent request loss
17//! - **State Snapshots**: Provides checkpointing capabilities for crawl resumption
18//! - **Concurrent Access**: Thread-safe operations for multi-threaded crawling
19//!
20//! ## Architecture
21//!
22//! The scheduler operates asynchronously using an internal message queue to handle
23//! operations like request enqueuing, URL marking, and state snapshots. It combines
24//! a Bloom Filter for fast preliminary duplicate checks with an LRU cache for
25//! definitive tracking, optimizing performance when handling millions of URLs.
26//!
27//! ## Example
28//!
29//! ```rust,ignore
30//! use spider_core::Scheduler;
31//! use spider_util::request::Request;
32//! use url::Url;
33//!
34//! let (scheduler, request_receiver) = Scheduler::new(None);
35//!
36//! // Enqueue a request
37//! let request = Request::new(Url::parse("https://example.com").unwrap());
38//! scheduler.enqueue_request(request).await?;
39//!
40//! // Mark a URL as visited
41//! scheduler.send_mark_as_visited("unique_fingerprint".to_string()).await?;
42//! ```
43
44#[cfg(feature = "checkpoint")]
45use spider_util::constants::DEFAULT_VISITED_CACHE_SIZE;
46#[cfg(feature = "checkpoint")]
47use crate::SchedulerCheckpoint;
48
49#[cfg(not(feature = "checkpoint"))]
50use spider_util::constants::MAX_PENDING_REQUESTS;
51
52use spider_util::constants::{
53    BLOOM_FILTER_CAPACITY, BLOOM_FILTER_HASH_FUNCTIONS,
54    VISITED_URL_CACHE_CAPACITY, VISITED_URL_CACHE_TTL_SECS,
55};
56use spider_util::error::SpiderError;
57use spider_util::request::Request;
58use crossbeam::queue::SegQueue;
59use kanal::{AsyncReceiver, AsyncSender, bounded_async, unbounded_async};
60use moka::sync::Cache;
61use std::sync::Arc;
62use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
63use log::{debug, error, info, trace, warn};
64
65enum SchedulerMessage {
66    Enqueue(Box<Request>),
67    MarkAsVisited(String),
68    MarkAsVisitedBatch(Vec<String>),
69    Shutdown,
70}
71
72use spider_util::bloom::BloomFilter;
73
74use tokio::sync::Notify;
75
76pub struct Scheduler {
77    queue: SegQueue<Request>,
78    visited: Cache<String, bool>,
79    bloom: std::sync::Arc<parking_lot::RwLock<BloomFilter>>,
80    buffer: Arc<std::sync::Mutex<Vec<String>>>,
81    notify: Arc<Notify>,
82    tx: AsyncSender<SchedulerMessage>,
83    pending: AtomicUsize,
84    salvaged: SegQueue<Request>,
85    pub(crate) is_shutting_down: AtomicBool,
86    max_pending: usize,
87}
88
89impl Scheduler {
90    /// Creates a new `Scheduler` and returns a tuple containing the scheduler and a request receiver.
91    #[cfg(feature = "checkpoint")]
92    pub fn new(
93        initial_state: Option<SchedulerCheckpoint>,
94    ) -> (Arc<Self>, AsyncReceiver<Request>) {
95        let (tx, rx_internal) = unbounded_async();
96        let (tx_out, rx_out) = bounded_async(100);
97
98        let queue: SegQueue<Request>;
99        let visited: Cache<String, bool>;
100        let pending: AtomicUsize;
101        let salvaged: SegQueue<Request>;
102
103        if let Some(state) = initial_state {
104            info!(
105                "Initializing scheduler from checkpoint with {} requests, {} visited URLs, and {} salvaged requests.",
106                state.request_queue.len(),
107                state.visited_urls.len(),
108                state.salvaged_requests.len(),
109            );
110            let pend = state.request_queue.len() + state.salvaged_requests.len();
111            queue = SegQueue::new();
112            for request in state.request_queue {
113                queue.push(request);
114            }
115
116            visited = Cache::builder()
117                .max_capacity(VISITED_URL_CACHE_CAPACITY)
118                .time_to_idle(std::time::Duration::from_secs(VISITED_URL_CACHE_TTL_SECS))
119                .eviction_listener(|_key, _value, _cause| {})
120                .build();
121            for url in state.visited_urls {
122                visited.insert(url, true);
123            }
124
125            pending = AtomicUsize::new(pend);
126            salvaged = SegQueue::new();
127            for request in state.salvaged_requests {
128                salvaged.push(request);
129            }
130        } else {
131            queue = SegQueue::new();
132            visited = Cache::builder().max_capacity(DEFAULT_VISITED_CACHE_SIZE).build();
133            pending = AtomicUsize::new(0);
134            salvaged = SegQueue::new();
135        }
136
137        let buffer = Arc::new(std::sync::Mutex::new(Vec::new()));
138        let notify = Arc::new(Notify::new());
139
140        let scheduler = Arc::new(Scheduler {
141            queue,
142            visited,
143            bloom: std::sync::Arc::new(parking_lot::RwLock::new(BloomFilter::new(
144                BLOOM_FILTER_CAPACITY,
145                BLOOM_FILTER_HASH_FUNCTIONS,
146            ))),
147            buffer: buffer.clone(),
148            notify: notify.clone(),
149            tx,
150            pending,
151            salvaged,
152            is_shutting_down: AtomicBool::new(false),
153            max_pending: 30000,
154        });
155
156        let scheduler_bloom = Arc::clone(&scheduler);
157        let buffer_clone = buffer.clone();
158        let notify_clone = notify.clone();
159        tokio::spawn(async move {
160            scheduler_bloom.flush_buffer(buffer_clone, notify_clone).await;
161        });
162
163        let scheduler_task = Arc::clone(&scheduler);
164        tokio::spawn(async move {
165            scheduler_task.run_loop(rx_internal, tx_out).await;
166        });
167
168        (scheduler, rx_out)
169    }
170
171    #[cfg(not(feature = "checkpoint"))]
172    pub fn new(
173        _initial_state: Option<()>,
174    ) -> (Arc<Self>, AsyncReceiver<Request>) {
175        let (tx, rx_internal) = unbounded_async();
176        let (tx_out, rx_out) = bounded_async(100);
177
178        let queue = SegQueue::new();
179        let visited = Cache::builder()
180            .max_capacity(VISITED_URL_CACHE_CAPACITY)
181            .time_to_idle(std::time::Duration::from_secs(VISITED_URL_CACHE_TTL_SECS))
182            .eviction_listener(|_key, _value, _cause| {})
183            .build();
184        let pending = AtomicUsize::new(0);
185        let salvaged = SegQueue::new();
186
187        let buffer = Arc::new(std::sync::Mutex::new(Vec::new()));
188        let notify = Arc::new(Notify::new());
189
190        let scheduler = Arc::new(Scheduler {
191            queue,
192            visited,
193            bloom: std::sync::Arc::new(parking_lot::RwLock::new(BloomFilter::new(
194                BLOOM_FILTER_CAPACITY,
195                BLOOM_FILTER_HASH_FUNCTIONS,
196            ))),
197            buffer: buffer.clone(),
198            notify: notify.clone(),
199            tx,
200            pending,
201            salvaged,
202            is_shutting_down: AtomicBool::new(false),
203            max_pending: MAX_PENDING_REQUESTS,
204        });
205
206        let scheduler_bloom = Arc::clone(&scheduler);
207        let buffer_clone = buffer.clone();
208        let notify_clone = notify.clone();
209        tokio::spawn(async move {
210            scheduler_bloom.flush_buffer(buffer_clone, notify_clone).await;
211        });
212
213        let scheduler_task = Arc::clone(&scheduler);
214        tokio::spawn(async move {
215            scheduler_task.run_loop(rx_internal, tx_out).await;
216        });
217
218        (scheduler, rx_out)
219    }
220
221    async fn run_loop(
222        &self,
223        rx_internal: AsyncReceiver<SchedulerMessage>,
224        tx_out: AsyncSender<Request>,
225    ) {
226        info!(
227            "Scheduler run_loop started with max pending: {}",
228            self.max_pending
229        );
230        loop {
231            if let Ok(Some(msg)) = rx_internal.try_recv() {
232                trace!("Processing pending internal message");
233                if !self.handle_message(Ok(msg)).await {
234                    break;
235                }
236                continue;
237            }
238
239            let request = if !tx_out.is_closed() && !self.is_idle() {
240                self.queue.pop()
241            } else {
242                None
243            };
244
245            if let Some(request) = request {
246                trace!("Sending request to crawler: {}", request.url);
247                tokio::select! {
248                    send_res = tx_out.send(request) => {
249                        if send_res.is_err() {
250                            error!("Crawler receiver dropped. Scheduler can no longer send requests.");
251                        } else {
252                            trace!("Successfully sent request to crawler");
253                        }
254                        self.pending.fetch_sub(1, Ordering::SeqCst);
255                    },
256                    recv_res = rx_internal.recv() => {
257                        trace!("Received internal message while sending request");
258                        if !self.handle_message(recv_res).await {
259                            break;
260                        }
261                        continue;
262                    }
263                }
264            } else {
265                trace!("No pending requests, waiting for internal message");
266                if !self.handle_message(rx_internal.recv().await).await {
267                    break;
268                }
269            }
270        }
271        info!(
272            "Scheduler run_loop finished with {} pending requests remaining.",
273            self.pending.load(Ordering::SeqCst)
274        );
275    }
276
277    async fn handle_message(&self, msg: Result<SchedulerMessage, kanal::ReceiveError>) -> bool {
278        match msg {
279            Ok(SchedulerMessage::Enqueue(boxed_request)) => {
280                let request = *boxed_request;
281                trace!("Enqueuing request: {}", request.url);
282                self.queue.push(request);
283                self.pending.fetch_add(1, Ordering::SeqCst);
284                true
285            }
286            Ok(SchedulerMessage::MarkAsVisited(fingerprint)) => {
287                trace!("Marking URL fingerprint as visited: {}", fingerprint);
288
289                // Insert into visited cache first (clone needed for cache)
290                self.visited.insert(fingerprint.clone(), true);
291
292                // Log before moving fingerprint
293                debug!("Marked URL as visited: {}", fingerprint);
294
295                // Then move fingerprint into buffer (no clone needed)
296                {
297                    let mut buffer = self.buffer.lock().unwrap();
298                    buffer.push(fingerprint);
299                    if buffer.len() >= 100 {
300                        self.notify.notify_one();
301                    }
302                }
303
304                true
305            }
306            Ok(SchedulerMessage::MarkAsVisitedBatch(mut fingerprints)) => {
307                let count = fingerprints.len();
308                trace!("Marking {} URL fingerprints as visited in batch", count);
309                
310                // Insert all fingerprints into visited cache
311                for fingerprint in &fingerprints {
312                    self.visited.insert(fingerprint.clone(), true);
313                }
314
315                // Then extend buffer with the fingerprints (no clone needed)
316                {
317                    let mut buffer = self.buffer.lock().unwrap();
318                    buffer.append(&mut fingerprints);
319                    if buffer.len() >= 100 {
320                        self.notify.notify_one();
321                    }
322                }
323
324                debug!("Marked {} URLs as visited in batch", count);
325                true
326            }
327            Ok(SchedulerMessage::Shutdown) => {
328                info!("Scheduler received shutdown signal. Exiting run_loop.");
329                self.is_shutting_down.store(true, Ordering::SeqCst);
330                self.flush_buffer_now();
331                false
332            }
333            Err(_) => {
334                warn!("Scheduler internal message channel closed. Exiting run_loop.");
335                self.is_shutting_down.store(true, Ordering::SeqCst);
336                false
337            }
338        }
339    }
340
341    #[cfg(feature = "checkpoint")]
342    pub async fn snapshot(&self) -> Result<SchedulerCheckpoint, SpiderError> {
343        let visited_urls = dashmap::DashSet::new();
344        for entry in self.visited.iter() {
345            let (key, _) = entry;
346            visited_urls.insert(key.as_ref().clone());
347        }
348
349        let mut request_queue = std::collections::VecDeque::new();
350        let mut temp_requests = Vec::new();
351
352        while let Some(request) = self.queue.pop() {
353            temp_requests.push(request);
354        }
355
356        for request in temp_requests.into_iter() {
357            request_queue.push_back(request.clone());
358            if !self.is_shutting_down.load(Ordering::SeqCst) {
359                self.queue.push(request);
360            }
361        }
362
363        let mut salvaged_requests = std::collections::VecDeque::new();
364        let mut temp_salvaged = Vec::new();
365
366        while let Some(request) = self.salvaged.pop() {
367            temp_salvaged.push(request);
368        }
369
370        for request in temp_salvaged.into_iter() {
371            salvaged_requests.push_back(request.clone());
372            if !self.is_shutting_down.load(Ordering::SeqCst) {
373                self.salvaged.push(request);
374            }
375        }
376
377        Ok(SchedulerCheckpoint {
378            request_queue,
379            visited_urls,
380            salvaged_requests,
381        })
382    }
383
384    #[cfg(not(feature = "checkpoint"))]
385    pub async fn snapshot(&self) -> Result<(), SpiderError> {
386        Ok(())
387    }
388
389    pub async fn enqueue_request(&self, request: Request) -> Result<(), SpiderError> {
390        if !self.should_enqueue(&request) {
391            trace!("Request already visited, skipping: {}", request.url);
392            return Ok(());
393        }
394
395        let pending = self.pending.load(Ordering::SeqCst);
396        if pending >= self.max_pending {
397            warn!(
398                "Maximum pending requests reached ({}), request dropped due to backpressure: {}",
399                self.max_pending, request.url
400            );
401            return Err(SpiderError::GeneralError(
402                "Scheduler at maximum capacity, request dropped due to backpressure.".into(),
403            ));
404        }
405
406        trace!("Enqueuing request: {}", request.url);
407        if self
408            .tx
409            .send(SchedulerMessage::Enqueue(Box::new(request.clone())))
410            .await
411            .is_err()
412        {
413            if !self.is_shutting_down.load(Ordering::SeqCst) {
414                error!(
415                    "Scheduler internal message channel is closed. Salvaging request: {}",
416                    request.url
417                );
418            }
419            self.salvaged.push(request);
420            return Err(SpiderError::GeneralError(
421                "Scheduler internal channel closed, request salvaged.".into(),
422            ));
423        }
424
425        trace!("Successfully enqueued request: {}", request.url);
426        Ok(())
427    }
428
429    pub async fn shutdown(&self) -> Result<(), SpiderError> {
430        self.is_shutting_down.store(true, Ordering::SeqCst);
431
432        if !self.tx.is_closed() {
433            self.tx
434                .send(SchedulerMessage::Shutdown)
435                .await
436                .map_err(|e| {
437                    SpiderError::GeneralError(format!(
438                        "Scheduler: Failed to send shutdown signal: {}",
439                        e
440                    ))
441                })
442        } else {
443            debug!("Scheduler internal channel already closed, skipping shutdown signal");
444            Ok(())
445        }
446    }
447
448    pub async fn mark_visited(&self, fingerprint: String) -> Result<(), SpiderError> {
449        trace!(
450            "Sending MarkAsVisited message for fingerprint: {}",
451            fingerprint
452        );
453        self.tx
454            .send(SchedulerMessage::MarkAsVisited(fingerprint))
455            .await
456            .map_err(|e| {
457                if !self.is_shutting_down.load(Ordering::SeqCst) {
458                    error!("Scheduler internal message channel is closed. Failed to mark URL as visited: {}", e);
459                }
460                SpiderError::GeneralError(format!(
461                    "Scheduler: Failed to send MarkAsVisited message: {}",
462                    e
463                ))
464            })
465    }
466
467    pub async fn mark_visited_batch(&self, fingerprints: Vec<String>) -> Result<(), SpiderError> {
468        if fingerprints.is_empty() {
469            return Ok(());
470        }
471
472        trace!(
473            "Sending MarkAsVisitedBatch message for {} fingerprints",
474            fingerprints.len()
475        );
476        self.tx
477            .send(SchedulerMessage::MarkAsVisitedBatch(fingerprints))
478            .await
479            .map_err(|e| {
480                if !self.is_shutting_down.load(Ordering::SeqCst) {
481                    error!("Scheduler internal message channel is closed. Failed to mark URLs as visited in batch: {}", e);
482                }
483                SpiderError::GeneralError(format!(
484                    "Scheduler: Failed to send MarkAsVisitedBatch message: {}",
485                    e
486                ))
487            })
488    }
489
490    pub fn is_visited(&self, fingerprint: &str) -> bool {
491        if !self.bloom.read().might_contain(fingerprint) {
492            return false;
493        }
494
495        {
496            let buffer = self.buffer.lock().unwrap();
497            if buffer.iter().any(|item| item == fingerprint) {
498                return true;
499            }
500        }
501
502        self.visited.contains_key(fingerprint)
503    }
504
505    fn flush_buffer_now(&self) {
506        let mut buffer = self.buffer.lock().unwrap();
507        if !buffer.is_empty() {
508            let items: Vec<String> = buffer.drain(..).collect();
509            drop(buffer);
510
511            let mut bloom = self.bloom.write();
512            for item in items {
513                bloom.add(&item);
514            }
515        }
516    }
517
518    async fn flush_buffer(
519        &self,
520        _buffer: Arc<std::sync::Mutex<Vec<String>>>,
521        notify: Arc<Notify>,
522    ) {
523        loop {
524            tokio::select! {
525                _ = notify.notified() => {
526                    self.flush_buffer_now();
527                }
528                _ = tokio::time::sleep(tokio::time::Duration::from_millis(100)) => {
529                    self.flush_buffer_now();
530                }
531            }
532
533            tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
534        }
535    }
536
537    pub fn should_enqueue(&self, request: &Request) -> bool {
538        let fingerprint = request.fingerprint();
539        !self.is_visited(&fingerprint)
540    }
541
542    #[inline]
543    pub fn len(&self) -> usize {
544        self.pending.load(Ordering::SeqCst)
545    }
546
547    #[inline]
548    pub fn is_empty(&self) -> bool {
549        self.len() == 0
550    }
551
552    #[inline]
553    pub fn is_idle(&self) -> bool {
554        self.is_empty()
555    }
556}