1#[cfg(feature = "checkpoint")]
45use spider_util::constants::DEFAULT_VISITED_CACHE_SIZE;
46#[cfg(feature = "checkpoint")]
47use crate::SchedulerCheckpoint;
48
49#[cfg(not(feature = "checkpoint"))]
50use spider_util::constants::MAX_PENDING_REQUESTS;
51
52use spider_util::constants::{
53 BLOOM_FILTER_CAPACITY, BLOOM_FILTER_HASH_FUNCTIONS,
54 VISITED_URL_CACHE_CAPACITY, VISITED_URL_CACHE_TTL_SECS,
55};
56use spider_util::error::SpiderError;
57use spider_util::request::Request;
58use crossbeam::queue::SegQueue;
59use kanal::{AsyncReceiver, AsyncSender, bounded_async, unbounded_async};
60use moka::sync::Cache;
61use std::sync::Arc;
62use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
63use log::{debug, error, info, trace, warn};
64
65enum SchedulerMessage {
66 Enqueue(Box<Request>),
67 MarkAsVisited(String),
68 MarkAsVisitedBatch(Vec<String>),
69 Shutdown,
70}
71
72use spider_util::bloom_filter::BloomFilter;
73
74use tokio::sync::Notify;
75
76pub struct Scheduler {
77 request_queue: SegQueue<Request>,
78 visited_urls: Cache<String, bool>,
79 bloom_filter: std::sync::Arc<parking_lot::RwLock<BloomFilter>>,
80 bloom_filter_buffer: Arc<std::sync::Mutex<Vec<String>>>,
81 bloom_filter_notify: Arc<Notify>,
82 tx_internal: AsyncSender<SchedulerMessage>,
83 pending_requests: AtomicUsize,
84 salvaged_requests: SegQueue<Request>,
85 pub(crate) is_shutting_down: AtomicBool,
86 max_pending_requests: usize,
87}
88
89impl Scheduler {
90 #[cfg(feature = "checkpoint")]
92 pub fn new(
93 initial_state: Option<SchedulerCheckpoint>,
94 ) -> (Arc<Self>, AsyncReceiver<Request>) {
95 let (tx_internal, rx_internal) = unbounded_async();
96
97 let (tx_req_out, rx_req_out) = bounded_async(100);
98
99 let request_queue: SegQueue<Request>;
100 let visited_urls: Cache<String, bool>;
101 let pending_requests: AtomicUsize;
102 let salvaged_requests: SegQueue<Request>;
103
104 if let Some(state) = initial_state {
105 info!(
106 "Initializing scheduler from checkpoint with {} requests, {} visited URLs, and {} salvaged requests.",
107 state.request_queue.len(),
108 state.visited_urls.len(),
109 state.salvaged_requests.len(),
110 );
111 let pending = state.request_queue.len() + state.salvaged_requests.len();
112 request_queue = SegQueue::new();
113 for request in state.request_queue {
114 request_queue.push(request);
115 }
116
117 visited_urls = Cache::builder()
118 .max_capacity(VISITED_URL_CACHE_CAPACITY)
119 .time_to_idle(std::time::Duration::from_secs(VISITED_URL_CACHE_TTL_SECS))
120 .eviction_listener(|_key, _value, _cause| {
121 })
123 .build();
124 for url in state.visited_urls {
125 visited_urls.insert(url, true);
126 }
127
128 pending_requests = AtomicUsize::new(pending);
129 salvaged_requests = SegQueue::new();
130 for request in state.salvaged_requests {
131 salvaged_requests.push(request);
132 }
133 } else {
134 request_queue = SegQueue::new();
135 visited_urls = Cache::builder().max_capacity(DEFAULT_VISITED_CACHE_SIZE).build();
136 pending_requests = AtomicUsize::new(0);
137 salvaged_requests = SegQueue::new();
138 }
139
140 let bloom_filter_buffer = Arc::new(std::sync::Mutex::new(Vec::new()));
141 let bloom_filter_notify = Arc::new(Notify::new());
142
143 let scheduler = Arc::new(Scheduler {
144 request_queue,
145 visited_urls,
146 bloom_filter: std::sync::Arc::new(parking_lot::RwLock::new(BloomFilter::new(
147 BLOOM_FILTER_CAPACITY,
148 BLOOM_FILTER_HASH_FUNCTIONS,
149 ))),
150 bloom_filter_buffer: bloom_filter_buffer.clone(),
151 bloom_filter_notify: bloom_filter_notify.clone(),
152 tx_internal,
153 pending_requests,
154 salvaged_requests,
155 is_shutting_down: AtomicBool::new(false),
156 max_pending_requests: 30000, });
158
159 let scheduler_clone_for_bloom = Arc::clone(&scheduler);
161 let bloom_filter_buffer_clone = bloom_filter_buffer.clone();
162 let bloom_filter_notify_clone = bloom_filter_notify.clone();
163 tokio::spawn(async move {
164 scheduler_clone_for_bloom.flush_bloom_filter_buffer(bloom_filter_buffer_clone, bloom_filter_notify_clone).await;
165 });
166
167 let scheduler_clone = Arc::clone(&scheduler);
168 tokio::spawn(async move {
169 scheduler_clone.run_loop(rx_internal, tx_req_out).await;
170 });
171
172 (scheduler, rx_req_out)
173 }
174
175 #[cfg(not(feature = "checkpoint"))]
177 pub fn new(
178 _initial_state: Option<()>, ) -> (Arc<Self>, AsyncReceiver<Request>) {
180 let (tx_internal, rx_internal) = unbounded_async();
181
182 let (tx_req_out, rx_req_out) = bounded_async(100);
183
184 let request_queue = SegQueue::new();
185 let visited_urls = Cache::builder()
186 .max_capacity(VISITED_URL_CACHE_CAPACITY)
187 .time_to_idle(std::time::Duration::from_secs(VISITED_URL_CACHE_TTL_SECS))
188 .eviction_listener(|_key, _value, _cause| {
189 })
191 .build();
192 let pending_requests = AtomicUsize::new(0);
193 let salvaged_requests = SegQueue::new();
194
195 let bloom_filter_buffer = Arc::new(std::sync::Mutex::new(Vec::new()));
196 let bloom_filter_notify = Arc::new(Notify::new());
197
198 let scheduler = Arc::new(Scheduler {
199 request_queue,
200 visited_urls,
201 bloom_filter: std::sync::Arc::new(parking_lot::RwLock::new(BloomFilter::new(
202 BLOOM_FILTER_CAPACITY,
203 BLOOM_FILTER_HASH_FUNCTIONS,
204 ))),
205 bloom_filter_buffer: bloom_filter_buffer.clone(),
206 bloom_filter_notify: bloom_filter_notify.clone(),
207 tx_internal,
208 pending_requests,
209 salvaged_requests,
210 is_shutting_down: AtomicBool::new(false),
211 max_pending_requests: MAX_PENDING_REQUESTS,
212 });
213
214 let scheduler_clone_for_bloom = Arc::clone(&scheduler);
216 let bloom_filter_buffer_clone = bloom_filter_buffer.clone();
217 let bloom_filter_notify_clone = bloom_filter_notify.clone();
218 tokio::spawn(async move {
219 scheduler_clone_for_bloom.flush_bloom_filter_buffer(bloom_filter_buffer_clone, bloom_filter_notify_clone).await;
220 });
221
222 let scheduler_clone = Arc::clone(&scheduler);
223 tokio::spawn(async move {
224 scheduler_clone.run_loop(rx_internal, tx_req_out).await;
225 });
226
227 (scheduler, rx_req_out)
228 }
229
230 async fn run_loop(
231 &self,
232 rx_internal: AsyncReceiver<SchedulerMessage>,
233 tx_req_out: AsyncSender<Request>,
234 ) {
235 info!(
236 "Scheduler run_loop started with max pending requests: {}",
237 self.max_pending_requests
238 );
239 loop {
240 if let Ok(Some(msg)) = rx_internal.try_recv() {
241 trace!("Processing pending internal message");
242 if !self.handle_message(Ok(msg)).await {
243 break;
244 }
245 continue;
246 }
247
248 let maybe_request = if !tx_req_out.is_closed() && !self.is_idle() {
249 self.request_queue.pop()
250 } else {
251 None
252 };
253
254 if let Some(request) = maybe_request {
255 trace!("Sending request to crawler: {}", request.url);
256 tokio::select! {
257 send_res = tx_req_out.send(request) => {
258 if send_res.is_err() {
259 error!("Crawler receiver dropped. Scheduler can no longer send requests.");
260 } else {
261 trace!("Successfully sent request to crawler");
262 }
263 self.pending_requests.fetch_sub(1, Ordering::SeqCst);
264 },
265 recv_res = rx_internal.recv() => {
266 trace!("Received internal message while sending request");
267 if !self.handle_message(recv_res).await {
268 break;
269 }
270 continue;
271 }
272 }
273 } else {
274 trace!("No pending requests, waiting for internal message");
275 if !self.handle_message(rx_internal.recv().await).await {
276 break;
277 }
278 }
279 }
280 info!(
281 "Scheduler run_loop finished with {} pending requests remaining.",
282 self.pending_requests.load(Ordering::SeqCst)
283 );
284 }
285
286 async fn handle_message(&self, msg: Result<SchedulerMessage, kanal::ReceiveError>) -> bool {
287 match msg {
288 Ok(SchedulerMessage::Enqueue(boxed_request)) => {
289 let request = *boxed_request;
290 trace!("Enqueuing request: {}", request.url);
291 self.request_queue.push(request);
292 self.pending_requests.fetch_add(1, Ordering::SeqCst);
293 true
294 }
295 Ok(SchedulerMessage::MarkAsVisited(fingerprint)) => {
296 trace!("Marking URL fingerprint as visited: {}", fingerprint);
297 self.visited_urls.insert(fingerprint.clone(), true);
298
299 {
301 let mut buffer = self.bloom_filter_buffer.lock().unwrap();
302 buffer.push(fingerprint.clone());
303 if buffer.len() >= 100 { self.bloom_filter_notify.notify_one();
305 }
306 }
307
308 debug!("Marked URL as visited: {}", fingerprint);
309 true
310 }
311 Ok(SchedulerMessage::MarkAsVisitedBatch(fingerprints)) => {
312 let count = fingerprints.len();
313 trace!("Marking {} URL fingerprints as visited in batch", count);
314 for fingerprint in &fingerprints {
315 self.visited_urls.insert(fingerprint.clone(), true);
316 }
317
318 {
320 let mut buffer = self.bloom_filter_buffer.lock().unwrap();
321 buffer.extend(fingerprints);
322 if buffer.len() >= 100 { self.bloom_filter_notify.notify_one();
324 }
325 }
326
327 debug!("Marked {} URLs as visited in batch", count);
328 true
329 }
330 Ok(SchedulerMessage::Shutdown) => {
331 info!("Scheduler received shutdown signal. Exiting run_loop.");
332 self.is_shutting_down.store(true, Ordering::SeqCst);
333
334 self.flush_bloom_filter_buffer_now();
336
337 false
338 }
339 Err(_) => {
340 warn!("Scheduler internal message channel closed. Exiting run_loop.");
341 self.is_shutting_down.store(true, Ordering::SeqCst);
342 false
343 }
344 }
345 }
346
347 #[cfg(feature = "checkpoint")]
350 pub async fn snapshot(&self) -> Result<SchedulerCheckpoint, SpiderError> {
351 let visited_urls = dashmap::DashSet::new();
352 for entry in self.visited_urls.iter() {
353 let (key, _) = entry;
354 visited_urls.insert(key.as_ref().clone());
355 }
356
357 let mut request_queue = std::collections::VecDeque::new();
360 let mut temp_requests = Vec::new();
361
362 while let Some(request) = self.request_queue.pop() {
364 temp_requests.push(request);
365 }
366
367 for request in temp_requests.into_iter() {
369 request_queue.push_back(request.clone());
370 if !self.is_shutting_down.load(Ordering::SeqCst) {
372 self.request_queue.push(request);
373 }
374 }
375
376 let mut salvaged_requests = std::collections::VecDeque::new();
378 let mut temp_salvaged = Vec::new();
379
380 while let Some(request) = self.salvaged_requests.pop() {
381 temp_salvaged.push(request);
382 }
383
384 for request in temp_salvaged.into_iter() {
385 salvaged_requests.push_back(request.clone());
386 if !self.is_shutting_down.load(Ordering::SeqCst) {
388 self.salvaged_requests.push(request);
389 }
390 }
391
392 Ok(SchedulerCheckpoint {
393 request_queue,
394 visited_urls,
395 salvaged_requests,
396 })
397 }
398
399 #[cfg(not(feature = "checkpoint"))]
401 pub async fn snapshot(&self) -> Result<(), SpiderError> {
402 Ok(())
404 }
405
406 pub async fn enqueue_request(&self, request: Request) -> Result<(), SpiderError> {
408 if !self.should_enqueue_request(&request) {
409 trace!("Request already visited, skipping: {}", request.url);
410 return Ok(());
411 }
412
413 let current_pending = self.pending_requests.load(Ordering::SeqCst);
414 if current_pending >= self.max_pending_requests {
415 warn!(
416 "Maximum pending requests reached ({}), request dropped due to backpressure: {}",
417 self.max_pending_requests, request.url
418 );
419 return Err(SpiderError::GeneralError(
420 "Scheduler at maximum capacity, request dropped due to backpressure.".into(),
421 ));
422 }
423
424 trace!("Enqueuing request: {}", request.url);
425 if self
426 .tx_internal
427 .send(SchedulerMessage::Enqueue(Box::new(request.clone())))
428 .await
429 .is_err()
430 {
431 if !self.is_shutting_down.load(Ordering::SeqCst) {
432 error!(
433 "Scheduler internal message channel is closed. Salvaging request: {}",
434 request.url
435 );
436 }
437 self.salvaged_requests.push(request);
438 return Err(SpiderError::GeneralError(
439 "Scheduler internal channel closed, request salvaged.".into(),
440 ));
441 }
442
443 trace!("Successfully enqueued request: {}", request.url);
444 Ok(())
445 }
446
447 pub async fn shutdown(&self) -> Result<(), SpiderError> {
449 self.is_shutting_down.store(true, Ordering::SeqCst);
450
451 if !self.tx_internal.is_closed() {
452 self.tx_internal
453 .send(SchedulerMessage::Shutdown)
454 .await
455 .map_err(|e| {
456 SpiderError::GeneralError(format!(
457 "Scheduler: Failed to send shutdown signal: {}",
458 e
459 ))
460 })
461 } else {
462 debug!("Scheduler internal channel already closed, skipping shutdown signal");
463 Ok(())
464 }
465 }
466
467 pub async fn send_mark_as_visited(&self, fingerprint: String) -> Result<(), SpiderError> {
469 trace!(
470 "Sending MarkAsVisited message for fingerprint: {}",
471 fingerprint
472 );
473 self.tx_internal
474 .send(SchedulerMessage::MarkAsVisited(fingerprint.clone()))
475 .await
476 .map_err(|e| {
477 if !self.is_shutting_down.load(Ordering::SeqCst) {
478 error!("Scheduler internal message channel is closed. Failed to mark URL as visited (fingerprint: {}): {}", fingerprint, e);
479 }
480 SpiderError::GeneralError(format!(
481 "Scheduler: Failed to send MarkAsVisited message: {}",
482 e
483 ))
484 })
485 }
486
487 pub async fn send_mark_as_visited_batch(&self, fingerprints: Vec<String>) -> Result<(), SpiderError> {
489 if fingerprints.is_empty() {
490 return Ok(());
491 }
492
493 trace!(
494 "Sending MarkAsVisitedBatch message for {} fingerprints",
495 fingerprints.len()
496 );
497 self.tx_internal
498 .send(SchedulerMessage::MarkAsVisitedBatch(fingerprints))
499 .await
500 .map_err(|e| {
501 if !self.is_shutting_down.load(Ordering::SeqCst) {
502 error!("Scheduler internal message channel is closed. Failed to mark URLs as visited in batch: {}", e);
503 }
504 SpiderError::GeneralError(format!(
505 "Scheduler: Failed to send MarkAsVisitedBatch message: {}",
506 e
507 ))
508 })
509 }
510
511 pub fn has_been_visited(&self, fingerprint: &str) -> bool {
513 if !self.bloom_filter.read().might_contain(fingerprint) {
514 return false;
515 }
516
517 {
518 let buffer = self.bloom_filter_buffer.lock().unwrap();
519 if buffer.iter().any(|item| item == fingerprint) {
520 return true;
521 }
522 }
523
524 self.visited_urls.contains_key(fingerprint)
525 }
526
527 fn flush_bloom_filter_buffer_now(&self) {
529 let mut buffer = self.bloom_filter_buffer.lock().unwrap();
530 if !buffer.is_empty() {
531 let items: Vec<String> = buffer.drain(..).collect();
532 drop(buffer); let mut bloom_filter = self.bloom_filter.write();
535 for item in items {
536 bloom_filter.add(&item);
537 }
538 }
539 }
540
541 async fn flush_bloom_filter_buffer(
543 &self,
544 _buffer: Arc<std::sync::Mutex<Vec<String>>>,
545 notify: Arc<Notify>,
546 ) {
547 loop {
548 tokio::select! {
550 _ = notify.notified() => {
551 self.flush_bloom_filter_buffer_now();
553 }
554 _ = tokio::time::sleep(tokio::time::Duration::from_millis(100)) => {
555 self.flush_bloom_filter_buffer_now();
557 }
558 }
559
560 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
562 }
563 }
564
565 pub fn should_enqueue_request(&self, request: &Request) -> bool {
567 let fingerprint = request.fingerprint();
568 !self.has_been_visited(&fingerprint)
569 }
570
571 #[inline]
573 pub fn len(&self) -> usize {
574 self.pending_requests.load(Ordering::SeqCst)
575 }
576
577 #[inline]
579 pub fn is_empty(&self) -> bool {
580 self.len() == 0
581 }
582
583 #[inline]
585 pub fn is_idle(&self) -> bool {
586 self.is_empty()
587 }
588}