1#[cfg(feature = "checkpoint")]
46use crate::SchedulerCheckpoint;
47#[cfg(feature = "checkpoint")]
48use spider_util::constants::DEFAULT_VISITED_CACHE_SIZE;
49
50use crossbeam::queue::SegQueue;
51use kanal::{AsyncReceiver, AsyncSender, bounded_async};
52use log::{debug, error, info, trace, warn};
53use moka::sync::Cache;
54use spider_util::constants::{
55 BLOOM_BUFFER_FLUSH_SIZE, BLOOM_FILTER_CAPACITY, BLOOM_FILTER_HASH_FUNCTIONS,
56 BLOOM_FLUSH_INTERVAL_MS, MAX_PENDING_REQUESTS, VISITED_URL_CACHE_CAPACITY,
57 VISITED_URL_CACHE_TTL_SECS,
58};
59use spider_util::error::SpiderError;
60use spider_util::request::Request;
61use std::collections::HashSet;
62use std::sync::Arc;
63use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
64
65enum SchedulerMessage {
67 Enqueue(Arc<Request>),
69 MarkAsVisited(String),
71 MarkAsVisitedBatch(Vec<String>),
73 Shutdown,
75}
76
77use spider_util::bloom::BloomFilter;
78
79use parking_lot::Mutex;
80use tokio::sync::Notify;
81
82pub struct Scheduler {
104 queue: SegQueue<Request>,
106 visited: Cache<String, bool>,
108 bloom: std::sync::Arc<parking_lot::RwLock<BloomFilter>>,
110 buffer: Arc<Mutex<HashSet<String>>>,
112 notify: Arc<Notify>,
114 tx: AsyncSender<SchedulerMessage>,
116 pending: AtomicUsize,
118 salvaged: SegQueue<Request>,
120 pub(crate) is_shutting_down: AtomicBool,
122 max_pending: usize,
124}
125
126impl Scheduler {
127 pub fn new(
156 #[cfg(feature = "checkpoint")] initial_state: Option<SchedulerCheckpoint>,
157 #[cfg(not(feature = "checkpoint"))] _initial_state: Option<()>,
158 ) -> (Arc<Self>, AsyncReceiver<Request>) {
159 let (tx, rx_internal) = bounded_async(MAX_PENDING_REQUESTS * 2);
160 let (tx_out, rx_out) = bounded_async(100);
161
162 let queue: SegQueue<Request>;
163 let visited: Cache<String, bool>;
164 let pending: AtomicUsize;
165 let salvaged: SegQueue<Request>;
166
167 #[cfg(feature = "checkpoint")]
168 {
169 if let Some(state) = initial_state {
170 info!(
171 "Initializing scheduler from checkpoint with {} requests, {} visited URLs, and {} salvaged requests.",
172 state.request_queue.len(),
173 state.visited_urls.len(),
174 state.salvaged_requests.len(),
175 );
176 let pend = state.request_queue.len() + state.salvaged_requests.len();
177 queue = SegQueue::new();
178 for request in state.request_queue {
179 queue.push(request);
180 }
181
182 visited = Cache::builder()
183 .max_capacity(VISITED_URL_CACHE_CAPACITY)
184 .time_to_idle(std::time::Duration::from_secs(VISITED_URL_CACHE_TTL_SECS))
185 .eviction_listener(|_key, _value, _cause| {})
186 .build();
187 for url in state.visited_urls {
188 visited.insert(url, true);
189 }
190
191 pending = AtomicUsize::new(pend);
192 salvaged = SegQueue::new();
193 for request in state.salvaged_requests {
194 salvaged.push(request);
195 }
196 } else {
197 queue = SegQueue::new();
198 visited = Cache::builder()
199 .max_capacity(DEFAULT_VISITED_CACHE_SIZE)
200 .build();
201 pending = AtomicUsize::new(0);
202 salvaged = SegQueue::new();
203 }
204 }
205
206 #[cfg(not(feature = "checkpoint"))]
207 {
208 queue = SegQueue::new();
209 visited = Cache::builder()
210 .max_capacity(VISITED_URL_CACHE_CAPACITY)
211 .time_to_idle(std::time::Duration::from_secs(VISITED_URL_CACHE_TTL_SECS))
212 .eviction_listener(|_key, _value, _cause| {})
213 .build();
214 pending = AtomicUsize::new(0);
215 salvaged = SegQueue::new();
216 }
217
218 let buffer = Arc::new(Mutex::new(HashSet::new()));
219 let notify = Arc::new(Notify::new());
220
221 let scheduler = Arc::new(Scheduler {
222 queue,
223 visited,
224 bloom: std::sync::Arc::new(parking_lot::RwLock::new(BloomFilter::new(
225 BLOOM_FILTER_CAPACITY,
226 BLOOM_FILTER_HASH_FUNCTIONS,
227 ))),
228 buffer: buffer.clone(),
229 notify: notify.clone(),
230 tx,
231 pending,
232 salvaged,
233 is_shutting_down: AtomicBool::new(false),
234 max_pending: MAX_PENDING_REQUESTS,
235 });
236
237 let scheduler_bloom = Arc::clone(&scheduler);
238 let notify_clone = notify.clone();
239 tokio::spawn(async move {
240 scheduler_bloom.flush_buffer(notify_clone).await;
241 });
242
243 let scheduler_task = Arc::clone(&scheduler);
244 tokio::spawn(async move {
245 scheduler_task.run_loop(rx_internal, tx_out).await;
246 });
247
248 (scheduler, rx_out)
249 }
250
251 async fn run_loop(
252 &self,
253 rx_internal: AsyncReceiver<SchedulerMessage>,
254 tx_out: AsyncSender<Request>,
255 ) {
256 info!(
257 "Scheduler run_loop started with max pending: {}",
258 self.max_pending
259 );
260 loop {
261 if let Ok(Some(msg)) = rx_internal.try_recv() {
262 trace!("Processing pending internal message");
263 if !self.handle_message(Ok(msg)).await {
264 break;
265 }
266 continue;
267 }
268
269 let request = if !tx_out.is_closed() && !self.is_idle() {
270 self.queue.pop()
271 } else {
272 None
273 };
274
275 if let Some(request) = request {
276 trace!("Sending request to crawler: {}", request.url);
277 tokio::select! {
278 send_res = tx_out.send(request) => {
279 if send_res.is_err() {
280 error!("Crawler receiver dropped. Scheduler can no longer send requests.");
281 } else {
282 trace!("Successfully sent request to crawler");
283 }
284 self.pending.fetch_sub(1, Ordering::AcqRel);
285 },
286 recv_res = rx_internal.recv() => {
287 trace!("Received internal message while sending request");
288 if !self.handle_message(recv_res).await {
289 break;
290 }
291 continue;
292 }
293 }
294 } else {
295 trace!("No pending requests, waiting for internal message");
296 if !self.handle_message(rx_internal.recv().await).await {
297 break;
298 }
299 }
300 }
301 info!(
302 "Scheduler run_loop finished with {} pending requests remaining.",
303 self.pending.load(Ordering::SeqCst)
304 );
305 }
306
307 async fn handle_message(&self, msg: Result<SchedulerMessage, kanal::ReceiveError>) -> bool {
308 match msg {
309 Ok(SchedulerMessage::Enqueue(arc_request)) => {
310 let request = Arc::unwrap_or_clone(arc_request);
312 trace!("Enqueuing request: {}", request.url);
313 self.queue.push(request);
314 self.pending.fetch_add(1, Ordering::AcqRel);
315 true
316 }
317 Ok(SchedulerMessage::MarkAsVisited(fingerprint)) => {
318 trace!("Marking URL fingerprint as visited: {}", fingerprint);
319
320 self.visited.insert(fingerprint.clone(), true);
322
323 debug!("Marked URL as visited: {}", fingerprint);
325
326 {
328 let mut buffer = self.buffer.lock();
329 buffer.insert(fingerprint);
330 if buffer.len() >= BLOOM_BUFFER_FLUSH_SIZE {
331 self.notify.notify_one();
332 }
333 }
334
335 true
336 }
337 Ok(SchedulerMessage::MarkAsVisitedBatch(fingerprints)) => {
338 let count = fingerprints.len();
339 trace!("Marking {} URL fingerprints as visited in batch", count);
340
341 for fingerprint in &fingerprints {
343 self.visited.insert(fingerprint.clone(), true);
344 }
345
346 {
348 let mut buffer = self.buffer.lock();
349 buffer.extend(fingerprints);
350 if buffer.len() >= BLOOM_BUFFER_FLUSH_SIZE {
351 self.notify.notify_one();
352 }
353 }
354
355 debug!("Marked {} URLs as visited in batch", count);
356 true
357 }
358 Ok(SchedulerMessage::Shutdown) => {
359 info!("Scheduler received shutdown signal. Exiting run_loop.");
360 self.is_shutting_down.store(true, Ordering::SeqCst);
361 self.flush_buffer_now();
362 false
363 }
364 Err(_) => {
365 warn!("Scheduler internal message channel closed. Exiting run_loop.");
366 self.is_shutting_down.store(true, Ordering::SeqCst);
367 false
368 }
369 }
370 }
371
372 #[cfg(feature = "checkpoint")]
373 pub async fn snapshot(&self) -> Result<SchedulerCheckpoint, SpiderError> {
374 let visited_urls = dashmap::DashSet::new();
375 for entry in self.visited.iter() {
376 let (key, _) = entry;
377 visited_urls.insert(key.as_ref().clone());
378 }
379
380 let mut request_queue = std::collections::VecDeque::new();
381 let mut temp_requests = Vec::new();
382
383 while let Some(request) = self.queue.pop() {
384 temp_requests.push(request);
385 }
386
387 for request in temp_requests.into_iter() {
388 request_queue.push_back(request.clone());
389 if !self.is_shutting_down.load(Ordering::SeqCst) {
390 self.queue.push(request);
391 }
392 }
393
394 let mut salvaged_requests = std::collections::VecDeque::new();
395 let mut temp_salvaged = Vec::new();
396
397 while let Some(request) = self.salvaged.pop() {
398 temp_salvaged.push(request);
399 }
400
401 for request in temp_salvaged.into_iter() {
402 salvaged_requests.push_back(request.clone());
403 if !self.is_shutting_down.load(Ordering::SeqCst) {
404 self.salvaged.push(request);
405 }
406 }
407
408 Ok(SchedulerCheckpoint {
409 request_queue,
410 visited_urls,
411 salvaged_requests,
412 })
413 }
414
415 #[cfg(not(feature = "checkpoint"))]
416 pub async fn snapshot(&self) -> Result<(), SpiderError> {
417 Ok(())
418 }
419
420 pub async fn enqueue_request(&self, request: Request) -> Result<(), SpiderError> {
421 if !self.should_enqueue(&request) {
422 trace!("Request already visited, skipping: {}", request.url);
423 return Ok(());
424 }
425
426 let pending = self.pending.load(Ordering::SeqCst);
427 if pending >= self.max_pending {
428 warn!(
429 "Maximum pending requests reached ({}), request dropped due to backpressure: {}",
430 self.max_pending, request.url
431 );
432 return Err(SpiderError::GeneralError(
433 "Scheduler at maximum capacity, request dropped due to backpressure.".into(),
434 ));
435 }
436
437 trace!("Enqueuing request: {}", request.url);
438 let request_arc = Arc::new(request);
439 if self
440 .tx
441 .send(SchedulerMessage::Enqueue(Arc::clone(&request_arc)))
442 .await
443 .is_err()
444 {
445 if !self.is_shutting_down.load(Ordering::SeqCst) {
446 error!(
447 "Scheduler internal message channel is closed. Salvaging request: {}",
448 request_arc.url
449 );
450 }
451 let salvaged_request =
452 Arc::try_unwrap(request_arc).unwrap_or_else(|shared| shared.as_ref().clone());
453 self.salvaged.push(salvaged_request);
454 return Err(SpiderError::GeneralError(
455 "Scheduler internal channel closed, request salvaged.".into(),
456 ));
457 }
458
459 trace!("Successfully enqueued request: {}", request_arc.url);
460 Ok(())
461 }
462
463 pub async fn shutdown(&self) -> Result<(), SpiderError> {
469 self.is_shutting_down.store(true, Ordering::SeqCst);
470
471 if !self.tx.is_closed() {
472 self.tx.send(SchedulerMessage::Shutdown).await.map_err(|e| {
473 SpiderError::GeneralError(format!(
474 "Scheduler: Failed to send shutdown signal: {}",
475 e
476 ))
477 })
478 } else {
479 debug!("Scheduler internal channel already closed, skipping shutdown signal");
480 Ok(())
481 }
482 }
483
484 pub async fn mark_visited(&self, fingerprint: String) -> Result<(), SpiderError> {
490 trace!(
491 "Sending MarkAsVisited message for fingerprint: {}",
492 fingerprint
493 );
494 self.tx
495 .send(SchedulerMessage::MarkAsVisited(fingerprint))
496 .await
497 .map_err(|e| {
498 if !self.is_shutting_down.load(Ordering::SeqCst) {
499 error!("Scheduler internal message channel is closed. Failed to mark URL as visited: {}", e);
500 }
501 SpiderError::GeneralError(format!(
502 "Scheduler: Failed to send MarkAsVisited message: {}",
503 e
504 ))
505 })
506 }
507
508 pub async fn mark_visited_batch(&self, fingerprints: Vec<String>) -> Result<(), SpiderError> {
516 if fingerprints.is_empty() {
517 return Ok(());
518 }
519
520 trace!(
521 "Sending MarkAsVisitedBatch message for {} fingerprints",
522 fingerprints.len()
523 );
524 self.tx
525 .send(SchedulerMessage::MarkAsVisitedBatch(fingerprints))
526 .await
527 .map_err(|e| {
528 if !self.is_shutting_down.load(Ordering::SeqCst) {
529 error!("Scheduler internal message channel is closed. Failed to mark URLs as visited in batch: {}", e);
530 }
531 SpiderError::GeneralError(format!(
532 "Scheduler: Failed to send MarkAsVisitedBatch message: {}",
533 e
534 ))
535 })
536 }
537
538 pub fn is_visited(&self, fingerprint: &str) -> bool {
540 if !self.bloom.read().might_contain(fingerprint) {
541 return false;
542 }
543
544 {
545 let buffer = self.buffer.lock();
546 if buffer.contains(fingerprint) {
547 return true;
548 }
549 }
550
551 self.visited.contains_key(fingerprint)
552 }
553
554 fn flush_buffer_now(&self) {
555 let mut buffer = self.buffer.lock();
556 if !buffer.is_empty() {
557 let items: Vec<String> = buffer.drain().collect();
558 drop(buffer);
559
560 let mut bloom = self.bloom.write();
561 for item in items {
562 bloom.add(&item);
563 }
564 }
565 }
566
567 async fn flush_buffer(&self, notify: Arc<Notify>) {
568 loop {
569 tokio::select! {
570 _ = notify.notified() => {
571 self.flush_buffer_now();
572 }
573 _ = tokio::time::sleep(tokio::time::Duration::from_millis(BLOOM_FLUSH_INTERVAL_MS)) => {
574 self.flush_buffer_now();
575 }
576 }
577 }
578 }
579
580 pub fn should_enqueue(&self, request: &Request) -> bool {
582 let fingerprint = request.fingerprint();
583 !self.is_visited(&fingerprint)
584 }
585
586 #[inline]
588 pub fn len(&self) -> usize {
589 self.pending.load(Ordering::Acquire)
590 }
591
592 #[inline]
594 pub fn is_empty(&self) -> bool {
595 self.len() == 0
596 }
597
598 #[inline]
600 pub fn is_idle(&self) -> bool {
601 self.is_empty()
602 }
603}