1#[cfg(feature = "checkpoint")]
46use spider_util::constants::DEFAULT_VISITED_CACHE_SIZE;
47#[cfg(feature = "checkpoint")]
48use crate::SchedulerCheckpoint;
49
50#[cfg(not(feature = "checkpoint"))]
51use spider_util::constants::MAX_PENDING_REQUESTS;
52
53use spider_util::constants::{
54 BLOOM_FILTER_CAPACITY, BLOOM_FILTER_HASH_FUNCTIONS,
55 VISITED_URL_CACHE_CAPACITY, VISITED_URL_CACHE_TTL_SECS,
56};
57use spider_util::error::SpiderError;
58use spider_util::request::Request;
59use crossbeam::queue::SegQueue;
60use kanal::{AsyncReceiver, AsyncSender, bounded_async, unbounded_async};
61use moka::sync::Cache;
62use std::sync::Arc;
63use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
64use log::{debug, error, info, trace, warn};
65
66enum SchedulerMessage {
68 Enqueue(Arc<Request>),
70 MarkAsVisited(String),
72 MarkAsVisitedBatch(Vec<String>),
74 Shutdown,
76}
77
78use spider_util::bloom::BloomFilter;
79
80use tokio::sync::Notify;
81
82pub struct Scheduler {
104 queue: SegQueue<Request>,
106 visited: Cache<String, bool>,
108 bloom: std::sync::Arc<parking_lot::RwLock<BloomFilter>>,
110 buffer: Arc<std::sync::Mutex<Vec<String>>>,
112 notify: Arc<Notify>,
114 tx: AsyncSender<SchedulerMessage>,
116 pending: AtomicUsize,
118 salvaged: SegQueue<Request>,
120 pub(crate) is_shutting_down: AtomicBool,
122 max_pending: usize,
124}
125
126impl Scheduler {
127 #[cfg(feature = "checkpoint")]
156 pub fn new(
157 initial_state: Option<SchedulerCheckpoint>,
158 ) -> (Arc<Self>, AsyncReceiver<Request>) {
159 let (tx, rx_internal) = unbounded_async();
160 let (tx_out, rx_out) = bounded_async(100);
161
162 let queue: SegQueue<Request>;
163 let visited: Cache<String, bool>;
164 let pending: AtomicUsize;
165 let salvaged: SegQueue<Request>;
166
167 if let Some(state) = initial_state {
168 info!(
169 "Initializing scheduler from checkpoint with {} requests, {} visited URLs, and {} salvaged requests.",
170 state.request_queue.len(),
171 state.visited_urls.len(),
172 state.salvaged_requests.len(),
173 );
174 let pend = state.request_queue.len() + state.salvaged_requests.len();
175 queue = SegQueue::new();
176 for request in state.request_queue {
177 queue.push(request);
178 }
179
180 visited = Cache::builder()
181 .max_capacity(VISITED_URL_CACHE_CAPACITY)
182 .time_to_idle(std::time::Duration::from_secs(VISITED_URL_CACHE_TTL_SECS))
183 .eviction_listener(|_key, _value, _cause| {})
184 .build();
185 for url in state.visited_urls {
186 visited.insert(url, true);
187 }
188
189 pending = AtomicUsize::new(pend);
190 salvaged = SegQueue::new();
191 for request in state.salvaged_requests {
192 salvaged.push(request);
193 }
194 } else {
195 queue = SegQueue::new();
196 visited = Cache::builder().max_capacity(DEFAULT_VISITED_CACHE_SIZE).build();
197 pending = AtomicUsize::new(0);
198 salvaged = SegQueue::new();
199 }
200
201 let buffer = Arc::new(std::sync::Mutex::new(Vec::new()));
202 let notify = Arc::new(Notify::new());
203
204 let scheduler = Arc::new(Scheduler {
205 queue,
206 visited,
207 bloom: std::sync::Arc::new(parking_lot::RwLock::new(BloomFilter::new(
208 BLOOM_FILTER_CAPACITY,
209 BLOOM_FILTER_HASH_FUNCTIONS,
210 ))),
211 buffer: buffer.clone(),
212 notify: notify.clone(),
213 tx,
214 pending,
215 salvaged,
216 is_shutting_down: AtomicBool::new(false),
217 max_pending: 30000,
218 });
219
220 let scheduler_bloom = Arc::clone(&scheduler);
221 let buffer_clone = buffer.clone();
222 let notify_clone = notify.clone();
223 tokio::spawn(async move {
224 scheduler_bloom.flush_buffer(buffer_clone, notify_clone).await;
225 });
226
227 let scheduler_task = Arc::clone(&scheduler);
228 tokio::spawn(async move {
229 scheduler_task.run_loop(rx_internal, tx_out).await;
230 });
231
232 (scheduler, rx_out)
233 }
234
235 #[cfg(not(feature = "checkpoint"))]
236 pub fn new(
237 _initial_state: Option<()>,
238 ) -> (Arc<Self>, AsyncReceiver<Request>) {
239 let (tx, rx_internal) = unbounded_async();
240 let (tx_out, rx_out) = bounded_async(100);
241
242 let queue = SegQueue::new();
243 let visited = Cache::builder()
244 .max_capacity(VISITED_URL_CACHE_CAPACITY)
245 .time_to_idle(std::time::Duration::from_secs(VISITED_URL_CACHE_TTL_SECS))
246 .eviction_listener(|_key, _value, _cause| {})
247 .build();
248 let pending = AtomicUsize::new(0);
249 let salvaged = SegQueue::new();
250
251 let buffer = Arc::new(std::sync::Mutex::new(Vec::new()));
252 let notify = Arc::new(Notify::new());
253
254 let scheduler = Arc::new(Scheduler {
255 queue,
256 visited,
257 bloom: std::sync::Arc::new(parking_lot::RwLock::new(BloomFilter::new(
258 BLOOM_FILTER_CAPACITY,
259 BLOOM_FILTER_HASH_FUNCTIONS,
260 ))),
261 buffer: buffer.clone(),
262 notify: notify.clone(),
263 tx,
264 pending,
265 salvaged,
266 is_shutting_down: AtomicBool::new(false),
267 max_pending: MAX_PENDING_REQUESTS,
268 });
269
270 let scheduler_bloom = Arc::clone(&scheduler);
271 let buffer_clone = buffer.clone();
272 let notify_clone = notify.clone();
273 tokio::spawn(async move {
274 scheduler_bloom.flush_buffer(buffer_clone, notify_clone).await;
275 });
276
277 let scheduler_task = Arc::clone(&scheduler);
278 tokio::spawn(async move {
279 scheduler_task.run_loop(rx_internal, tx_out).await;
280 });
281
282 (scheduler, rx_out)
283 }
284
285 async fn run_loop(
286 &self,
287 rx_internal: AsyncReceiver<SchedulerMessage>,
288 tx_out: AsyncSender<Request>,
289 ) {
290 info!(
291 "Scheduler run_loop started with max pending: {}",
292 self.max_pending
293 );
294 loop {
295 if let Ok(Some(msg)) = rx_internal.try_recv() {
296 trace!("Processing pending internal message");
297 if !self.handle_message(Ok(msg)).await {
298 break;
299 }
300 continue;
301 }
302
303 let request = if !tx_out.is_closed() && !self.is_idle() {
304 self.queue.pop()
305 } else {
306 None
307 };
308
309 if let Some(request) = request {
310 trace!("Sending request to crawler: {}", request.url);
311 tokio::select! {
312 send_res = tx_out.send(request) => {
313 if send_res.is_err() {
314 error!("Crawler receiver dropped. Scheduler can no longer send requests.");
315 } else {
316 trace!("Successfully sent request to crawler");
317 }
318 self.pending.fetch_sub(1, Ordering::SeqCst);
319 },
320 recv_res = rx_internal.recv() => {
321 trace!("Received internal message while sending request");
322 if !self.handle_message(recv_res).await {
323 break;
324 }
325 continue;
326 }
327 }
328 } else {
329 trace!("No pending requests, waiting for internal message");
330 if !self.handle_message(rx_internal.recv().await).await {
331 break;
332 }
333 }
334 }
335 info!(
336 "Scheduler run_loop finished with {} pending requests remaining.",
337 self.pending.load(Ordering::SeqCst)
338 );
339 }
340
341 async fn handle_message(&self, msg: Result<SchedulerMessage, kanal::ReceiveError>) -> bool {
342 match msg {
343 Ok(SchedulerMessage::Enqueue(arc_request)) => {
344 let request = Arc::unwrap_or_clone(arc_request);
346 trace!("Enqueuing request: {}", request.url);
347 self.queue.push(request);
348 self.pending.fetch_add(1, Ordering::SeqCst);
349 true
350 }
351 Ok(SchedulerMessage::MarkAsVisited(fingerprint)) => {
352 trace!("Marking URL fingerprint as visited: {}", fingerprint);
353
354 self.visited.insert(fingerprint.clone(), true);
356
357 debug!("Marked URL as visited: {}", fingerprint);
359
360 {
362 let mut buffer = self.buffer.lock().unwrap();
363 buffer.push(fingerprint);
364 if buffer.len() >= 100 {
365 self.notify.notify_one();
366 }
367 }
368
369 true
370 }
371 Ok(SchedulerMessage::MarkAsVisitedBatch(mut fingerprints)) => {
372 let count = fingerprints.len();
373 trace!("Marking {} URL fingerprints as visited in batch", count);
374
375 for fingerprint in &fingerprints {
377 self.visited.insert(fingerprint.clone(), true);
378 }
379
380 {
382 let mut buffer = self.buffer.lock().unwrap();
383 buffer.append(&mut fingerprints);
384 if buffer.len() >= 100 {
385 self.notify.notify_one();
386 }
387 }
388
389 debug!("Marked {} URLs as visited in batch", count);
390 true
391 }
392 Ok(SchedulerMessage::Shutdown) => {
393 info!("Scheduler received shutdown signal. Exiting run_loop.");
394 self.is_shutting_down.store(true, Ordering::SeqCst);
395 self.flush_buffer_now();
396 false
397 }
398 Err(_) => {
399 warn!("Scheduler internal message channel closed. Exiting run_loop.");
400 self.is_shutting_down.store(true, Ordering::SeqCst);
401 false
402 }
403 }
404 }
405
406 #[cfg(feature = "checkpoint")]
407 pub async fn snapshot(&self) -> Result<SchedulerCheckpoint, SpiderError> {
408 let visited_urls = dashmap::DashSet::new();
409 for entry in self.visited.iter() {
410 let (key, _) = entry;
411 visited_urls.insert(key.as_ref().clone());
412 }
413
414 let mut request_queue = std::collections::VecDeque::new();
415 let mut temp_requests = Vec::new();
416
417 while let Some(request) = self.queue.pop() {
418 temp_requests.push(request);
419 }
420
421 for request in temp_requests.into_iter() {
422 request_queue.push_back(request.clone());
423 if !self.is_shutting_down.load(Ordering::SeqCst) {
424 self.queue.push(request);
425 }
426 }
427
428 let mut salvaged_requests = std::collections::VecDeque::new();
429 let mut temp_salvaged = Vec::new();
430
431 while let Some(request) = self.salvaged.pop() {
432 temp_salvaged.push(request);
433 }
434
435 for request in temp_salvaged.into_iter() {
436 salvaged_requests.push_back(request.clone());
437 if !self.is_shutting_down.load(Ordering::SeqCst) {
438 self.salvaged.push(request);
439 }
440 }
441
442 Ok(SchedulerCheckpoint {
443 request_queue,
444 visited_urls,
445 salvaged_requests,
446 })
447 }
448
449 #[cfg(not(feature = "checkpoint"))]
450 pub async fn snapshot(&self) -> Result<(), SpiderError> {
451 Ok(())
452 }
453
454 pub async fn enqueue_request(&self, request: Request) -> Result<(), SpiderError> {
455 if !self.should_enqueue(&request) {
456 trace!("Request already visited, skipping: {}", request.url);
457 return Ok(());
458 }
459
460 let pending = self.pending.load(Ordering::SeqCst);
461 if pending >= self.max_pending {
462 warn!(
463 "Maximum pending requests reached ({}), request dropped due to backpressure: {}",
464 self.max_pending, request.url
465 );
466 return Err(SpiderError::GeneralError(
467 "Scheduler at maximum capacity, request dropped due to backpressure.".into(),
468 ));
469 }
470
471 trace!("Enqueuing request: {}", request.url);
472 if self
473 .tx
474 .send(SchedulerMessage::Enqueue(Arc::new(request.clone())))
475 .await
476 .is_err()
477 {
478 if !self.is_shutting_down.load(Ordering::SeqCst) {
479 error!(
480 "Scheduler internal message channel is closed. Salvaging request: {}",
481 request.url
482 );
483 }
484 self.salvaged.push(request);
485 return Err(SpiderError::GeneralError(
486 "Scheduler internal channel closed, request salvaged.".into(),
487 ));
488 }
489
490 trace!("Successfully enqueued request: {}", request.url);
491 Ok(())
492 }
493
494 pub async fn shutdown(&self) -> Result<(), SpiderError> {
495 self.is_shutting_down.store(true, Ordering::SeqCst);
496
497 if !self.tx.is_closed() {
498 self.tx
499 .send(SchedulerMessage::Shutdown)
500 .await
501 .map_err(|e| {
502 SpiderError::GeneralError(format!(
503 "Scheduler: Failed to send shutdown signal: {}",
504 e
505 ))
506 })
507 } else {
508 debug!("Scheduler internal channel already closed, skipping shutdown signal");
509 Ok(())
510 }
511 }
512
513 pub async fn mark_visited(&self, fingerprint: String) -> Result<(), SpiderError> {
514 trace!(
515 "Sending MarkAsVisited message for fingerprint: {}",
516 fingerprint
517 );
518 self.tx
519 .send(SchedulerMessage::MarkAsVisited(fingerprint))
520 .await
521 .map_err(|e| {
522 if !self.is_shutting_down.load(Ordering::SeqCst) {
523 error!("Scheduler internal message channel is closed. Failed to mark URL as visited: {}", e);
524 }
525 SpiderError::GeneralError(format!(
526 "Scheduler: Failed to send MarkAsVisited message: {}",
527 e
528 ))
529 })
530 }
531
532 pub async fn mark_visited_batch(&self, fingerprints: Vec<String>) -> Result<(), SpiderError> {
533 if fingerprints.is_empty() {
534 return Ok(());
535 }
536
537 trace!(
538 "Sending MarkAsVisitedBatch message for {} fingerprints",
539 fingerprints.len()
540 );
541 self.tx
542 .send(SchedulerMessage::MarkAsVisitedBatch(fingerprints))
543 .await
544 .map_err(|e| {
545 if !self.is_shutting_down.load(Ordering::SeqCst) {
546 error!("Scheduler internal message channel is closed. Failed to mark URLs as visited in batch: {}", e);
547 }
548 SpiderError::GeneralError(format!(
549 "Scheduler: Failed to send MarkAsVisitedBatch message: {}",
550 e
551 ))
552 })
553 }
554
555 pub fn is_visited(&self, fingerprint: &str) -> bool {
556 if !self.bloom.read().might_contain(fingerprint) {
557 return false;
558 }
559
560 {
561 let buffer = self.buffer.lock().unwrap();
562 if buffer.iter().any(|item| item == fingerprint) {
563 return true;
564 }
565 }
566
567 self.visited.contains_key(fingerprint)
568 }
569
570 fn flush_buffer_now(&self) {
571 let mut buffer = self.buffer.lock().unwrap();
572 if !buffer.is_empty() {
573 let items: Vec<String> = buffer.drain(..).collect();
574 drop(buffer);
575
576 let mut bloom = self.bloom.write();
577 for item in items {
578 bloom.add(&item);
579 }
580 }
581 }
582
583 async fn flush_buffer(
584 &self,
585 _buffer: Arc<std::sync::Mutex<Vec<String>>>,
586 notify: Arc<Notify>,
587 ) {
588 loop {
589 tokio::select! {
590 _ = notify.notified() => {
591 self.flush_buffer_now();
592 }
593 _ = tokio::time::sleep(tokio::time::Duration::from_millis(100)) => {
594 self.flush_buffer_now();
595 }
596 }
597
598 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
599 }
600 }
601
602 pub fn should_enqueue(&self, request: &Request) -> bool {
603 let fingerprint = request.fingerprint();
604 !self.is_visited(&fingerprint)
605 }
606
607 #[inline]
608 pub fn len(&self) -> usize {
609 self.pending.load(Ordering::SeqCst)
610 }
611
612 #[inline]
613 pub fn is_empty(&self) -> bool {
614 self.len() == 0
615 }
616
617 #[inline]
618 pub fn is_idle(&self) -> bool {
619 self.is_empty()
620 }
621}