1#[cfg(feature = "checkpoint")]
45use spider_util::constants::DEFAULT_VISITED_CACHE_SIZE;
46#[cfg(feature = "checkpoint")]
47use crate::SchedulerCheckpoint;
48
49#[cfg(not(feature = "checkpoint"))]
50use spider_util::constants::MAX_PENDING_REQUESTS;
51
52use spider_util::constants::{
53 BLOOM_FILTER_CAPACITY, BLOOM_FILTER_HASH_FUNCTIONS,
54 VISITED_URL_CACHE_CAPACITY, VISITED_URL_CACHE_TTL_SECS,
55};
56use spider_util::error::SpiderError;
57use spider_util::request::Request;
58use crossbeam::queue::SegQueue;
59use kanal::{AsyncReceiver, AsyncSender, bounded_async, unbounded_async};
60use moka::sync::Cache;
61use std::sync::Arc;
62use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
63use log::{debug, error, info, trace, warn};
64
65enum SchedulerMessage {
66 Enqueue(Box<Request>),
67 MarkAsVisited(String),
68 MarkAsVisitedBatch(Vec<String>),
69 Shutdown,
70}
71
72use spider_util::bloom::BloomFilter;
73
74use tokio::sync::Notify;
75
76pub struct Scheduler {
77 queue: SegQueue<Request>,
78 visited: Cache<String, bool>,
79 bloom: std::sync::Arc<parking_lot::RwLock<BloomFilter>>,
80 buffer: Arc<std::sync::Mutex<Vec<String>>>,
81 notify: Arc<Notify>,
82 tx: AsyncSender<SchedulerMessage>,
83 pending: AtomicUsize,
84 salvaged: SegQueue<Request>,
85 pub(crate) is_shutting_down: AtomicBool,
86 max_pending: usize,
87}
88
89impl Scheduler {
90 #[cfg(feature = "checkpoint")]
92 pub fn new(
93 initial_state: Option<SchedulerCheckpoint>,
94 ) -> (Arc<Self>, AsyncReceiver<Request>) {
95 let (tx, rx_internal) = unbounded_async();
96 let (tx_out, rx_out) = bounded_async(100);
97
98 let queue: SegQueue<Request>;
99 let visited: Cache<String, bool>;
100 let pending: AtomicUsize;
101 let salvaged: SegQueue<Request>;
102
103 if let Some(state) = initial_state {
104 info!(
105 "Initializing scheduler from checkpoint with {} requests, {} visited URLs, and {} salvaged requests.",
106 state.request_queue.len(),
107 state.visited_urls.len(),
108 state.salvaged_requests.len(),
109 );
110 let pend = state.request_queue.len() + state.salvaged_requests.len();
111 queue = SegQueue::new();
112 for request in state.request_queue {
113 queue.push(request);
114 }
115
116 visited = Cache::builder()
117 .max_capacity(VISITED_URL_CACHE_CAPACITY)
118 .time_to_idle(std::time::Duration::from_secs(VISITED_URL_CACHE_TTL_SECS))
119 .eviction_listener(|_key, _value, _cause| {})
120 .build();
121 for url in state.visited_urls {
122 visited.insert(url, true);
123 }
124
125 pending = AtomicUsize::new(pend);
126 salvaged = SegQueue::new();
127 for request in state.salvaged_requests {
128 salvaged.push(request);
129 }
130 } else {
131 queue = SegQueue::new();
132 visited = Cache::builder().max_capacity(DEFAULT_VISITED_CACHE_SIZE).build();
133 pending = AtomicUsize::new(0);
134 salvaged = SegQueue::new();
135 }
136
137 let buffer = Arc::new(std::sync::Mutex::new(Vec::new()));
138 let notify = Arc::new(Notify::new());
139
140 let scheduler = Arc::new(Scheduler {
141 queue,
142 visited,
143 bloom: std::sync::Arc::new(parking_lot::RwLock::new(BloomFilter::new(
144 BLOOM_FILTER_CAPACITY,
145 BLOOM_FILTER_HASH_FUNCTIONS,
146 ))),
147 buffer: buffer.clone(),
148 notify: notify.clone(),
149 tx,
150 pending,
151 salvaged,
152 is_shutting_down: AtomicBool::new(false),
153 max_pending: 30000,
154 });
155
156 let scheduler_bloom = Arc::clone(&scheduler);
157 let buffer_clone = buffer.clone();
158 let notify_clone = notify.clone();
159 tokio::spawn(async move {
160 scheduler_bloom.flush_buffer(buffer_clone, notify_clone).await;
161 });
162
163 let scheduler_task = Arc::clone(&scheduler);
164 tokio::spawn(async move {
165 scheduler_task.run_loop(rx_internal, tx_out).await;
166 });
167
168 (scheduler, rx_out)
169 }
170
171 #[cfg(not(feature = "checkpoint"))]
172 pub fn new(
173 _initial_state: Option<()>,
174 ) -> (Arc<Self>, AsyncReceiver<Request>) {
175 let (tx, rx_internal) = unbounded_async();
176 let (tx_out, rx_out) = bounded_async(100);
177
178 let queue = SegQueue::new();
179 let visited = Cache::builder()
180 .max_capacity(VISITED_URL_CACHE_CAPACITY)
181 .time_to_idle(std::time::Duration::from_secs(VISITED_URL_CACHE_TTL_SECS))
182 .eviction_listener(|_key, _value, _cause| {})
183 .build();
184 let pending = AtomicUsize::new(0);
185 let salvaged = SegQueue::new();
186
187 let buffer = Arc::new(std::sync::Mutex::new(Vec::new()));
188 let notify = Arc::new(Notify::new());
189
190 let scheduler = Arc::new(Scheduler {
191 queue,
192 visited,
193 bloom: std::sync::Arc::new(parking_lot::RwLock::new(BloomFilter::new(
194 BLOOM_FILTER_CAPACITY,
195 BLOOM_FILTER_HASH_FUNCTIONS,
196 ))),
197 buffer: buffer.clone(),
198 notify: notify.clone(),
199 tx,
200 pending,
201 salvaged,
202 is_shutting_down: AtomicBool::new(false),
203 max_pending: MAX_PENDING_REQUESTS,
204 });
205
206 let scheduler_bloom = Arc::clone(&scheduler);
207 let buffer_clone = buffer.clone();
208 let notify_clone = notify.clone();
209 tokio::spawn(async move {
210 scheduler_bloom.flush_buffer(buffer_clone, notify_clone).await;
211 });
212
213 let scheduler_task = Arc::clone(&scheduler);
214 tokio::spawn(async move {
215 scheduler_task.run_loop(rx_internal, tx_out).await;
216 });
217
218 (scheduler, rx_out)
219 }
220
221 async fn run_loop(
222 &self,
223 rx_internal: AsyncReceiver<SchedulerMessage>,
224 tx_out: AsyncSender<Request>,
225 ) {
226 info!(
227 "Scheduler run_loop started with max pending: {}",
228 self.max_pending
229 );
230 loop {
231 if let Ok(Some(msg)) = rx_internal.try_recv() {
232 trace!("Processing pending internal message");
233 if !self.handle_message(Ok(msg)).await {
234 break;
235 }
236 continue;
237 }
238
239 let request = if !tx_out.is_closed() && !self.is_idle() {
240 self.queue.pop()
241 } else {
242 None
243 };
244
245 if let Some(request) = request {
246 trace!("Sending request to crawler: {}", request.url);
247 tokio::select! {
248 send_res = tx_out.send(request) => {
249 if send_res.is_err() {
250 error!("Crawler receiver dropped. Scheduler can no longer send requests.");
251 } else {
252 trace!("Successfully sent request to crawler");
253 }
254 self.pending.fetch_sub(1, Ordering::SeqCst);
255 },
256 recv_res = rx_internal.recv() => {
257 trace!("Received internal message while sending request");
258 if !self.handle_message(recv_res).await {
259 break;
260 }
261 continue;
262 }
263 }
264 } else {
265 trace!("No pending requests, waiting for internal message");
266 if !self.handle_message(rx_internal.recv().await).await {
267 break;
268 }
269 }
270 }
271 info!(
272 "Scheduler run_loop finished with {} pending requests remaining.",
273 self.pending.load(Ordering::SeqCst)
274 );
275 }
276
277 async fn handle_message(&self, msg: Result<SchedulerMessage, kanal::ReceiveError>) -> bool {
278 match msg {
279 Ok(SchedulerMessage::Enqueue(boxed_request)) => {
280 let request = *boxed_request;
281 trace!("Enqueuing request: {}", request.url);
282 self.queue.push(request);
283 self.pending.fetch_add(1, Ordering::SeqCst);
284 true
285 }
286 Ok(SchedulerMessage::MarkAsVisited(fingerprint)) => {
287 trace!("Marking URL fingerprint as visited: {}", fingerprint);
288
289 self.visited.insert(fingerprint.clone(), true);
291
292 debug!("Marked URL as visited: {}", fingerprint);
294
295 {
297 let mut buffer = self.buffer.lock().unwrap();
298 buffer.push(fingerprint);
299 if buffer.len() >= 100 {
300 self.notify.notify_one();
301 }
302 }
303
304 true
305 }
306 Ok(SchedulerMessage::MarkAsVisitedBatch(mut fingerprints)) => {
307 let count = fingerprints.len();
308 trace!("Marking {} URL fingerprints as visited in batch", count);
309
310 for fingerprint in &fingerprints {
312 self.visited.insert(fingerprint.clone(), true);
313 }
314
315 {
317 let mut buffer = self.buffer.lock().unwrap();
318 buffer.append(&mut fingerprints);
319 if buffer.len() >= 100 {
320 self.notify.notify_one();
321 }
322 }
323
324 debug!("Marked {} URLs as visited in batch", count);
325 true
326 }
327 Ok(SchedulerMessage::Shutdown) => {
328 info!("Scheduler received shutdown signal. Exiting run_loop.");
329 self.is_shutting_down.store(true, Ordering::SeqCst);
330 self.flush_buffer_now();
331 false
332 }
333 Err(_) => {
334 warn!("Scheduler internal message channel closed. Exiting run_loop.");
335 self.is_shutting_down.store(true, Ordering::SeqCst);
336 false
337 }
338 }
339 }
340
341 #[cfg(feature = "checkpoint")]
342 pub async fn snapshot(&self) -> Result<SchedulerCheckpoint, SpiderError> {
343 let visited_urls = dashmap::DashSet::new();
344 for entry in self.visited.iter() {
345 let (key, _) = entry;
346 visited_urls.insert(key.as_ref().clone());
347 }
348
349 let mut request_queue = std::collections::VecDeque::new();
350 let mut temp_requests = Vec::new();
351
352 while let Some(request) = self.queue.pop() {
353 temp_requests.push(request);
354 }
355
356 for request in temp_requests.into_iter() {
357 request_queue.push_back(request.clone());
358 if !self.is_shutting_down.load(Ordering::SeqCst) {
359 self.queue.push(request);
360 }
361 }
362
363 let mut salvaged_requests = std::collections::VecDeque::new();
364 let mut temp_salvaged = Vec::new();
365
366 while let Some(request) = self.salvaged.pop() {
367 temp_salvaged.push(request);
368 }
369
370 for request in temp_salvaged.into_iter() {
371 salvaged_requests.push_back(request.clone());
372 if !self.is_shutting_down.load(Ordering::SeqCst) {
373 self.salvaged.push(request);
374 }
375 }
376
377 Ok(SchedulerCheckpoint {
378 request_queue,
379 visited_urls,
380 salvaged_requests,
381 })
382 }
383
384 #[cfg(not(feature = "checkpoint"))]
385 pub async fn snapshot(&self) -> Result<(), SpiderError> {
386 Ok(())
387 }
388
389 pub async fn enqueue_request(&self, request: Request) -> Result<(), SpiderError> {
390 if !self.should_enqueue(&request) {
391 trace!("Request already visited, skipping: {}", request.url);
392 return Ok(());
393 }
394
395 let pending = self.pending.load(Ordering::SeqCst);
396 if pending >= self.max_pending {
397 warn!(
398 "Maximum pending requests reached ({}), request dropped due to backpressure: {}",
399 self.max_pending, request.url
400 );
401 return Err(SpiderError::GeneralError(
402 "Scheduler at maximum capacity, request dropped due to backpressure.".into(),
403 ));
404 }
405
406 trace!("Enqueuing request: {}", request.url);
407 if self
408 .tx
409 .send(SchedulerMessage::Enqueue(Box::new(request.clone())))
410 .await
411 .is_err()
412 {
413 if !self.is_shutting_down.load(Ordering::SeqCst) {
414 error!(
415 "Scheduler internal message channel is closed. Salvaging request: {}",
416 request.url
417 );
418 }
419 self.salvaged.push(request);
420 return Err(SpiderError::GeneralError(
421 "Scheduler internal channel closed, request salvaged.".into(),
422 ));
423 }
424
425 trace!("Successfully enqueued request: {}", request.url);
426 Ok(())
427 }
428
429 pub async fn shutdown(&self) -> Result<(), SpiderError> {
430 self.is_shutting_down.store(true, Ordering::SeqCst);
431
432 if !self.tx.is_closed() {
433 self.tx
434 .send(SchedulerMessage::Shutdown)
435 .await
436 .map_err(|e| {
437 SpiderError::GeneralError(format!(
438 "Scheduler: Failed to send shutdown signal: {}",
439 e
440 ))
441 })
442 } else {
443 debug!("Scheduler internal channel already closed, skipping shutdown signal");
444 Ok(())
445 }
446 }
447
448 pub async fn mark_visited(&self, fingerprint: String) -> Result<(), SpiderError> {
449 trace!(
450 "Sending MarkAsVisited message for fingerprint: {}",
451 fingerprint
452 );
453 self.tx
454 .send(SchedulerMessage::MarkAsVisited(fingerprint))
455 .await
456 .map_err(|e| {
457 if !self.is_shutting_down.load(Ordering::SeqCst) {
458 error!("Scheduler internal message channel is closed. Failed to mark URL as visited: {}", e);
459 }
460 SpiderError::GeneralError(format!(
461 "Scheduler: Failed to send MarkAsVisited message: {}",
462 e
463 ))
464 })
465 }
466
467 pub async fn mark_visited_batch(&self, fingerprints: Vec<String>) -> Result<(), SpiderError> {
468 if fingerprints.is_empty() {
469 return Ok(());
470 }
471
472 trace!(
473 "Sending MarkAsVisitedBatch message for {} fingerprints",
474 fingerprints.len()
475 );
476 self.tx
477 .send(SchedulerMessage::MarkAsVisitedBatch(fingerprints))
478 .await
479 .map_err(|e| {
480 if !self.is_shutting_down.load(Ordering::SeqCst) {
481 error!("Scheduler internal message channel is closed. Failed to mark URLs as visited in batch: {}", e);
482 }
483 SpiderError::GeneralError(format!(
484 "Scheduler: Failed to send MarkAsVisitedBatch message: {}",
485 e
486 ))
487 })
488 }
489
490 pub fn is_visited(&self, fingerprint: &str) -> bool {
491 if !self.bloom.read().might_contain(fingerprint) {
492 return false;
493 }
494
495 {
496 let buffer = self.buffer.lock().unwrap();
497 if buffer.iter().any(|item| item == fingerprint) {
498 return true;
499 }
500 }
501
502 self.visited.contains_key(fingerprint)
503 }
504
505 fn flush_buffer_now(&self) {
506 let mut buffer = self.buffer.lock().unwrap();
507 if !buffer.is_empty() {
508 let items: Vec<String> = buffer.drain(..).collect();
509 drop(buffer);
510
511 let mut bloom = self.bloom.write();
512 for item in items {
513 bloom.add(&item);
514 }
515 }
516 }
517
518 async fn flush_buffer(
519 &self,
520 _buffer: Arc<std::sync::Mutex<Vec<String>>>,
521 notify: Arc<Notify>,
522 ) {
523 loop {
524 tokio::select! {
525 _ = notify.notified() => {
526 self.flush_buffer_now();
527 }
528 _ = tokio::time::sleep(tokio::time::Duration::from_millis(100)) => {
529 self.flush_buffer_now();
530 }
531 }
532
533 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
534 }
535 }
536
537 pub fn should_enqueue(&self, request: &Request) -> bool {
538 let fingerprint = request.fingerprint();
539 !self.is_visited(&fingerprint)
540 }
541
542 #[inline]
543 pub fn len(&self) -> usize {
544 self.pending.load(Ordering::SeqCst)
545 }
546
547 #[inline]
548 pub fn is_empty(&self) -> bool {
549 self.len() == 0
550 }
551
552 #[inline]
553 pub fn is_idle(&self) -> bool {
554 self.is_empty()
555 }
556}