1#[cfg(feature = "checkpoint")]
26use crate::SchedulerCheckpoint;
27#[cfg(feature = "checkpoint")]
28use spider_util::constants::DEFAULT_VISITED_CACHE_SIZE;
29
30use crossbeam::queue::SegQueue;
31use kanal::{AsyncReceiver, AsyncSender, bounded_async};
32use log::{debug, error, info, trace, warn};
33use moka::sync::Cache;
34use spider_util::constants::{
35 BLOOM_BUFFER_FLUSH_SIZE, BLOOM_FILTER_CAPACITY, BLOOM_FILTER_HASH_FUNCTIONS,
36 BLOOM_FLUSH_INTERVAL_MS, MAX_PENDING_REQUESTS, VISITED_URL_CACHE_CAPACITY,
37 VISITED_URL_CACHE_TTL_SECS,
38};
39use spider_util::error::SpiderError;
40use spider_util::request::Request;
41use std::collections::HashSet;
42use std::sync::Arc;
43use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
44
45enum SchedulerMessage {
47 Enqueue(Arc<Request>),
49 EnqueueBatch(Vec<Arc<Request>>),
51 Requeue(Arc<Request>),
53 MarkAsVisited(String),
55 MarkAsVisitedBatch(Vec<String>),
57 Shutdown,
59}
60
61use spider_util::bloom::BloomFilter;
62
63use parking_lot::Mutex;
64use tokio::sync::Notify;
65
66pub struct Scheduler {
88 queue: SegQueue<Request>,
90 visited: Cache<String, bool>,
92 bloom: std::sync::Arc<parking_lot::RwLock<BloomFilter>>,
94 buffer: Arc<Mutex<HashSet<String>>>,
96 notify: Arc<Notify>,
98 capacity_notify: Arc<Notify>,
100 tx: AsyncSender<SchedulerMessage>,
102 pending: AtomicUsize,
104 salvaged: SegQueue<Request>,
106 pub(crate) is_shutting_down: AtomicBool,
108 max_pending: usize,
110}
111
112impl Scheduler {
113 pub fn new(
142 #[cfg(feature = "checkpoint")] initial_state: Option<SchedulerCheckpoint>,
143 #[cfg(not(feature = "checkpoint"))] _initial_state: Option<()>,
144 max_pending_requests: usize,
145 ) -> (Arc<Self>, AsyncReceiver<Request>) {
146 let max_pending = max_pending_requests.clamp(1, MAX_PENDING_REQUESTS);
147 let (tx, rx_internal) = bounded_async(max_pending.saturating_mul(2).max(1));
148 let output_capacity = (max_pending / 8).clamp(256, 2048);
149 let (tx_out, rx_out) = bounded_async(output_capacity);
150
151 let queue: SegQueue<Request>;
152 let visited: Cache<String, bool>;
153 let pending: AtomicUsize;
154 let salvaged: SegQueue<Request>;
155
156 #[cfg(feature = "checkpoint")]
157 {
158 if let Some(state) = initial_state {
159 info!(
160 "Initializing scheduler from checkpoint with {} requests, {} visited URLs, and {} salvaged requests.",
161 state.request_queue.len(),
162 state.visited_urls.len(),
163 state.salvaged_requests.len(),
164 );
165 let pend = state.request_queue.len() + state.salvaged_requests.len();
166 queue = SegQueue::new();
167 for request in state.request_queue {
168 queue.push(request);
169 }
170
171 visited = Cache::builder()
172 .max_capacity(VISITED_URL_CACHE_CAPACITY)
173 .time_to_idle(std::time::Duration::from_secs(VISITED_URL_CACHE_TTL_SECS))
174 .eviction_listener(|_key, _value, _cause| {})
175 .build();
176 for url in state.visited_urls {
177 visited.insert(url, true);
178 }
179
180 pending = AtomicUsize::new(pend);
181 salvaged = SegQueue::new();
182 for request in state.salvaged_requests {
183 salvaged.push(request);
184 }
185 } else {
186 queue = SegQueue::new();
187 visited = Cache::builder()
188 .max_capacity(DEFAULT_VISITED_CACHE_SIZE)
189 .build();
190 pending = AtomicUsize::new(0);
191 salvaged = SegQueue::new();
192 }
193 }
194
195 #[cfg(not(feature = "checkpoint"))]
196 {
197 queue = SegQueue::new();
198 visited = Cache::builder()
199 .max_capacity(VISITED_URL_CACHE_CAPACITY)
200 .time_to_idle(std::time::Duration::from_secs(VISITED_URL_CACHE_TTL_SECS))
201 .eviction_listener(|_key, _value, _cause| {})
202 .build();
203 pending = AtomicUsize::new(0);
204 salvaged = SegQueue::new();
205 }
206
207 let buffer = Arc::new(Mutex::new(HashSet::new()));
208 let notify = Arc::new(Notify::new());
209 let capacity_notify = Arc::new(Notify::new());
210
211 let scheduler = Arc::new(Scheduler {
212 queue,
213 visited,
214 bloom: std::sync::Arc::new(parking_lot::RwLock::new(BloomFilter::new(
215 BLOOM_FILTER_CAPACITY,
216 BLOOM_FILTER_HASH_FUNCTIONS,
217 ))),
218 buffer: buffer.clone(),
219 notify: notify.clone(),
220 capacity_notify: Arc::clone(&capacity_notify),
221 tx,
222 pending,
223 salvaged,
224 is_shutting_down: AtomicBool::new(false),
225 max_pending,
226 });
227
228 let scheduler_bloom = Arc::clone(&scheduler);
229 let notify_clone = notify.clone();
230 tokio::spawn(async move {
231 scheduler_bloom.flush_buffer(notify_clone).await;
232 });
233
234 let scheduler_task = Arc::clone(&scheduler);
235 tokio::spawn(async move {
236 scheduler_task.run_loop(rx_internal, tx_out).await;
237 });
238
239 (scheduler, rx_out)
240 }
241
242 async fn run_loop(
243 &self,
244 rx_internal: AsyncReceiver<SchedulerMessage>,
245 tx_out: AsyncSender<Request>,
246 ) {
247 info!(
248 "Scheduler run_loop started with max pending: {}",
249 self.max_pending
250 );
251 loop {
252 if let Ok(Some(msg)) = rx_internal.try_recv() {
253 trace!("Processing pending internal message");
254 if !self.handle_message(Ok(msg)).await {
255 break;
256 }
257 continue;
258 }
259
260 let request = if !tx_out.is_closed() && !self.is_idle() {
261 self.queue.pop()
262 } else {
263 None
264 };
265
266 if let Some(request) = request {
267 trace!("Sending request to crawler: {}", request.url);
268 tokio::select! {
269 send_res = tx_out.send(request) => {
270 if send_res.is_err() {
271 error!("Crawler receiver dropped. Scheduler can no longer send requests.");
272 } else {
273 trace!("Successfully sent request to crawler");
274 }
275 },
276 recv_res = rx_internal.recv() => {
277 trace!("Received internal message while sending request");
278 if !self.handle_message(recv_res).await {
279 break;
280 }
281 continue;
282 }
283 }
284 } else {
285 trace!("No pending requests, waiting for internal message");
286 if !self.handle_message(rx_internal.recv().await).await {
287 break;
288 }
289 }
290 }
291 info!(
292 "Scheduler run_loop finished with {} pending requests remaining.",
293 self.pending.load(Ordering::SeqCst)
294 );
295 }
296
297 async fn handle_message(&self, msg: Result<SchedulerMessage, kanal::ReceiveError>) -> bool {
298 match msg {
299 Ok(SchedulerMessage::Enqueue(arc_request)) => {
300 let request = Arc::unwrap_or_clone(arc_request);
302 trace!("Enqueuing request: {}", request.url);
303 self.queue.push(request);
304 self.pending.fetch_add(1, Ordering::AcqRel);
305 true
306 }
307 Ok(SchedulerMessage::EnqueueBatch(requests)) => {
308 let count = requests.len();
309 for request in requests {
310 self.queue.push(Arc::unwrap_or_clone(request));
311 }
312 self.pending.fetch_add(count, Ordering::AcqRel);
313 true
314 }
315 Ok(SchedulerMessage::Requeue(arc_request)) => {
316 let request = Arc::unwrap_or_clone(arc_request);
317 trace!("Re-enqueuing request: {}", request.url);
318 self.queue.push(request);
319 true
320 }
321 Ok(SchedulerMessage::MarkAsVisited(fingerprint)) => {
322 trace!("Marking URL fingerprint as visited: {}", fingerprint);
323
324 self.visited.insert(fingerprint.clone(), true);
326
327 debug!("Marked URL as visited: {}", fingerprint);
329
330 {
332 let mut buffer = self.buffer.lock();
333 buffer.insert(fingerprint);
334 if buffer.len() >= BLOOM_BUFFER_FLUSH_SIZE {
335 self.notify.notify_one();
336 }
337 }
338
339 true
340 }
341 Ok(SchedulerMessage::MarkAsVisitedBatch(fingerprints)) => {
342 let count = fingerprints.len();
343 trace!("Marking {} URL fingerprints as visited in batch", count);
344
345 for fingerprint in &fingerprints {
347 self.visited.insert(fingerprint.clone(), true);
348 }
349
350 {
352 let mut buffer = self.buffer.lock();
353 buffer.extend(fingerprints);
354 if buffer.len() >= BLOOM_BUFFER_FLUSH_SIZE {
355 self.notify.notify_one();
356 }
357 }
358
359 debug!("Marked {} URLs as visited in batch", count);
360 true
361 }
362 Ok(SchedulerMessage::Shutdown) => {
363 info!("Scheduler received shutdown signal. Exiting run_loop.");
364 self.is_shutting_down.store(true, Ordering::SeqCst);
365 self.flush_buffer_now();
366 false
367 }
368 Err(_) => {
369 warn!("Scheduler internal message channel closed. Exiting run_loop.");
370 self.is_shutting_down.store(true, Ordering::SeqCst);
371 false
372 }
373 }
374 }
375
376 #[cfg(feature = "checkpoint")]
377 pub async fn snapshot(&self) -> Result<SchedulerCheckpoint, SpiderError> {
378 let visited_urls = dashmap::DashSet::new();
379 for entry in self.visited.iter() {
380 let (key, _) = entry;
381 visited_urls.insert(key.as_ref().clone());
382 }
383
384 let mut request_queue = std::collections::VecDeque::new();
385 let mut temp_requests = Vec::new();
386
387 while let Some(request) = self.queue.pop() {
388 temp_requests.push(request);
389 }
390
391 for request in temp_requests.into_iter() {
392 request_queue.push_back(request.clone());
393 if !self.is_shutting_down.load(Ordering::SeqCst) {
394 self.queue.push(request);
395 }
396 }
397
398 let mut salvaged_requests = std::collections::VecDeque::new();
399 let mut temp_salvaged = Vec::new();
400
401 while let Some(request) = self.salvaged.pop() {
402 temp_salvaged.push(request);
403 }
404
405 for request in temp_salvaged.into_iter() {
406 salvaged_requests.push_back(request.clone());
407 if !self.is_shutting_down.load(Ordering::SeqCst) {
408 self.salvaged.push(request);
409 }
410 }
411
412 Ok(SchedulerCheckpoint {
413 request_queue,
414 visited_urls,
415 salvaged_requests,
416 })
417 }
418
419 #[cfg(not(feature = "checkpoint"))]
420 pub async fn snapshot(&self) -> Result<(), SpiderError> {
421 Ok(())
422 }
423
424 pub async fn enqueue_request(&self, request: Request) -> Result<(), SpiderError> {
425 if !self.should_enqueue(&request) {
426 trace!("Request already visited, skipping: {}", request.url);
427 return Ok(());
428 }
429
430 loop {
431 let pending = self.pending.load(Ordering::SeqCst);
432 if pending < self.max_pending {
433 break;
434 }
435
436 if self.is_shutting_down.load(Ordering::SeqCst) {
437 return Err(SpiderError::GeneralError(
438 "Scheduler is shutting down.".into(),
439 ));
440 }
441
442 trace!(
443 "Scheduler capacity reached ({} pending), waiting to enqueue: {}",
444 self.max_pending, request.url
445 );
446 self.capacity_notify.notified().await;
447 }
448
449 trace!("Enqueuing request: {}", request.url);
450 let request_arc = Arc::new(request);
451 if self
452 .tx
453 .send(SchedulerMessage::Enqueue(Arc::clone(&request_arc)))
454 .await
455 .is_err()
456 {
457 if !self.is_shutting_down.load(Ordering::SeqCst) {
458 error!(
459 "Scheduler internal message channel is closed. Salvaging request: {}",
460 request_arc.url
461 );
462 }
463 let salvaged_request =
464 Arc::try_unwrap(request_arc).unwrap_or_else(|shared| shared.as_ref().clone());
465 self.salvaged.push(salvaged_request);
466 return Err(SpiderError::GeneralError(
467 "Scheduler internal channel closed, request salvaged.".into(),
468 ));
469 }
470
471 trace!("Successfully enqueued request: {}", request_arc.url);
472 Ok(())
473 }
474
475 pub async fn enqueue_requests_batch(
477 &self,
478 requests: Vec<Request>,
479 ) -> Result<usize, SpiderError> {
480 if requests.is_empty() {
481 return Ok(0);
482 }
483
484 let mut filtered = Vec::with_capacity(requests.len());
485 let mut seen_fingerprints = HashSet::with_capacity(requests.len());
486 for request in requests {
487 let fingerprint = request.fingerprint();
488 if seen_fingerprints.insert(fingerprint) && self.should_enqueue(&request) {
489 filtered.push(Arc::new(request));
490 }
491 }
492
493 if filtered.is_empty() {
494 return Ok(0);
495 }
496
497 let batch_len = filtered.len();
498 loop {
499 let pending = self.pending.load(Ordering::SeqCst);
500 if pending.saturating_add(batch_len) <= self.max_pending {
501 break;
502 }
503
504 if self.is_shutting_down.load(Ordering::SeqCst) {
505 return Err(SpiderError::GeneralError(
506 "Scheduler is shutting down.".into(),
507 ));
508 }
509
510 self.capacity_notify.notified().await;
511 }
512
513 if self
514 .tx
515 .send(SchedulerMessage::EnqueueBatch(filtered.clone()))
516 .await
517 .is_err()
518 {
519 if !self.is_shutting_down.load(Ordering::SeqCst) {
520 error!(
521 "Scheduler internal message channel is closed. Salvaging batch request set."
522 );
523 }
524 for request in filtered {
525 let salvaged =
526 Arc::try_unwrap(request).unwrap_or_else(|shared| shared.as_ref().clone());
527 self.salvaged.push(salvaged);
528 }
529 return Err(SpiderError::GeneralError(
530 "Scheduler internal channel closed, request batch salvaged.".into(),
531 ));
532 }
533
534 Ok(batch_len)
535 }
536
537 pub async fn requeue_request(&self, request: Request) -> Result<(), SpiderError> {
542 if !self.should_enqueue(&request) {
543 trace!(
544 "Request already visited during requeue, skipping: {}",
545 request.url
546 );
547 return Ok(());
548 }
549
550 let reserved_slot = self
551 .pending
552 .compare_exchange(0, 1, Ordering::AcqRel, Ordering::Acquire)
553 .is_ok();
554
555 trace!(
556 "Re-enqueuing request without changing pending count: {}",
557 request.url
558 );
559 let request_arc = Arc::new(request);
560 if self
561 .tx
562 .send(SchedulerMessage::Requeue(Arc::clone(&request_arc)))
563 .await
564 .is_err()
565 {
566 if !self.is_shutting_down.load(Ordering::SeqCst) {
567 error!(
568 "Scheduler internal message channel is closed. Salvaging re-queued request: {}",
569 request_arc.url
570 );
571 }
572 let salvaged_request =
573 Arc::try_unwrap(request_arc).unwrap_or_else(|shared| shared.as_ref().clone());
574 self.salvaged.push(salvaged_request);
575 if reserved_slot {
576 self.complete_request();
577 }
578 return Err(SpiderError::GeneralError(
579 "Scheduler internal channel closed, request salvaged.".into(),
580 ));
581 }
582
583 Ok(())
584 }
585
586 pub fn complete_request(&self) {
588 let mut current = self.pending.load(Ordering::Acquire);
589 loop {
590 if current == 0 {
591 warn!("Scheduler pending request counter underflow prevented.");
592 return;
593 }
594
595 match self.pending.compare_exchange_weak(
596 current,
597 current - 1,
598 Ordering::AcqRel,
599 Ordering::Acquire,
600 ) {
601 Ok(_) => break,
602 Err(actual) => current = actual,
603 }
604 }
605
606 self.capacity_notify.notify_waiters();
607 }
608
609 pub async fn shutdown(&self) -> Result<(), SpiderError> {
615 self.is_shutting_down.store(true, Ordering::SeqCst);
616
617 if !self.tx.is_closed() {
618 self.tx.send(SchedulerMessage::Shutdown).await.map_err(|e| {
619 SpiderError::GeneralError(format!(
620 "Scheduler: Failed to send shutdown signal: {}",
621 e
622 ))
623 })
624 } else {
625 info!("Scheduler internal channel already closed, skipping shutdown signal");
626 Ok(())
627 }
628 }
629
630 pub async fn mark_visited(&self, fingerprint: String) -> Result<(), SpiderError> {
636 trace!(
637 "Sending MarkAsVisited message for fingerprint: {}",
638 fingerprint
639 );
640 self.tx
641 .send(SchedulerMessage::MarkAsVisited(fingerprint))
642 .await
643 .map_err(|e| {
644 if !self.is_shutting_down.load(Ordering::SeqCst) {
645 error!("Scheduler internal message channel is closed. Failed to mark URL as visited: {}", e);
646 }
647 SpiderError::GeneralError(format!(
648 "Scheduler: Failed to send MarkAsVisited message: {}",
649 e
650 ))
651 })
652 }
653
654 pub async fn mark_visited_batch(&self, fingerprints: Vec<String>) -> Result<(), SpiderError> {
662 if fingerprints.is_empty() {
663 return Ok(());
664 }
665
666 trace!(
667 "Sending MarkAsVisitedBatch message for {} fingerprints",
668 fingerprints.len()
669 );
670 self.tx
671 .send(SchedulerMessage::MarkAsVisitedBatch(fingerprints))
672 .await
673 .map_err(|e| {
674 if !self.is_shutting_down.load(Ordering::SeqCst) {
675 error!("Scheduler internal message channel is closed. Failed to mark URLs as visited in batch: {}", e);
676 }
677 SpiderError::GeneralError(format!(
678 "Scheduler: Failed to send MarkAsVisitedBatch message: {}",
679 e
680 ))
681 })
682 }
683
684 pub fn is_visited(&self, fingerprint: &str) -> bool {
686 if !self.bloom.read().might_contain(fingerprint) {
687 return false;
688 }
689
690 {
691 let buffer = self.buffer.lock();
692 if buffer.contains(fingerprint) {
693 return true;
694 }
695 }
696
697 self.visited.contains_key(fingerprint)
698 }
699
700 fn flush_buffer_now(&self) {
701 let mut buffer = self.buffer.lock();
702 if !buffer.is_empty() {
703 let items: Vec<String> = buffer.drain().collect();
704 drop(buffer);
705
706 let mut bloom = self.bloom.write();
707 for item in items {
708 bloom.add(&item);
709 }
710 }
711 }
712
713 async fn flush_buffer(&self, notify: Arc<Notify>) {
714 loop {
715 tokio::select! {
716 _ = notify.notified() => {
717 self.flush_buffer_now();
718 }
719 _ = tokio::time::sleep(tokio::time::Duration::from_millis(BLOOM_FLUSH_INTERVAL_MS)) => {
720 self.flush_buffer_now();
721 }
722 }
723 }
724 }
725
726 pub fn should_enqueue(&self, request: &Request) -> bool {
728 let fingerprint = request.fingerprint();
729 !self.is_visited(&fingerprint)
730 }
731
732 #[inline]
734 pub fn len(&self) -> usize {
735 self.pending.load(Ordering::Acquire)
736 }
737
738 #[inline]
740 pub fn is_empty(&self) -> bool {
741 self.len() == 0
742 }
743
744 #[inline]
746 pub fn is_idle(&self) -> bool {
747 self.is_empty()
748 }
749}