Skip to main content

spider_lib/
scheduler.rs

1use crate::checkpoint::SchedulerCheckpoint;
2use crate::error::SpiderError;
3use crate::request::Request;
4use dashmap::DashSet;
5use kanal::{AsyncReceiver, AsyncSender, bounded_async, unbounded_async};
6use std::collections::VecDeque;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicUsize, Ordering};
9use tokio::sync::{Mutex, oneshot};
10use tracing::{debug, error, info};
11
12enum SchedulerMessage {
13    Enqueue(Box<Request>),
14    Shutdown,
15    TakeSnapshot(oneshot::Sender<SchedulerCheckpoint>),
16}
17
18pub struct Scheduler {
19    request_queue: Arc<Mutex<VecDeque<Request>>>,
20    visited_urls: DashSet<String>,
21    tx_internal: AsyncSender<SchedulerMessage>,
22    pending_requests: AtomicUsize,
23}
24
25impl Scheduler {
26    /// Creates a new `Scheduler` and returns a tuple containing the scheduler and a request receiver.
27    pub fn new(initial_state: Option<SchedulerCheckpoint>) -> (Arc<Self>, AsyncReceiver<Request>) {
28        let (tx_internal, rx_internal) = unbounded_async();
29        let (tx_req_out, rx_req_out) = bounded_async(1);
30
31        let (request_queue, visited_urls, pending_requests) = if let Some(state) = initial_state {
32            info!(
33                "Initializing scheduler from checkpoint with {} requests and {} visited URLs.",
34                state.request_queue.len(),
35                state.visited_urls.len()
36            );
37            let pending = state.request_queue.len();
38            (
39                Arc::new(Mutex::new(state.request_queue)),
40                state.visited_urls,
41                AtomicUsize::new(pending),
42            )
43        } else {
44            (
45                Arc::new(Mutex::new(VecDeque::new())),
46                DashSet::new(),
47                AtomicUsize::new(0),
48            )
49        };
50
51        let scheduler = Arc::new(Scheduler {
52            request_queue,
53            visited_urls,
54            tx_internal,
55            pending_requests,
56        });
57
58        let scheduler_clone = Arc::clone(&scheduler);
59        tokio::spawn(async move {
60            scheduler_clone.run_loop(rx_internal, tx_req_out).await;
61        });
62
63        (scheduler, rx_req_out)
64    }
65
66    async fn run_loop(
67        &self,
68        rx_internal: AsyncReceiver<SchedulerMessage>,
69        tx_req_out: AsyncSender<Request>,
70    ) {
71        info!("Scheduler run_loop started.");
72        loop {
73            // Only try to pop if we know there are items
74            let maybe_request = if !tx_req_out.is_closed() && !self.is_idle() {
75                self.request_queue.lock().await.pop_front()
76            } else {
77                None
78            };
79
80            if let Some(request) = maybe_request {
81                tokio::select! {
82                    biased;
83                    send_res = tx_req_out.send(request) => {
84                        if send_res.is_err() {
85                            error!("Crawler receiver dropped. Scheduler can no longer send requests.");
86                        }
87                        self.pending_requests.fetch_sub(1, Ordering::SeqCst);
88                    },
89                    recv_res = rx_internal.recv() => {
90                        self.pending_requests.fetch_sub(1, Ordering::SeqCst);
91                        if !self.handle_message(recv_res).await {
92                            break;
93                        }
94                    }
95                }
96            } else {
97                // If the queue is empty, we must wait for a new message to arrive.
98                if !self.handle_message(rx_internal.recv().await).await {
99                    break;
100                }
101            }
102        }
103        info!("Scheduler run_loop finished.");
104    }
105
106    async fn handle_message(&self, msg: Result<SchedulerMessage, kanal::ReceiveError>) -> bool {
107        match msg {
108            Ok(SchedulerMessage::Enqueue(boxed_request)) => {
109                let request = *boxed_request;
110                let fingerprint = request.fingerprint();
111
112                if self.visited_urls.insert(fingerprint.clone()) {
113                    self.request_queue.lock().await.push_back(request);
114                    self.pending_requests.fetch_add(1, Ordering::SeqCst);
115                } else {
116                    debug!(
117                        "Skipping already visited URL: {} (fingerprint: {})",
118                        request.url, fingerprint
119                    );
120                }
121                true
122            }
123            Ok(SchedulerMessage::TakeSnapshot(responder)) => {
124                let visited_urls = self.visited_urls.iter().map(|item| item.clone()).collect();
125                let request_queue = self.request_queue.lock().await.clone();
126
127                let _ = responder.send(SchedulerCheckpoint {
128                    request_queue,
129                    visited_urls,
130                });
131                true
132            }
133            Ok(SchedulerMessage::Shutdown) | Err(_) => {
134                info!("Scheduler received shutdown signal or channel closed. Exiting run_loop.");
135                false
136            }
137        }
138    }
139
140    /// Takes a snapshot of the current state of the scheduler.
141    pub async fn snapshot(&self) -> Result<SchedulerCheckpoint, SpiderError> {
142        let (tx, rx) = oneshot::channel();
143        self.tx_internal
144            .send(SchedulerMessage::TakeSnapshot(tx))
145            .await
146            .map_err(|e| {
147                SpiderError::GeneralError(format!(
148                    "Scheduler: Failed to send snapshot request: {}",
149                    e
150                ))
151            })?;
152        rx.await.map_err(|e| {
153            SpiderError::GeneralError(format!("Scheduler: Failed to receive snapshot: {}", e))
154        })
155    }
156
157    /// Enqueues a new request to be processed.
158    pub async fn enqueue_request(&self, request: Request) -> Result<(), (Request, SpiderError)> {
159        let original_request = request.clone();
160
161        self.tx_internal
162            .send(SchedulerMessage::Enqueue(Box::new(request)))
163            .await
164            .map_err(|e| {
165                (
166                    original_request,
167                    SpiderError::GeneralError(format!(
168                        "Scheduler: Failed to enqueue request: {}",
169                        e
170                    )),
171                )
172            })
173    }
174
175    /// Sends a shutdown signal to the scheduler.
176    pub async fn shutdown(&self) -> Result<(), SpiderError> {
177        self.tx_internal
178            .send(SchedulerMessage::Shutdown)
179            .await
180            .map_err(|e| {
181                SpiderError::GeneralError(format!(
182                    "Scheduler: Failed to send shutdown signal: {}",
183                    e
184                ))
185            })
186    }
187
188    /// Returns the number of pending requests in the scheduler.
189    #[inline]
190    pub fn len(&self) -> usize {
191        self.pending_requests.load(Ordering::SeqCst)
192    }
193
194    /// Checks if the scheduler has no pending requests.
195    #[inline]
196    pub fn is_empty(&self) -> bool {
197        self.len() == 0
198    }
199
200    /// Checks if the scheduler is idle (has no pending requests).
201    #[inline]
202    pub fn is_idle(&self) -> bool {
203        self.is_empty()
204    }
205}