1#[cfg(feature = "checkpoint")]
45use crate::SchedulerCheckpoint;
46
47use spider_util::error::SpiderError;
48use spider_util::request::Request;
49use crossbeam::queue::SegQueue;
50use kanal::{AsyncReceiver, AsyncSender, bounded_async, unbounded_async};
51use moka::sync::Cache;
52use std::sync::Arc;
53use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
54use tracing::{debug, error, info, trace, warn};
55
56enum SchedulerMessage {
57 Enqueue(Box<Request>),
58 MarkAsVisited(String),
59 Shutdown,
60}
61
62use spider_util::bloom_filter::BloomFilter;
63
64pub struct Scheduler {
65 request_queue: SegQueue<Request>,
66 visited_urls: Cache<String, bool>,
67 bloom_filter: std::sync::Arc<parking_lot::RwLock<BloomFilter>>,
68 tx_internal: AsyncSender<SchedulerMessage>,
69 pending_requests: AtomicUsize,
70 salvaged_requests: SegQueue<Request>,
71 pub(crate) is_shutting_down: AtomicBool,
72 max_pending_requests: usize,
73}
74
75impl Scheduler {
76 #[cfg(feature = "checkpoint")]
78 pub fn new(
79 initial_state: Option<SchedulerCheckpoint>,
80 ) -> (Arc<Self>, AsyncReceiver<Request>) {
81 let (tx_internal, rx_internal) = unbounded_async();
82
83 let (tx_req_out, rx_req_out) = bounded_async(100);
84
85 let request_queue: SegQueue<Request>;
86 let visited_urls: Cache<String, bool>;
87 let pending_requests: AtomicUsize;
88 let salvaged_requests: SegQueue<Request>;
89
90 if let Some(state) = initial_state {
91 info!(
92 "Initializing scheduler from checkpoint with {} requests, {} visited URLs, and {} salvaged requests.",
93 state.request_queue.len(),
94 state.visited_urls.len(),
95 state.salvaged_requests.len(),
96 );
97 let pending = state.request_queue.len() + state.salvaged_requests.len();
98 request_queue = SegQueue::new();
99 for request in state.request_queue {
100 request_queue.push(request);
101 }
102
103 visited_urls = Cache::builder().max_capacity(100000).build();
104 for url in state.visited_urls {
105 visited_urls.insert(url, true);
106 }
107
108 pending_requests = AtomicUsize::new(pending);
109 salvaged_requests = SegQueue::new();
110 for request in state.salvaged_requests {
111 salvaged_requests.push(request);
112 }
113 } else {
114 request_queue = SegQueue::new();
115 visited_urls = Cache::builder().max_capacity(100000).build();
116 pending_requests = AtomicUsize::new(0);
117 salvaged_requests = SegQueue::new();
118 }
119
120 let scheduler = Arc::new(Scheduler {
121 request_queue,
122 visited_urls,
123 bloom_filter: std::sync::Arc::new(parking_lot::RwLock::new(BloomFilter::new(1000000, 3))),
124 tx_internal,
125 pending_requests,
126 salvaged_requests,
127 is_shutting_down: AtomicBool::new(false),
128 max_pending_requests: 10000,
129 });
130
131 let scheduler_clone = Arc::clone(&scheduler);
132 tokio::spawn(async move {
133 scheduler_clone.run_loop(rx_internal, tx_req_out).await;
134 });
135
136 (scheduler, rx_req_out)
137 }
138
139 #[cfg(not(feature = "checkpoint"))]
141 pub fn new(
142 _initial_state: Option<()>, ) -> (Arc<Self>, AsyncReceiver<Request>) {
144 let (tx_internal, rx_internal) = unbounded_async();
145
146 let (tx_req_out, rx_req_out) = bounded_async(100);
147
148 let request_queue = SegQueue::new();
149 let visited_urls = Cache::builder().max_capacity(100000).build();
150 let pending_requests = AtomicUsize::new(0);
151 let salvaged_requests = SegQueue::new();
152
153 let scheduler = Arc::new(Scheduler {
154 request_queue,
155 visited_urls,
156 bloom_filter: std::sync::Arc::new(parking_lot::RwLock::new(BloomFilter::new(1000000, 3))),
157 tx_internal,
158 pending_requests,
159 salvaged_requests,
160 is_shutting_down: AtomicBool::new(false),
161 max_pending_requests: 10000,
162 });
163
164 let scheduler_clone = Arc::clone(&scheduler);
165 tokio::spawn(async move {
166 scheduler_clone.run_loop(rx_internal, tx_req_out).await;
167 });
168
169 (scheduler, rx_req_out)
170 }
171
172 async fn run_loop(
173 &self,
174 rx_internal: AsyncReceiver<SchedulerMessage>,
175 tx_req_out: AsyncSender<Request>,
176 ) {
177 info!(
178 "Scheduler run_loop started with max pending requests: {}",
179 self.max_pending_requests
180 );
181 loop {
182 if let Ok(Some(msg)) = rx_internal.try_recv() {
183 trace!("Processing pending internal message");
184 if !self.handle_message(Ok(msg)).await {
185 break;
186 }
187 continue;
188 }
189
190 let maybe_request = if !tx_req_out.is_closed() && !self.is_idle() {
191 self.request_queue.pop()
192 } else {
193 None
194 };
195
196 if let Some(request) = maybe_request {
197 trace!("Sending request to crawler: {}", request.url);
198 tokio::select! {
199 send_res = tx_req_out.send(request) => {
200 if send_res.is_err() {
201 error!("Crawler receiver dropped. Scheduler can no longer send requests.");
202 } else {
203 trace!("Successfully sent request to crawler");
204 }
205 self.pending_requests.fetch_sub(1, Ordering::SeqCst);
206 },
207 recv_res = rx_internal.recv() => {
208 trace!("Received internal message while sending request");
209 if !self.handle_message(recv_res).await {
210 break;
211 }
212 continue;
213 }
214 }
215 } else {
216 trace!("No pending requests, waiting for internal message");
217 if !self.handle_message(rx_internal.recv().await).await {
218 break;
219 }
220 }
221 }
222 info!(
223 "Scheduler run_loop finished with {} pending requests remaining.",
224 self.pending_requests.load(Ordering::SeqCst)
225 );
226 }
227
228 async fn handle_message(&self, msg: Result<SchedulerMessage, kanal::ReceiveError>) -> bool {
229 match msg {
230 Ok(SchedulerMessage::Enqueue(boxed_request)) => {
231 let request = *boxed_request;
232 trace!("Enqueuing request: {}", request.url);
233 self.request_queue.push(request);
234 self.pending_requests.fetch_add(1, Ordering::SeqCst);
235 true
236 }
237 Ok(SchedulerMessage::MarkAsVisited(fingerprint)) => {
238 trace!("Marking URL fingerprint as visited: {}", fingerprint);
239 self.visited_urls.insert(fingerprint.clone(), true);
240 self.bloom_filter.write().add(&fingerprint);
241 debug!("Marked URL as visited: {}", fingerprint);
242 true
243 }
244 Ok(SchedulerMessage::Shutdown) => {
245 info!("Scheduler received shutdown signal. Exiting run_loop.");
246 self.is_shutting_down.store(true, Ordering::SeqCst);
247 false
248 }
249 Err(_) => {
250 warn!("Scheduler internal message channel closed. Exiting run_loop.");
251 self.is_shutting_down.store(true, Ordering::SeqCst);
252 false
253 }
254 }
255 }
256
257 #[cfg(feature = "checkpoint")]
260 pub async fn snapshot(&self) -> Result<SchedulerCheckpoint, SpiderError> {
261 let visited_urls = dashmap::DashSet::new();
262 for entry in self.visited_urls.iter() {
263 let (key, _) = entry;
264 visited_urls.insert(key.as_ref().clone());
265 }
266
267 let mut request_queue = std::collections::VecDeque::new();
270 let mut temp_requests = Vec::new();
271
272 while let Some(request) = self.request_queue.pop() {
274 temp_requests.push(request);
275 }
276
277 for request in temp_requests.into_iter() {
279 request_queue.push_back(request.clone());
280 if !self.is_shutting_down.load(Ordering::SeqCst) {
282 self.request_queue.push(request);
283 }
284 }
285
286 let mut salvaged_requests = std::collections::VecDeque::new();
288 let mut temp_salvaged = Vec::new();
289
290 while let Some(request) = self.salvaged_requests.pop() {
291 temp_salvaged.push(request);
292 }
293
294 for request in temp_salvaged.into_iter() {
295 salvaged_requests.push_back(request.clone());
296 if !self.is_shutting_down.load(Ordering::SeqCst) {
298 self.salvaged_requests.push(request);
299 }
300 }
301
302 Ok(SchedulerCheckpoint {
303 request_queue,
304 visited_urls,
305 salvaged_requests,
306 })
307 }
308
309 #[cfg(not(feature = "checkpoint"))]
311 pub async fn snapshot(&self) -> Result<(), SpiderError> {
312 Ok(())
314 }
315
316 pub async fn enqueue_request(&self, request: Request) -> Result<(), SpiderError> {
318 if !self.should_enqueue_request(&request) {
319 trace!("Request already visited, skipping: {}", request.url);
320 return Ok(());
321 }
322
323 let current_pending = self.pending_requests.load(Ordering::SeqCst);
325 if current_pending >= self.max_pending_requests {
326 warn!(
327 "Maximum pending requests reached ({}), request dropped due to backpressure: {}",
328 self.max_pending_requests, request.url
329 );
330 return Err(SpiderError::GeneralError(
331 "Scheduler at maximum capacity, request dropped due to backpressure.".into(),
332 ));
333 }
334
335 trace!("Enqueuing request: {}", request.url);
336 if self
337 .tx_internal
338 .send(SchedulerMessage::Enqueue(Box::new(request.clone())))
339 .await
340 .is_err()
341 {
342 if !self.is_shutting_down.load(Ordering::SeqCst) {
343 error!(
344 "Scheduler internal message channel is closed. Salvaging request: {}",
345 request.url
346 );
347 }
348 self.salvaged_requests.push(request);
349 return Err(SpiderError::GeneralError(
350 "Scheduler internal channel closed, request salvaged.".into(),
351 ));
352 }
353
354 trace!("Successfully enqueued request: {}", request.url);
355 Ok(())
356 }
357
358 pub async fn shutdown(&self) -> Result<(), SpiderError> {
360 self.is_shutting_down.store(true, Ordering::SeqCst);
361
362 if !self.tx_internal.is_closed() {
363 self.tx_internal
364 .send(SchedulerMessage::Shutdown)
365 .await
366 .map_err(|e| {
367 SpiderError::GeneralError(format!(
368 "Scheduler: Failed to send shutdown signal: {}",
369 e
370 ))
371 })
372 } else {
373 debug!("Scheduler internal channel already closed, skipping shutdown signal");
374 Ok(())
375 }
376 }
377
378 pub async fn send_mark_as_visited(&self, fingerprint: String) -> Result<(), SpiderError> {
380 trace!(
381 "Sending MarkAsVisited message for fingerprint: {}",
382 fingerprint
383 );
384 self.tx_internal
385 .send(SchedulerMessage::MarkAsVisited(fingerprint.clone()))
386 .await
387 .map_err(|e| {
388 if !self.is_shutting_down.load(Ordering::SeqCst) {
389 error!("Scheduler internal message channel is closed. Failed to mark URL as visited (fingerprint: {}): {}", fingerprint, e);
390 }
391 SpiderError::GeneralError(format!(
392 "Scheduler: Failed to send MarkAsVisited message: {}",
393 e
394 ))
395 })
396 }
397
398 pub fn has_been_visited(&self, fingerprint: &str) -> bool {
400 if !self.bloom_filter.read().might_contain(fingerprint) {
401 return false;
402 }
403
404 self.visited_urls.contains_key(fingerprint)
405 }
406
407 pub fn should_enqueue_request(&self, request: &Request) -> bool {
409 let fingerprint = request.fingerprint();
410 !self.has_been_visited(&fingerprint)
411 }
412
413 #[inline]
415 pub fn len(&self) -> usize {
416 self.pending_requests.load(Ordering::SeqCst)
417 }
418
419 #[inline]
421 pub fn is_empty(&self) -> bool {
422 self.len() == 0
423 }
424
425 #[inline]
427 pub fn is_idle(&self) -> bool {
428 self.is_empty()
429 }
430}