1#[cfg(feature = "checkpoint")]
45use crate::SchedulerCheckpoint;
46
47use spider_util::error::SpiderError;
48use spider_util::request::Request;
49use crossbeam::queue::SegQueue;
50use kanal::{AsyncReceiver, AsyncSender, bounded_async, unbounded_async};
51use moka::sync::Cache;
52use std::sync::Arc;
53use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
54use tracing::{debug, error, info, trace, warn};
55
56enum SchedulerMessage {
57 Enqueue(Box<Request>),
58 MarkAsVisited(String),
59 MarkAsVisitedBatch(Vec<String>),
60 Shutdown,
61}
62
63use spider_util::bloom_filter::BloomFilter;
64
65use tokio::sync::Notify;
66
67pub struct Scheduler {
68 request_queue: SegQueue<Request>,
69 visited_urls: Cache<String, bool>,
70 bloom_filter: std::sync::Arc<parking_lot::RwLock<BloomFilter>>,
71 bloom_filter_buffer: Arc<std::sync::Mutex<Vec<String>>>,
72 bloom_filter_notify: Arc<Notify>,
73 tx_internal: AsyncSender<SchedulerMessage>,
74 pending_requests: AtomicUsize,
75 salvaged_requests: SegQueue<Request>,
76 pub(crate) is_shutting_down: AtomicBool,
77 max_pending_requests: usize,
78}
79
80impl Scheduler {
81 #[cfg(feature = "checkpoint")]
83 pub fn new(
84 initial_state: Option<SchedulerCheckpoint>,
85 ) -> (Arc<Self>, AsyncReceiver<Request>) {
86 let (tx_internal, rx_internal) = unbounded_async();
87
88 let (tx_req_out, rx_req_out) = bounded_async(100);
89
90 let request_queue: SegQueue<Request>;
91 let visited_urls: Cache<String, bool>;
92 let pending_requests: AtomicUsize;
93 let salvaged_requests: SegQueue<Request>;
94
95 if let Some(state) = initial_state {
96 info!(
97 "Initializing scheduler from checkpoint with {} requests, {} visited URLs, and {} salvaged requests.",
98 state.request_queue.len(),
99 state.visited_urls.len(),
100 state.salvaged_requests.len(),
101 );
102 let pending = state.request_queue.len() + state.salvaged_requests.len();
103 request_queue = SegQueue::new();
104 for request in state.request_queue {
105 request_queue.push(request);
106 }
107
108 visited_urls = Cache::builder()
109 .max_capacity(500000) .time_to_idle(std::time::Duration::from_secs(3600)) .eviction_listener(|_key, _value, _cause| {
112 })
114 .build();
115 for url in state.visited_urls {
116 visited_urls.insert(url, true);
117 }
118
119 pending_requests = AtomicUsize::new(pending);
120 salvaged_requests = SegQueue::new();
121 for request in state.salvaged_requests {
122 salvaged_requests.push(request);
123 }
124 } else {
125 request_queue = SegQueue::new();
126 visited_urls = Cache::builder().max_capacity(200000).build(); pending_requests = AtomicUsize::new(0);
128 salvaged_requests = SegQueue::new();
129 }
130
131 let bloom_filter_buffer = Arc::new(std::sync::Mutex::new(Vec::new()));
132 let bloom_filter_notify = Arc::new(Notify::new());
133
134 let scheduler = Arc::new(Scheduler {
135 request_queue,
136 visited_urls,
137 bloom_filter: std::sync::Arc::new(parking_lot::RwLock::new(BloomFilter::new(5000000, 5))), bloom_filter_buffer: bloom_filter_buffer.clone(),
139 bloom_filter_notify: bloom_filter_notify.clone(),
140 tx_internal,
141 pending_requests,
142 salvaged_requests,
143 is_shutting_down: AtomicBool::new(false),
144 max_pending_requests: 30000, });
146
147 let scheduler_clone_for_bloom = Arc::clone(&scheduler);
149 let bloom_filter_buffer_clone = bloom_filter_buffer.clone();
150 let bloom_filter_notify_clone = bloom_filter_notify.clone();
151 tokio::spawn(async move {
152 scheduler_clone_for_bloom.flush_bloom_filter_buffer(bloom_filter_buffer_clone, bloom_filter_notify_clone).await;
153 });
154
155 let scheduler_clone = Arc::clone(&scheduler);
156 tokio::spawn(async move {
157 scheduler_clone.run_loop(rx_internal, tx_req_out).await;
158 });
159
160 (scheduler, rx_req_out)
161 }
162
163 #[cfg(not(feature = "checkpoint"))]
165 pub fn new(
166 _initial_state: Option<()>, ) -> (Arc<Self>, AsyncReceiver<Request>) {
168 let (tx_internal, rx_internal) = unbounded_async();
169
170 let (tx_req_out, rx_req_out) = bounded_async(100);
171
172 let request_queue = SegQueue::new();
173 let visited_urls = Cache::builder()
174 .max_capacity(500000) .time_to_idle(std::time::Duration::from_secs(3600)) .eviction_listener(|_key, _value, _cause| {
177 })
179 .build();
180 let pending_requests = AtomicUsize::new(0);
181 let salvaged_requests = SegQueue::new();
182
183 let bloom_filter_buffer = Arc::new(std::sync::Mutex::new(Vec::new()));
184 let bloom_filter_notify = Arc::new(Notify::new());
185
186 let scheduler = Arc::new(Scheduler {
187 request_queue,
188 visited_urls,
189 bloom_filter: std::sync::Arc::new(parking_lot::RwLock::new(BloomFilter::new(5000000, 5))), bloom_filter_buffer: bloom_filter_buffer.clone(),
191 bloom_filter_notify: bloom_filter_notify.clone(),
192 tx_internal,
193 pending_requests,
194 salvaged_requests,
195 is_shutting_down: AtomicBool::new(false),
196 max_pending_requests: 30000, });
198
199 let scheduler_clone_for_bloom = Arc::clone(&scheduler);
201 let bloom_filter_buffer_clone = bloom_filter_buffer.clone();
202 let bloom_filter_notify_clone = bloom_filter_notify.clone();
203 tokio::spawn(async move {
204 scheduler_clone_for_bloom.flush_bloom_filter_buffer(bloom_filter_buffer_clone, bloom_filter_notify_clone).await;
205 });
206
207 let scheduler_clone = Arc::clone(&scheduler);
208 tokio::spawn(async move {
209 scheduler_clone.run_loop(rx_internal, tx_req_out).await;
210 });
211
212 (scheduler, rx_req_out)
213 }
214
215 async fn run_loop(
216 &self,
217 rx_internal: AsyncReceiver<SchedulerMessage>,
218 tx_req_out: AsyncSender<Request>,
219 ) {
220 info!(
221 "Scheduler run_loop started with max pending requests: {}",
222 self.max_pending_requests
223 );
224 loop {
225 if let Ok(Some(msg)) = rx_internal.try_recv() {
226 trace!("Processing pending internal message");
227 if !self.handle_message(Ok(msg)).await {
228 break;
229 }
230 continue;
231 }
232
233 let maybe_request = if !tx_req_out.is_closed() && !self.is_idle() {
234 self.request_queue.pop()
235 } else {
236 None
237 };
238
239 if let Some(request) = maybe_request {
240 trace!("Sending request to crawler: {}", request.url);
241 tokio::select! {
242 send_res = tx_req_out.send(request) => {
243 if send_res.is_err() {
244 error!("Crawler receiver dropped. Scheduler can no longer send requests.");
245 } else {
246 trace!("Successfully sent request to crawler");
247 }
248 self.pending_requests.fetch_sub(1, Ordering::SeqCst);
249 },
250 recv_res = rx_internal.recv() => {
251 trace!("Received internal message while sending request");
252 if !self.handle_message(recv_res).await {
253 break;
254 }
255 continue;
256 }
257 }
258 } else {
259 trace!("No pending requests, waiting for internal message");
260 if !self.handle_message(rx_internal.recv().await).await {
261 break;
262 }
263 }
264 }
265 info!(
266 "Scheduler run_loop finished with {} pending requests remaining.",
267 self.pending_requests.load(Ordering::SeqCst)
268 );
269 }
270
271 async fn handle_message(&self, msg: Result<SchedulerMessage, kanal::ReceiveError>) -> bool {
272 match msg {
273 Ok(SchedulerMessage::Enqueue(boxed_request)) => {
274 let request = *boxed_request;
275 trace!("Enqueuing request: {}", request.url);
276 self.request_queue.push(request);
277 self.pending_requests.fetch_add(1, Ordering::SeqCst);
278 true
279 }
280 Ok(SchedulerMessage::MarkAsVisited(fingerprint)) => {
281 trace!("Marking URL fingerprint as visited: {}", fingerprint);
282 self.visited_urls.insert(fingerprint.clone(), true);
283
284 {
286 let mut buffer = self.bloom_filter_buffer.lock().unwrap();
287 buffer.push(fingerprint.clone());
288 if buffer.len() >= 100 { self.bloom_filter_notify.notify_one();
290 }
291 }
292
293 debug!("Marked URL as visited: {}", fingerprint);
294 true
295 }
296 Ok(SchedulerMessage::MarkAsVisitedBatch(fingerprints)) => {
297 let count = fingerprints.len();
298 trace!("Marking {} URL fingerprints as visited in batch", count);
299 for fingerprint in &fingerprints {
300 self.visited_urls.insert(fingerprint.clone(), true);
301 }
302
303 {
305 let mut buffer = self.bloom_filter_buffer.lock().unwrap();
306 buffer.extend(fingerprints);
307 if buffer.len() >= 100 { self.bloom_filter_notify.notify_one();
309 }
310 }
311
312 debug!("Marked {} URLs as visited in batch", count);
313 true
314 }
315 Ok(SchedulerMessage::Shutdown) => {
316 info!("Scheduler received shutdown signal. Exiting run_loop.");
317 self.is_shutting_down.store(true, Ordering::SeqCst);
318
319 self.flush_bloom_filter_buffer_now();
321
322 false
323 }
324 Err(_) => {
325 warn!("Scheduler internal message channel closed. Exiting run_loop.");
326 self.is_shutting_down.store(true, Ordering::SeqCst);
327 false
328 }
329 }
330 }
331
332 #[cfg(feature = "checkpoint")]
335 pub async fn snapshot(&self) -> Result<SchedulerCheckpoint, SpiderError> {
336 let visited_urls = dashmap::DashSet::new();
337 for entry in self.visited_urls.iter() {
338 let (key, _) = entry;
339 visited_urls.insert(key.as_ref().clone());
340 }
341
342 let mut request_queue = std::collections::VecDeque::new();
345 let mut temp_requests = Vec::new();
346
347 while let Some(request) = self.request_queue.pop() {
349 temp_requests.push(request);
350 }
351
352 for request in temp_requests.into_iter() {
354 request_queue.push_back(request.clone());
355 if !self.is_shutting_down.load(Ordering::SeqCst) {
357 self.request_queue.push(request);
358 }
359 }
360
361 let mut salvaged_requests = std::collections::VecDeque::new();
363 let mut temp_salvaged = Vec::new();
364
365 while let Some(request) = self.salvaged_requests.pop() {
366 temp_salvaged.push(request);
367 }
368
369 for request in temp_salvaged.into_iter() {
370 salvaged_requests.push_back(request.clone());
371 if !self.is_shutting_down.load(Ordering::SeqCst) {
373 self.salvaged_requests.push(request);
374 }
375 }
376
377 Ok(SchedulerCheckpoint {
378 request_queue,
379 visited_urls,
380 salvaged_requests,
381 })
382 }
383
384 #[cfg(not(feature = "checkpoint"))]
386 pub async fn snapshot(&self) -> Result<(), SpiderError> {
387 Ok(())
389 }
390
391 pub async fn enqueue_request(&self, request: Request) -> Result<(), SpiderError> {
393 if !self.should_enqueue_request(&request) {
394 trace!("Request already visited, skipping: {}", request.url);
395 return Ok(());
396 }
397
398 let current_pending = self.pending_requests.load(Ordering::SeqCst);
400 if current_pending >= self.max_pending_requests {
401 warn!(
402 "Maximum pending requests reached ({}), request dropped due to backpressure: {}",
403 self.max_pending_requests, request.url
404 );
405 return Err(SpiderError::GeneralError(
406 "Scheduler at maximum capacity, request dropped due to backpressure.".into(),
407 ));
408 }
409
410 trace!("Enqueuing request: {}", request.url);
411 if self
412 .tx_internal
413 .send(SchedulerMessage::Enqueue(Box::new(request.clone())))
414 .await
415 .is_err()
416 {
417 if !self.is_shutting_down.load(Ordering::SeqCst) {
418 error!(
419 "Scheduler internal message channel is closed. Salvaging request: {}",
420 request.url
421 );
422 }
423 self.salvaged_requests.push(request);
424 return Err(SpiderError::GeneralError(
425 "Scheduler internal channel closed, request salvaged.".into(),
426 ));
427 }
428
429 trace!("Successfully enqueued request: {}", request.url);
430 Ok(())
431 }
432
433 pub async fn shutdown(&self) -> Result<(), SpiderError> {
435 self.is_shutting_down.store(true, Ordering::SeqCst);
436
437 if !self.tx_internal.is_closed() {
438 self.tx_internal
439 .send(SchedulerMessage::Shutdown)
440 .await
441 .map_err(|e| {
442 SpiderError::GeneralError(format!(
443 "Scheduler: Failed to send shutdown signal: {}",
444 e
445 ))
446 })
447 } else {
448 debug!("Scheduler internal channel already closed, skipping shutdown signal");
449 Ok(())
450 }
451 }
452
453 pub async fn send_mark_as_visited(&self, fingerprint: String) -> Result<(), SpiderError> {
455 trace!(
456 "Sending MarkAsVisited message for fingerprint: {}",
457 fingerprint
458 );
459 self.tx_internal
460 .send(SchedulerMessage::MarkAsVisited(fingerprint.clone()))
461 .await
462 .map_err(|e| {
463 if !self.is_shutting_down.load(Ordering::SeqCst) {
464 error!("Scheduler internal message channel is closed. Failed to mark URL as visited (fingerprint: {}): {}", fingerprint, e);
465 }
466 SpiderError::GeneralError(format!(
467 "Scheduler: Failed to send MarkAsVisited message: {}",
468 e
469 ))
470 })
471 }
472
473 pub async fn send_mark_as_visited_batch(&self, fingerprints: Vec<String>) -> Result<(), SpiderError> {
475 if fingerprints.is_empty() {
476 return Ok(());
477 }
478
479 trace!(
480 "Sending MarkAsVisitedBatch message for {} fingerprints",
481 fingerprints.len()
482 );
483 self.tx_internal
484 .send(SchedulerMessage::MarkAsVisitedBatch(fingerprints))
485 .await
486 .map_err(|e| {
487 if !self.is_shutting_down.load(Ordering::SeqCst) {
488 error!("Scheduler internal message channel is closed. Failed to mark URLs as visited in batch: {}", e);
489 }
490 SpiderError::GeneralError(format!(
491 "Scheduler: Failed to send MarkAsVisitedBatch message: {}",
492 e
493 ))
494 })
495 }
496
497 pub fn has_been_visited(&self, fingerprint: &str) -> bool {
499 if !self.bloom_filter.read().might_contain(fingerprint) {
501 return false;
502 }
503
504 {
506 let buffer = self.bloom_filter_buffer.lock().unwrap();
507 if buffer.iter().any(|item| item == fingerprint) {
508 return true;
509 }
510 }
511
512 self.visited_urls.contains_key(fingerprint)
513 }
514
515 fn flush_bloom_filter_buffer_now(&self) {
517 let mut buffer = self.bloom_filter_buffer.lock().unwrap();
518 if !buffer.is_empty() {
519 let items: Vec<String> = buffer.drain(..).collect();
520 drop(buffer); let mut bloom_filter = self.bloom_filter.write();
523 for item in items {
524 bloom_filter.add(&item);
525 }
526 }
527 }
528
529 async fn flush_bloom_filter_buffer(
531 &self,
532 _buffer: Arc<std::sync::Mutex<Vec<String>>>,
533 notify: Arc<Notify>,
534 ) {
535 loop {
536 tokio::select! {
538 _ = notify.notified() => {
539 self.flush_bloom_filter_buffer_now();
541 }
542 _ = tokio::time::sleep(tokio::time::Duration::from_millis(100)) => {
543 self.flush_bloom_filter_buffer_now();
545 }
546 }
547
548 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
550 }
551 }
552
553 pub fn should_enqueue_request(&self, request: &Request) -> bool {
555 let fingerprint = request.fingerprint();
556 !self.has_been_visited(&fingerprint)
557 }
558
559 #[inline]
561 pub fn len(&self) -> usize {
562 self.pending_requests.load(Ordering::SeqCst)
563 }
564
565 #[inline]
567 pub fn is_empty(&self) -> bool {
568 self.len() == 0
569 }
570
571 #[inline]
573 pub fn is_idle(&self) -> bool {
574 self.is_empty()
575 }
576}