1#[cfg(feature = "checkpoint")]
46use crate::SchedulerCheckpoint;
47#[cfg(feature = "checkpoint")]
48use spider_util::constants::DEFAULT_VISITED_CACHE_SIZE;
49
50use crossbeam::queue::SegQueue;
51use kanal::{AsyncReceiver, AsyncSender, bounded_async, unbounded_async};
52use log::{debug, error, info, trace, warn};
53use moka::sync::Cache;
54use spider_util::constants::{
55 BLOOM_BUFFER_FLUSH_SIZE, BLOOM_FILTER_CAPACITY, BLOOM_FILTER_HASH_FUNCTIONS,
56 BLOOM_FLUSH_INTERVAL_MS, MAX_PENDING_REQUESTS, VISITED_URL_CACHE_CAPACITY,
57 VISITED_URL_CACHE_TTL_SECS,
58};
59use spider_util::error::SpiderError;
60use spider_util::request::Request;
61use std::sync::Arc;
62use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
63
64enum SchedulerMessage {
66 Enqueue(Arc<Request>),
68 MarkAsVisited(String),
70 MarkAsVisitedBatch(Vec<String>),
72 Shutdown,
74}
75
76use spider_util::bloom::BloomFilter;
77
78use parking_lot::Mutex;
79use tokio::sync::Notify;
80
81pub struct Scheduler {
103 queue: SegQueue<Request>,
105 visited: Cache<String, bool>,
107 bloom: std::sync::Arc<parking_lot::RwLock<BloomFilter>>,
109 buffer: Arc<Mutex<Vec<String>>>,
111 notify: Arc<Notify>,
113 tx: AsyncSender<SchedulerMessage>,
115 pending: AtomicUsize,
117 salvaged: SegQueue<Request>,
119 pub(crate) is_shutting_down: AtomicBool,
121 max_pending: usize,
123}
124
125impl Scheduler {
126 pub fn new(
155 #[cfg(feature = "checkpoint")] initial_state: Option<SchedulerCheckpoint>,
156 #[cfg(not(feature = "checkpoint"))] _initial_state: Option<()>,
157 ) -> (Arc<Self>, AsyncReceiver<Request>) {
158 let (tx, rx_internal) = unbounded_async();
159 let (tx_out, rx_out) = bounded_async(100);
160
161 let queue: SegQueue<Request>;
162 let visited: Cache<String, bool>;
163 let pending: AtomicUsize;
164 let salvaged: SegQueue<Request>;
165
166 #[cfg(feature = "checkpoint")]
167 {
168 if let Some(state) = initial_state {
169 info!(
170 "Initializing scheduler from checkpoint with {} requests, {} visited URLs, and {} salvaged requests.",
171 state.request_queue.len(),
172 state.visited_urls.len(),
173 state.salvaged_requests.len(),
174 );
175 let pend = state.request_queue.len() + state.salvaged_requests.len();
176 queue = SegQueue::new();
177 for request in state.request_queue {
178 queue.push(request);
179 }
180
181 visited = Cache::builder()
182 .max_capacity(VISITED_URL_CACHE_CAPACITY)
183 .time_to_idle(std::time::Duration::from_secs(VISITED_URL_CACHE_TTL_SECS))
184 .eviction_listener(|_key, _value, _cause| {})
185 .build();
186 for url in state.visited_urls {
187 visited.insert(url, true);
188 }
189
190 pending = AtomicUsize::new(pend);
191 salvaged = SegQueue::new();
192 for request in state.salvaged_requests {
193 salvaged.push(request);
194 }
195 } else {
196 queue = SegQueue::new();
197 visited = Cache::builder()
198 .max_capacity(DEFAULT_VISITED_CACHE_SIZE)
199 .build();
200 pending = AtomicUsize::new(0);
201 salvaged = SegQueue::new();
202 }
203 }
204
205 #[cfg(not(feature = "checkpoint"))]
206 {
207 queue = SegQueue::new();
208 visited = Cache::builder()
209 .max_capacity(VISITED_URL_CACHE_CAPACITY)
210 .time_to_idle(std::time::Duration::from_secs(VISITED_URL_CACHE_TTL_SECS))
211 .eviction_listener(|_key, _value, _cause| {})
212 .build();
213 pending = AtomicUsize::new(0);
214 salvaged = SegQueue::new();
215 }
216
217 let buffer = Arc::new(Mutex::new(Vec::new()));
218 let notify = Arc::new(Notify::new());
219
220 let scheduler = Arc::new(Scheduler {
221 queue,
222 visited,
223 bloom: std::sync::Arc::new(parking_lot::RwLock::new(BloomFilter::new(
224 BLOOM_FILTER_CAPACITY,
225 BLOOM_FILTER_HASH_FUNCTIONS,
226 ))),
227 buffer: buffer.clone(),
228 notify: notify.clone(),
229 tx,
230 pending,
231 salvaged,
232 is_shutting_down: AtomicBool::new(false),
233 max_pending: MAX_PENDING_REQUESTS,
234 });
235
236 let scheduler_bloom = Arc::clone(&scheduler);
237 let notify_clone = notify.clone();
238 tokio::spawn(async move {
239 scheduler_bloom.flush_buffer(notify_clone).await;
240 });
241
242 let scheduler_task = Arc::clone(&scheduler);
243 tokio::spawn(async move {
244 scheduler_task.run_loop(rx_internal, tx_out).await;
245 });
246
247 (scheduler, rx_out)
248 }
249
250 async fn run_loop(
251 &self,
252 rx_internal: AsyncReceiver<SchedulerMessage>,
253 tx_out: AsyncSender<Request>,
254 ) {
255 info!(
256 "Scheduler run_loop started with max pending: {}",
257 self.max_pending
258 );
259 loop {
260 if let Ok(Some(msg)) = rx_internal.try_recv() {
261 trace!("Processing pending internal message");
262 if !self.handle_message(Ok(msg)).await {
263 break;
264 }
265 continue;
266 }
267
268 let request = if !tx_out.is_closed() && !self.is_idle() {
269 self.queue.pop()
270 } else {
271 None
272 };
273
274 if let Some(request) = request {
275 trace!("Sending request to crawler: {}", request.url);
276 tokio::select! {
277 send_res = tx_out.send(request) => {
278 if send_res.is_err() {
279 error!("Crawler receiver dropped. Scheduler can no longer send requests.");
280 } else {
281 trace!("Successfully sent request to crawler");
282 }
283 self.pending.fetch_sub(1, Ordering::AcqRel);
284 },
285 recv_res = rx_internal.recv() => {
286 trace!("Received internal message while sending request");
287 if !self.handle_message(recv_res).await {
288 break;
289 }
290 continue;
291 }
292 }
293 } else {
294 trace!("No pending requests, waiting for internal message");
295 if !self.handle_message(rx_internal.recv().await).await {
296 break;
297 }
298 }
299 }
300 info!(
301 "Scheduler run_loop finished with {} pending requests remaining.",
302 self.pending.load(Ordering::SeqCst)
303 );
304 }
305
306 async fn handle_message(&self, msg: Result<SchedulerMessage, kanal::ReceiveError>) -> bool {
307 match msg {
308 Ok(SchedulerMessage::Enqueue(arc_request)) => {
309 let request = Arc::unwrap_or_clone(arc_request);
311 trace!("Enqueuing request: {}", request.url);
312 self.queue.push(request);
313 self.pending.fetch_add(1, Ordering::AcqRel);
314 true
315 }
316 Ok(SchedulerMessage::MarkAsVisited(fingerprint)) => {
317 trace!("Marking URL fingerprint as visited: {}", fingerprint);
318
319 self.visited.insert(fingerprint.clone(), true);
321
322 debug!("Marked URL as visited: {}", fingerprint);
324
325 {
327 let mut buffer = self.buffer.lock();
328 buffer.push(fingerprint);
329 if buffer.len() >= BLOOM_BUFFER_FLUSH_SIZE {
330 self.notify.notify_one();
331 }
332 }
333
334 true
335 }
336 Ok(SchedulerMessage::MarkAsVisitedBatch(mut fingerprints)) => {
337 let count = fingerprints.len();
338 trace!("Marking {} URL fingerprints as visited in batch", count);
339
340 for fingerprint in &fingerprints {
342 self.visited.insert(fingerprint.clone(), true);
343 }
344
345 {
347 let mut buffer = self.buffer.lock();
348 buffer.append(&mut fingerprints);
349 if buffer.len() >= BLOOM_BUFFER_FLUSH_SIZE {
350 self.notify.notify_one();
351 }
352 }
353
354 debug!("Marked {} URLs as visited in batch", count);
355 true
356 }
357 Ok(SchedulerMessage::Shutdown) => {
358 info!("Scheduler received shutdown signal. Exiting run_loop.");
359 self.is_shutting_down.store(true, Ordering::SeqCst);
360 self.flush_buffer_now();
361 false
362 }
363 Err(_) => {
364 warn!("Scheduler internal message channel closed. Exiting run_loop.");
365 self.is_shutting_down.store(true, Ordering::SeqCst);
366 false
367 }
368 }
369 }
370
371 #[cfg(feature = "checkpoint")]
372 pub async fn snapshot(&self) -> Result<SchedulerCheckpoint, SpiderError> {
373 let visited_urls = dashmap::DashSet::new();
374 for entry in self.visited.iter() {
375 let (key, _) = entry;
376 visited_urls.insert(key.as_ref().clone());
377 }
378
379 let mut request_queue = std::collections::VecDeque::new();
380 let mut temp_requests = Vec::new();
381
382 while let Some(request) = self.queue.pop() {
383 temp_requests.push(request);
384 }
385
386 for request in temp_requests.into_iter() {
387 request_queue.push_back(request.clone());
388 if !self.is_shutting_down.load(Ordering::SeqCst) {
389 self.queue.push(request);
390 }
391 }
392
393 let mut salvaged_requests = std::collections::VecDeque::new();
394 let mut temp_salvaged = Vec::new();
395
396 while let Some(request) = self.salvaged.pop() {
397 temp_salvaged.push(request);
398 }
399
400 for request in temp_salvaged.into_iter() {
401 salvaged_requests.push_back(request.clone());
402 if !self.is_shutting_down.load(Ordering::SeqCst) {
403 self.salvaged.push(request);
404 }
405 }
406
407 Ok(SchedulerCheckpoint {
408 request_queue,
409 visited_urls,
410 salvaged_requests,
411 })
412 }
413
414 #[cfg(not(feature = "checkpoint"))]
415 pub async fn snapshot(&self) -> Result<(), SpiderError> {
416 Ok(())
417 }
418
419 pub async fn enqueue_request(&self, request: Request) -> Result<(), SpiderError> {
420 if !self.should_enqueue(&request) {
421 trace!("Request already visited, skipping: {}", request.url);
422 return Ok(());
423 }
424
425 let pending = self.pending.load(Ordering::SeqCst);
426 if pending >= self.max_pending {
427 warn!(
428 "Maximum pending requests reached ({}), request dropped due to backpressure: {}",
429 self.max_pending, request.url
430 );
431 return Err(SpiderError::GeneralError(
432 "Scheduler at maximum capacity, request dropped due to backpressure.".into(),
433 ));
434 }
435
436 trace!("Enqueuing request: {}", request.url);
437 if self
438 .tx
439 .send(SchedulerMessage::Enqueue(Arc::new(request.clone())))
440 .await
441 .is_err()
442 {
443 if !self.is_shutting_down.load(Ordering::SeqCst) {
444 error!(
445 "Scheduler internal message channel is closed. Salvaging request: {}",
446 request.url
447 );
448 }
449 self.salvaged.push(request);
450 return Err(SpiderError::GeneralError(
451 "Scheduler internal channel closed, request salvaged.".into(),
452 ));
453 }
454
455 trace!("Successfully enqueued request: {}", request.url);
456 Ok(())
457 }
458
459 pub async fn shutdown(&self) -> Result<(), SpiderError> {
465 self.is_shutting_down.store(true, Ordering::SeqCst);
466
467 if !self.tx.is_closed() {
468 self.tx.send(SchedulerMessage::Shutdown).await.map_err(|e| {
469 SpiderError::GeneralError(format!(
470 "Scheduler: Failed to send shutdown signal: {}",
471 e
472 ))
473 })
474 } else {
475 debug!("Scheduler internal channel already closed, skipping shutdown signal");
476 Ok(())
477 }
478 }
479
480 pub async fn mark_visited(&self, fingerprint: String) -> Result<(), SpiderError> {
486 trace!(
487 "Sending MarkAsVisited message for fingerprint: {}",
488 fingerprint
489 );
490 self.tx
491 .send(SchedulerMessage::MarkAsVisited(fingerprint))
492 .await
493 .map_err(|e| {
494 if !self.is_shutting_down.load(Ordering::SeqCst) {
495 error!("Scheduler internal message channel is closed. Failed to mark URL as visited: {}", e);
496 }
497 SpiderError::GeneralError(format!(
498 "Scheduler: Failed to send MarkAsVisited message: {}",
499 e
500 ))
501 })
502 }
503
504 pub async fn mark_visited_batch(&self, fingerprints: Vec<String>) -> Result<(), SpiderError> {
512 if fingerprints.is_empty() {
513 return Ok(());
514 }
515
516 trace!(
517 "Sending MarkAsVisitedBatch message for {} fingerprints",
518 fingerprints.len()
519 );
520 self.tx
521 .send(SchedulerMessage::MarkAsVisitedBatch(fingerprints))
522 .await
523 .map_err(|e| {
524 if !self.is_shutting_down.load(Ordering::SeqCst) {
525 error!("Scheduler internal message channel is closed. Failed to mark URLs as visited in batch: {}", e);
526 }
527 SpiderError::GeneralError(format!(
528 "Scheduler: Failed to send MarkAsVisitedBatch message: {}",
529 e
530 ))
531 })
532 }
533
534 pub fn is_visited(&self, fingerprint: &str) -> bool {
536 if !self.bloom.read().might_contain(fingerprint) {
537 return false;
538 }
539
540 {
541 let buffer = self.buffer.lock();
542 if buffer.iter().any(|item| item == fingerprint) {
543 return true;
544 }
545 }
546
547 self.visited.contains_key(fingerprint)
548 }
549
550 fn flush_buffer_now(&self) {
551 let mut buffer = self.buffer.lock();
552 if !buffer.is_empty() {
553 let items: Vec<String> = buffer.drain(..).collect();
554 drop(buffer);
555
556 let mut bloom = self.bloom.write();
557 for item in items {
558 bloom.add(&item);
559 }
560 }
561 }
562
563 async fn flush_buffer(&self, notify: Arc<Notify>) {
564 loop {
565 tokio::select! {
566 _ = notify.notified() => {
567 self.flush_buffer_now();
568 }
569 _ = tokio::time::sleep(tokio::time::Duration::from_millis(BLOOM_FLUSH_INTERVAL_MS)) => {
570 self.flush_buffer_now();
571 }
572 }
573
574 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
575 }
576 }
577
578 pub fn should_enqueue(&self, request: &Request) -> bool {
580 let fingerprint = request.fingerprint();
581 !self.is_visited(&fingerprint)
582 }
583
584 #[inline]
586 pub fn len(&self) -> usize {
587 self.pending.load(Ordering::Acquire)
588 }
589
590 #[inline]
592 pub fn is_empty(&self) -> bool {
593 self.len() == 0
594 }
595
596 #[inline]
598 pub fn is_idle(&self) -> bool {
599 self.is_empty()
600 }
601}