1#[cfg(feature = "checkpoint")]
26use crate::SchedulerCheckpoint;
27#[cfg(feature = "checkpoint")]
28use spider_util::constants::DEFAULT_VISITED_CACHE_SIZE;
29
30use crossbeam::queue::SegQueue;
31use kanal::{AsyncReceiver, AsyncSender, bounded_async};
32use log::{debug, error, info, trace, warn};
33use moka::sync::Cache;
34use spider_util::constants::{
35 BLOOM_BUFFER_FLUSH_SIZE, BLOOM_FILTER_CAPACITY, BLOOM_FILTER_HASH_FUNCTIONS,
36 BLOOM_FLUSH_INTERVAL_MS, MAX_PENDING_REQUESTS, VISITED_URL_CACHE_CAPACITY,
37 VISITED_URL_CACHE_TTL_SECS,
38};
39use spider_util::error::SpiderError;
40use spider_util::request::Request;
41use std::cmp::Ordering as CmpOrdering;
42use std::collections::{BinaryHeap, HashSet};
43use std::sync::Arc;
44use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
45
46enum SchedulerMessage {
48 Enqueue(Arc<Request>),
50 EnqueueBatch(Vec<Arc<Request>>),
52 Requeue(Arc<Request>),
54 MarkAsVisited(String),
56 MarkAsVisitedBatch(Vec<String>),
58 Shutdown,
60}
61
62use spider_util::bloom::BloomFilter;
63
64use parking_lot::Mutex;
65use tokio::sync::Notify;
66
67#[derive(Debug, Clone)]
68struct ScheduledRequest {
69 request: Request,
70 priority: i32,
71 sequence: u64,
72}
73
74impl ScheduledRequest {
75 fn new(request: Request, sequence: u64) -> Self {
76 Self {
77 priority: request.priority(),
78 request,
79 sequence,
80 }
81 }
82}
83
84impl PartialEq for ScheduledRequest {
85 fn eq(&self, other: &Self) -> bool {
86 self.priority == other.priority && self.sequence == other.sequence
87 }
88}
89
90impl Eq for ScheduledRequest {}
91
92impl PartialOrd for ScheduledRequest {
93 fn partial_cmp(&self, other: &Self) -> Option<CmpOrdering> {
94 Some(self.cmp(other))
95 }
96}
97
98impl Ord for ScheduledRequest {
99 fn cmp(&self, other: &Self) -> CmpOrdering {
100 self.priority
101 .cmp(&other.priority)
102 .then_with(|| other.sequence.cmp(&self.sequence))
103 }
104}
105
106pub struct Scheduler {
128 queue: Mutex<BinaryHeap<ScheduledRequest>>,
130 visited: Cache<String, bool>,
132 bloom: std::sync::Arc<parking_lot::RwLock<BloomFilter>>,
134 buffer: Arc<Mutex<HashSet<String>>>,
136 notify: Arc<Notify>,
138 capacity_notify: Arc<Notify>,
140 tx: AsyncSender<SchedulerMessage>,
142 pending: AtomicUsize,
144 salvaged: SegQueue<Request>,
146 sequence: AtomicU64,
148 pub(crate) is_shutting_down: AtomicBool,
150 max_pending: usize,
152}
153
154impl Scheduler {
155 pub fn new(
184 #[cfg(feature = "checkpoint")] initial_state: Option<SchedulerCheckpoint>,
185 #[cfg(not(feature = "checkpoint"))] _initial_state: Option<()>,
186 max_pending_requests: usize,
187 ) -> (Arc<Self>, AsyncReceiver<Request>) {
188 let max_pending = max_pending_requests.clamp(1, MAX_PENDING_REQUESTS);
189 let (tx, rx_internal) = bounded_async(max_pending.saturating_mul(2).max(1));
190 let output_capacity = (max_pending / 8).clamp(256, 2048);
191 let (tx_out, rx_out) = bounded_async(output_capacity);
192
193 let queue: Mutex<BinaryHeap<ScheduledRequest>>;
194 let visited: Cache<String, bool>;
195 let pending: AtomicUsize;
196 let salvaged: SegQueue<Request>;
197 let sequence: AtomicU64;
198
199 #[cfg(feature = "checkpoint")]
200 {
201 if let Some(state) = initial_state {
202 info!(
203 "Initializing scheduler from checkpoint with {} requests, {} visited URLs, and {} salvaged requests.",
204 state.request_queue.len(),
205 state.visited_urls.len(),
206 state.salvaged_requests.len(),
207 );
208 let pend = state.request_queue.len() + state.salvaged_requests.len();
209 let mut restored_queue = BinaryHeap::with_capacity(state.request_queue.len());
210 let mut next_sequence = 0_u64;
211 for request in state.request_queue {
212 restored_queue.push(ScheduledRequest::new(request, next_sequence));
213 next_sequence += 1;
214 }
215 queue = Mutex::new(restored_queue);
216
217 visited = Cache::builder()
218 .max_capacity(VISITED_URL_CACHE_CAPACITY)
219 .time_to_idle(std::time::Duration::from_secs(VISITED_URL_CACHE_TTL_SECS))
220 .eviction_listener(|_key, _value, _cause| {})
221 .build();
222 for url in state.visited_urls {
223 visited.insert(url, true);
224 }
225
226 pending = AtomicUsize::new(pend);
227 salvaged = SegQueue::new();
228 for request in state.salvaged_requests {
229 salvaged.push(request);
230 }
231 sequence = AtomicU64::new(next_sequence);
232 } else {
233 queue = Mutex::new(BinaryHeap::new());
234 visited = Cache::builder()
235 .max_capacity(DEFAULT_VISITED_CACHE_SIZE)
236 .build();
237 pending = AtomicUsize::new(0);
238 salvaged = SegQueue::new();
239 sequence = AtomicU64::new(0);
240 }
241 }
242
243 #[cfg(not(feature = "checkpoint"))]
244 {
245 queue = Mutex::new(BinaryHeap::new());
246 visited = Cache::builder()
247 .max_capacity(VISITED_URL_CACHE_CAPACITY)
248 .time_to_idle(std::time::Duration::from_secs(VISITED_URL_CACHE_TTL_SECS))
249 .eviction_listener(|_key, _value, _cause| {})
250 .build();
251 pending = AtomicUsize::new(0);
252 salvaged = SegQueue::new();
253 sequence = AtomicU64::new(0);
254 }
255
256 let buffer = Arc::new(Mutex::new(HashSet::new()));
257 let notify = Arc::new(Notify::new());
258 let capacity_notify = Arc::new(Notify::new());
259
260 let scheduler = Arc::new(Scheduler {
261 queue,
262 visited,
263 bloom: std::sync::Arc::new(parking_lot::RwLock::new(BloomFilter::new(
264 BLOOM_FILTER_CAPACITY,
265 BLOOM_FILTER_HASH_FUNCTIONS,
266 ))),
267 buffer: buffer.clone(),
268 notify: notify.clone(),
269 capacity_notify: Arc::clone(&capacity_notify),
270 tx,
271 pending,
272 salvaged,
273 sequence,
274 is_shutting_down: AtomicBool::new(false),
275 max_pending,
276 });
277
278 let scheduler_bloom = Arc::clone(&scheduler);
279 let notify_clone = notify.clone();
280 tokio::spawn(async move {
281 scheduler_bloom.flush_buffer(notify_clone).await;
282 });
283
284 let scheduler_task = Arc::clone(&scheduler);
285 tokio::spawn(async move {
286 scheduler_task.run_loop(rx_internal, tx_out).await;
287 });
288
289 (scheduler, rx_out)
290 }
291
292 pub fn pending_count(&self) -> usize {
294 self.pending.load(Ordering::Acquire)
295 }
296
297 async fn run_loop(
298 &self,
299 rx_internal: AsyncReceiver<SchedulerMessage>,
300 tx_out: AsyncSender<Request>,
301 ) {
302 info!(
303 "Scheduler run_loop started with max pending: {}",
304 self.max_pending
305 );
306 loop {
307 if let Ok(Some(msg)) = rx_internal.try_recv() {
308 trace!("Processing pending internal message");
309 if !self.handle_message(Ok(msg)).await {
310 break;
311 }
312 continue;
313 }
314
315 let request = if !tx_out.is_closed() && !self.is_idle() {
316 self.pop_request()
317 } else {
318 None
319 };
320
321 if let Some(request) = request {
322 trace!("Sending request to crawler: {}", request.url);
323 tokio::select! {
324 send_res = tx_out.send(request) => {
325 if send_res.is_err() {
326 error!("Crawler receiver dropped. Scheduler can no longer send requests.");
327 } else {
328 trace!("Successfully sent request to crawler");
329 }
330 },
331 recv_res = rx_internal.recv() => {
332 trace!("Received internal message while sending request");
333 if !self.handle_message(recv_res).await {
334 break;
335 }
336 continue;
337 }
338 }
339 } else {
340 trace!("No pending requests, waiting for internal message");
341 if !self.handle_message(rx_internal.recv().await).await {
342 break;
343 }
344 }
345 }
346 info!(
347 "Scheduler run_loop finished with {} pending requests remaining.",
348 self.pending.load(Ordering::SeqCst)
349 );
350 }
351
352 async fn handle_message(&self, msg: Result<SchedulerMessage, kanal::ReceiveError>) -> bool {
353 match msg {
354 Ok(SchedulerMessage::Enqueue(arc_request)) => {
355 let request = Arc::unwrap_or_clone(arc_request);
357 trace!("Enqueuing request: {}", request.url);
358 self.push_request(request);
359 self.pending.fetch_add(1, Ordering::AcqRel);
360 true
361 }
362 Ok(SchedulerMessage::EnqueueBatch(requests)) => {
363 let count = requests.len();
364 for request in requests {
365 self.push_request(Arc::unwrap_or_clone(request));
366 }
367 self.pending.fetch_add(count, Ordering::AcqRel);
368 true
369 }
370 Ok(SchedulerMessage::Requeue(arc_request)) => {
371 let request = Arc::unwrap_or_clone(arc_request);
372 trace!("Re-enqueuing request: {}", request.url);
373 self.push_request(request);
374 true
375 }
376 Ok(SchedulerMessage::MarkAsVisited(fingerprint)) => {
377 trace!("Marking URL fingerprint as visited: {}", fingerprint);
378
379 self.visited.insert(fingerprint.clone(), true);
381
382 debug!("Marked URL as visited: {}", fingerprint);
384
385 {
387 let mut buffer = self.buffer.lock();
388 buffer.insert(fingerprint);
389 if buffer.len() >= BLOOM_BUFFER_FLUSH_SIZE {
390 self.notify.notify_one();
391 }
392 }
393
394 true
395 }
396 Ok(SchedulerMessage::MarkAsVisitedBatch(fingerprints)) => {
397 let count = fingerprints.len();
398 trace!("Marking {} URL fingerprints as visited in batch", count);
399
400 for fingerprint in &fingerprints {
402 self.visited.insert(fingerprint.clone(), true);
403 }
404
405 {
407 let mut buffer = self.buffer.lock();
408 buffer.extend(fingerprints);
409 if buffer.len() >= BLOOM_BUFFER_FLUSH_SIZE {
410 self.notify.notify_one();
411 }
412 }
413
414 debug!("Marked {} URLs as visited in batch", count);
415 true
416 }
417 Ok(SchedulerMessage::Shutdown) => {
418 info!("Scheduler received shutdown signal. Exiting run_loop.");
419 self.is_shutting_down.store(true, Ordering::SeqCst);
420 self.flush_buffer_now();
421 false
422 }
423 Err(_) => {
424 warn!("Scheduler internal message channel closed. Exiting run_loop.");
425 self.is_shutting_down.store(true, Ordering::SeqCst);
426 false
427 }
428 }
429 }
430
431 #[cfg(feature = "checkpoint")]
432 pub async fn snapshot(&self) -> Result<SchedulerCheckpoint, SpiderError> {
433 let visited_urls = dashmap::DashSet::new();
434 for entry in self.visited.iter() {
435 let (key, _) = entry;
436 visited_urls.insert(key.as_ref().clone());
437 }
438
439 let mut request_queue = std::collections::VecDeque::new();
440 let mut queue = self.queue.lock();
441 let mut temp_requests = Vec::with_capacity(queue.len());
442
443 while let Some(scheduled) = queue.pop() {
444 request_queue.push_back(scheduled.request.clone());
445 temp_requests.push(scheduled);
446 }
447
448 if !self.is_shutting_down.load(Ordering::SeqCst) {
449 for scheduled in temp_requests {
450 queue.push(scheduled);
451 }
452 }
453 drop(queue);
454
455 let mut salvaged_requests = std::collections::VecDeque::new();
456 let mut temp_salvaged = Vec::new();
457
458 while let Some(request) = self.salvaged.pop() {
459 temp_salvaged.push(request);
460 }
461
462 for request in temp_salvaged.into_iter() {
463 salvaged_requests.push_back(request.clone());
464 if !self.is_shutting_down.load(Ordering::SeqCst) {
465 self.salvaged.push(request);
466 }
467 }
468
469 Ok(SchedulerCheckpoint {
470 request_queue,
471 visited_urls,
472 salvaged_requests,
473 })
474 }
475
476 #[cfg(not(feature = "checkpoint"))]
477 pub async fn snapshot(&self) -> Result<(), SpiderError> {
478 Ok(())
479 }
480
481 pub async fn enqueue_request(&self, request: Request) -> Result<(), SpiderError> {
482 if !self.should_enqueue(&request) {
483 trace!("Request already visited, skipping: {}", request.url);
484 return Ok(());
485 }
486
487 loop {
488 let pending = self.pending.load(Ordering::SeqCst);
489 if pending < self.max_pending {
490 break;
491 }
492
493 if self.is_shutting_down.load(Ordering::SeqCst) {
494 return Err(SpiderError::GeneralError(
495 "Scheduler is shutting down.".into(),
496 ));
497 }
498
499 trace!(
500 "Scheduler capacity reached ({} pending), waiting to enqueue: {}",
501 self.max_pending, request.url
502 );
503 self.capacity_notify.notified().await;
504 }
505
506 trace!("Enqueuing request: {}", request.url);
507 let request_arc = Arc::new(request);
508 if self
509 .tx
510 .send(SchedulerMessage::Enqueue(Arc::clone(&request_arc)))
511 .await
512 .is_err()
513 {
514 if !self.is_shutting_down.load(Ordering::SeqCst) {
515 error!(
516 "Scheduler internal message channel is closed. Salvaging request: {}",
517 request_arc.url
518 );
519 }
520 let salvaged_request =
521 Arc::try_unwrap(request_arc).unwrap_or_else(|shared| shared.as_ref().clone());
522 self.salvaged.push(salvaged_request);
523 return Err(SpiderError::GeneralError(
524 "Scheduler internal channel closed, request salvaged.".into(),
525 ));
526 }
527
528 trace!("Successfully enqueued request: {}", request_arc.url);
529 Ok(())
530 }
531
532 pub async fn enqueue_requests_batch(
534 &self,
535 requests: Vec<Request>,
536 ) -> Result<usize, SpiderError> {
537 if requests.is_empty() {
538 return Ok(0);
539 }
540
541 let mut filtered = Vec::with_capacity(requests.len());
542 let mut seen_fingerprints = HashSet::with_capacity(requests.len());
543 for request in requests {
544 let fingerprint = request.fingerprint();
545 if seen_fingerprints.insert(fingerprint) && self.should_enqueue(&request) {
546 filtered.push(Arc::new(request));
547 }
548 }
549
550 if filtered.is_empty() {
551 return Ok(0);
552 }
553
554 let batch_len = filtered.len();
555 loop {
556 let pending = self.pending.load(Ordering::SeqCst);
557 if pending.saturating_add(batch_len) <= self.max_pending {
558 break;
559 }
560
561 if self.is_shutting_down.load(Ordering::SeqCst) {
562 return Err(SpiderError::GeneralError(
563 "Scheduler is shutting down.".into(),
564 ));
565 }
566
567 self.capacity_notify.notified().await;
568 }
569
570 if self
571 .tx
572 .send(SchedulerMessage::EnqueueBatch(filtered.clone()))
573 .await
574 .is_err()
575 {
576 if !self.is_shutting_down.load(Ordering::SeqCst) {
577 error!(
578 "Scheduler internal message channel is closed. Salvaging batch request set."
579 );
580 }
581 for request in filtered {
582 let salvaged =
583 Arc::try_unwrap(request).unwrap_or_else(|shared| shared.as_ref().clone());
584 self.salvaged.push(salvaged);
585 }
586 return Err(SpiderError::GeneralError(
587 "Scheduler internal channel closed, request batch salvaged.".into(),
588 ));
589 }
590
591 Ok(batch_len)
592 }
593
594 pub async fn requeue_request(&self, request: Request) -> Result<(), SpiderError> {
599 if !self.should_enqueue(&request) {
600 trace!(
601 "Request already visited during requeue, skipping: {}",
602 request.url
603 );
604 return Ok(());
605 }
606
607 let reserved_slot = self
608 .pending
609 .compare_exchange(0, 1, Ordering::AcqRel, Ordering::Acquire)
610 .is_ok();
611
612 trace!(
613 "Re-enqueuing request without changing pending count: {}",
614 request.url
615 );
616 let request_arc = Arc::new(request);
617 if self
618 .tx
619 .send(SchedulerMessage::Requeue(Arc::clone(&request_arc)))
620 .await
621 .is_err()
622 {
623 if !self.is_shutting_down.load(Ordering::SeqCst) {
624 error!(
625 "Scheduler internal message channel is closed. Salvaging re-queued request: {}",
626 request_arc.url
627 );
628 }
629 let salvaged_request =
630 Arc::try_unwrap(request_arc).unwrap_or_else(|shared| shared.as_ref().clone());
631 self.salvaged.push(salvaged_request);
632 if reserved_slot {
633 self.complete_request();
634 }
635 return Err(SpiderError::GeneralError(
636 "Scheduler internal channel closed, request salvaged.".into(),
637 ));
638 }
639
640 Ok(())
641 }
642
643 pub fn complete_request(&self) {
645 let mut current = self.pending.load(Ordering::Acquire);
646 loop {
647 if current == 0 {
648 warn!("Scheduler pending request counter underflow prevented.");
649 return;
650 }
651
652 match self.pending.compare_exchange_weak(
653 current,
654 current - 1,
655 Ordering::AcqRel,
656 Ordering::Acquire,
657 ) {
658 Ok(_) => break,
659 Err(actual) => current = actual,
660 }
661 }
662
663 self.capacity_notify.notify_waiters();
664 }
665
666 pub async fn shutdown(&self) -> Result<(), SpiderError> {
672 self.is_shutting_down.store(true, Ordering::SeqCst);
673
674 if !self.tx.is_closed() {
675 self.tx.send(SchedulerMessage::Shutdown).await.map_err(|e| {
676 SpiderError::GeneralError(format!(
677 "Scheduler: Failed to send shutdown signal: {}",
678 e
679 ))
680 })
681 } else {
682 info!("Scheduler internal channel already closed, skipping shutdown signal");
683 Ok(())
684 }
685 }
686
687 pub async fn mark_visited(&self, fingerprint: impl Into<String>) -> Result<(), SpiderError> {
693 let fingerprint = fingerprint.into();
694 trace!(
695 "Sending MarkAsVisited message for fingerprint: {}",
696 fingerprint
697 );
698 self.tx
699 .send(SchedulerMessage::MarkAsVisited(fingerprint))
700 .await
701 .map_err(|e| {
702 if !self.is_shutting_down.load(Ordering::SeqCst) {
703 error!("Scheduler internal message channel is closed. Failed to mark URL as visited: {}", e);
704 }
705 SpiderError::GeneralError(format!(
706 "Scheduler: Failed to send MarkAsVisited message: {}",
707 e
708 ))
709 })
710 }
711
712 pub async fn mark_visited_batch(&self, fingerprints: Vec<String>) -> Result<(), SpiderError> {
720 if fingerprints.is_empty() {
721 return Ok(());
722 }
723
724 trace!(
725 "Sending MarkAsVisitedBatch message for {} fingerprints",
726 fingerprints.len()
727 );
728 self.tx
729 .send(SchedulerMessage::MarkAsVisitedBatch(fingerprints))
730 .await
731 .map_err(|e| {
732 if !self.is_shutting_down.load(Ordering::SeqCst) {
733 error!("Scheduler internal message channel is closed. Failed to mark URLs as visited in batch: {}", e);
734 }
735 SpiderError::GeneralError(format!(
736 "Scheduler: Failed to send MarkAsVisitedBatch message: {}",
737 e
738 ))
739 })
740 }
741
742 pub fn is_visited(&self, fingerprint: &str) -> bool {
744 if !self.bloom.read().might_contain(fingerprint) {
745 return false;
746 }
747
748 {
749 let buffer = self.buffer.lock();
750 if buffer.contains(fingerprint) {
751 return true;
752 }
753 }
754
755 self.visited.contains_key(fingerprint)
756 }
757
758 fn flush_buffer_now(&self) {
759 let mut buffer = self.buffer.lock();
760 if !buffer.is_empty() {
761 let items: Vec<String> = buffer.drain().collect();
762 drop(buffer);
763
764 let mut bloom = self.bloom.write();
765 for item in items {
766 bloom.add(&item);
767 }
768 }
769 }
770
771 async fn flush_buffer(&self, notify: Arc<Notify>) {
772 loop {
773 tokio::select! {
774 _ = notify.notified() => {
775 self.flush_buffer_now();
776 }
777 _ = tokio::time::sleep(tokio::time::Duration::from_millis(BLOOM_FLUSH_INTERVAL_MS)) => {
778 self.flush_buffer_now();
779 }
780 }
781 }
782 }
783
784 pub fn should_enqueue(&self, request: &Request) -> bool {
786 let fingerprint = request.fingerprint();
787 !self.is_visited(&fingerprint)
788 }
789
790 #[inline]
792 pub fn len(&self) -> usize {
793 self.pending.load(Ordering::Acquire)
794 }
795
796 #[inline]
798 pub fn is_empty(&self) -> bool {
799 self.len() == 0
800 }
801
802 #[inline]
804 pub fn is_idle(&self) -> bool {
805 self.is_empty()
806 }
807}
808
809impl Scheduler {
810 fn next_sequence(&self) -> u64 {
811 self.sequence.fetch_add(1, Ordering::AcqRel)
812 }
813
814 fn push_request(&self, request: Request) {
815 let scheduled = ScheduledRequest::new(request, self.next_sequence());
816 self.queue.lock().push(scheduled);
817 }
818
819 fn pop_request(&self) -> Option<Request> {
820 self.queue.lock().pop().map(|scheduled| scheduled.request)
821 }
822}