1use crate::SchedulerCheckpoint;
15
16use crate::error::SpiderError;
17use crate::request::Request;
18use dashmap::DashSet;
19use kanal::{AsyncReceiver, AsyncSender, bounded_async, unbounded_async};
20use std::collections::VecDeque;
21use std::sync::Arc;
22use std::sync::atomic::{AtomicUsize, Ordering};
23use tokio::sync::Mutex;
24#[cfg(feature = "checkpoint")]
25use tokio::sync::oneshot;
26use tracing::{debug, error, info};
27
28enum SchedulerMessage {
29 Enqueue(Box<Request>),
30 MarkAsVisited(String),
31 Shutdown,
32 #[cfg(feature = "checkpoint")]
33 TakeSnapshot(oneshot::Sender<SchedulerCheckpoint>),
34}
35
36pub struct Scheduler {
37 request_queue: Arc<Mutex<VecDeque<Request>>>,
38 visited_urls: DashSet<String>,
39 tx_internal: AsyncSender<SchedulerMessage>,
40 pending_requests: AtomicUsize,
41 salvaged_requests: Arc<Mutex<VecDeque<Request>>>,
42}
43
44impl Scheduler {
45 pub fn new(
47 #[cfg(feature = "checkpoint")] initial_state: Option<SchedulerCheckpoint>,
48 #[cfg(not(feature = "checkpoint"))] _initial_state: Option<SchedulerCheckpoint>,
49 ) -> (Arc<Self>, AsyncReceiver<Request>) {
50 let (tx_internal, rx_internal) = unbounded_async();
51 let (tx_req_out, rx_req_out) = bounded_async(1);
52
53 let request_queue: Arc<Mutex<VecDeque<Request>>>;
55 let visited_urls: DashSet<String>;
56 let pending_requests: AtomicUsize;
57 let salvaged_requests: Arc<Mutex<VecDeque<Request>>>;
58
59 #[cfg(feature = "checkpoint")]
60 if let Some(state) = initial_state {
61 info!(
63 "Initializing scheduler from checkpoint with {} requests, {} visited URLs, and {} salvaged requests.",
64 state.request_queue.len(),
65 state.visited_urls.len(),
66 state.salvaged_requests.len(),
67 );
68 let pending = state.request_queue.len() + state.salvaged_requests.len();
69 request_queue = Arc::new(Mutex::new(state.request_queue));
70 visited_urls = state.visited_urls;
71 pending_requests = AtomicUsize::new(pending);
72 salvaged_requests = Arc::new(Mutex::new(state.salvaged_requests));
73 } else {
74 request_queue = Arc::new(Mutex::new(VecDeque::new()));
76 visited_urls = DashSet::new();
77 pending_requests = AtomicUsize::new(0);
78 salvaged_requests = Arc::new(Mutex::new(VecDeque::new()));
79 }
80
81 #[cfg(not(feature = "checkpoint"))] {
83 request_queue = Arc::new(Mutex::new(VecDeque::new()));
86 visited_urls = DashSet::new();
87 pending_requests = AtomicUsize::new(0);
88 salvaged_requests = Arc::new(Mutex::new(VecDeque::new()));
89 }
90
91 let scheduler = Arc::new(Scheduler {
92 request_queue,
93 visited_urls,
94 tx_internal,
95 pending_requests,
96 salvaged_requests,
97 });
98
99 let scheduler_clone = Arc::clone(&scheduler);
100 tokio::spawn(async move {
101 scheduler_clone.run_loop(rx_internal, tx_req_out).await;
102 });
103
104 (scheduler, rx_req_out)
105 }
106
107 async fn run_loop(
108 &self,
109 rx_internal: AsyncReceiver<SchedulerMessage>,
110 tx_req_out: AsyncSender<Request>,
111 ) {
112 info!("Scheduler run_loop started.");
113 loop {
114 let maybe_request = if !tx_req_out.is_closed() && !self.is_idle() {
116 self.request_queue.lock().await.pop_front()
117 } else {
118 None
119 };
120
121 if let Some(request) = maybe_request {
122 tokio::select! {
123 biased;
124 send_res = tx_req_out.send(request) => {
125 if send_res.is_err() {
126 error!("Crawler receiver dropped. Scheduler can no longer send requests.");
127 }
128 self.pending_requests.fetch_sub(1, Ordering::SeqCst);
129 },
130 recv_res = rx_internal.recv() => {
131 self.pending_requests.fetch_sub(1, Ordering::SeqCst);
132 if !self.handle_message(recv_res).await {
133 break;
134 }
135 }
136 }
137 } else {
138 if !self.handle_message(rx_internal.recv().await).await {
140 break;
141 }
142 }
143 }
144 info!("Scheduler run_loop finished.");
145 }
146
147 async fn handle_message(&self, msg: Result<SchedulerMessage, kanal::ReceiveError>) -> bool {
148 match msg {
149 Ok(SchedulerMessage::Enqueue(boxed_request)) => {
150 let request = *boxed_request;
151 self.request_queue.lock().await.push_back(request);
152 self.pending_requests.fetch_add(1, Ordering::SeqCst);
153 true
154 }
155 Ok(SchedulerMessage::MarkAsVisited(fingerprint)) => {
156 self.visited_urls.insert(fingerprint.clone());
157 debug!("Marked URL as visited: {}", fingerprint);
158 true
159 }
160 #[cfg(feature = "checkpoint")]
161 Ok(SchedulerMessage::TakeSnapshot(responder)) => {
162 let visited_urls = self.visited_urls.iter().map(|item| item.clone()).collect();
163 let request_queue = self.request_queue.lock().await.clone();
164 let salvaged_requests = self.salvaged_requests.lock().await.clone();
165
166 let _ = responder.send(SchedulerCheckpoint {
167 request_queue,
168 visited_urls,
169 salvaged_requests,
170 });
171 true
172 }
173 Ok(SchedulerMessage::Shutdown) | Err(_) => {
174 info!("Scheduler received shutdown signal or channel closed. Exiting run_loop.");
175 false
176 }
177 }
178 }
179
180 #[cfg(feature = "checkpoint")]
182 pub async fn snapshot(&self) -> Result<SchedulerCheckpoint, SpiderError> {
183 let (tx, rx) = oneshot::channel();
184 self.tx_internal
185 .send(SchedulerMessage::TakeSnapshot(tx))
186 .await
187 .map_err(|e| {
188 SpiderError::GeneralError(format!(
189 "Scheduler: Failed to send snapshot request: {}",
190 e
191 ))
192 })?;
193 rx.await.map_err(|e| {
194 SpiderError::GeneralError(format!("Scheduler: Failed to receive snapshot: {}", e))
195 })
196 }
197
198 pub async fn enqueue_request(&self, request: Request) -> Result<(), SpiderError> {
200 if self
201 .tx_internal
202 .send(SchedulerMessage::Enqueue(Box::new(request.clone()))) .await
204 .is_err()
205 {
206 error!("Scheduler internal message channel is closed. Salvaging request.");
207 self.salvaged_requests.lock().await.push_back(request); return Err(SpiderError::GeneralError(
209 "Scheduler internal channel closed, request salvaged.".into(),
210 ));
211 }
212 Ok(())
213 }
214
215 pub async fn shutdown(&self) -> Result<(), SpiderError> {
217 self.tx_internal
218 .send(SchedulerMessage::Shutdown)
219 .await
220 .map_err(|e| {
221 SpiderError::GeneralError(format!(
222 "Scheduler: Failed to send shutdown signal: {}",
223 e
224 ))
225 })
226 }
227
228 pub async fn send_mark_as_visited(&self, fingerprint: String) -> Result<(), SpiderError> {
230 self.tx_internal
231 .send(SchedulerMessage::MarkAsVisited(fingerprint))
232 .await
233 .map_err(|e| {
234 SpiderError::GeneralError(format!(
235 "Scheduler: Failed to send MarkAsVisited message: {}",
236 e
237 ))
238 })
239 }
240
241 #[inline]
243 pub fn len(&self) -> usize {
244 self.pending_requests.load(Ordering::SeqCst)
245 }
246
247 #[inline]
249 pub fn is_empty(&self) -> bool {
250 self.len() == 0
251 }
252
253 #[inline]
255 pub fn is_idle(&self) -> bool {
256 self.is_empty()
257 }
258}