1use crate::checkpoint::SchedulerCheckpoint;
2use crate::error::SpiderError;
3use crate::request::Request;
4use dashmap::DashSet;
5use kanal::{AsyncReceiver, AsyncSender, bounded_async, unbounded_async};
6use std::collections::VecDeque;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicUsize, Ordering};
9use tokio::sync::{Mutex, oneshot};
10use tracing::{debug, error, info};
11
12enum SchedulerMessage {
13 Enqueue(Box<Request>),
14 Shutdown,
15 TakeSnapshot(oneshot::Sender<SchedulerCheckpoint>),
16}
17
18pub struct Scheduler {
19 request_queue: Arc<Mutex<VecDeque<Request>>>,
20 visited_urls: DashSet<String>,
21 tx_internal: AsyncSender<SchedulerMessage>,
22 pending_requests: AtomicUsize,
23}
24
25impl Scheduler {
26 pub fn new(initial_state: Option<SchedulerCheckpoint>) -> (Arc<Self>, AsyncReceiver<Request>) {
28 let (tx_internal, rx_internal) = unbounded_async();
29 let (tx_req_out, rx_req_out) = bounded_async(1);
30
31 let (request_queue, visited_urls, pending_requests) = if let Some(state) = initial_state {
32 info!(
33 "Initializing scheduler from checkpoint with {} requests and {} visited URLs.",
34 state.request_queue.len(),
35 state.visited_urls.len()
36 );
37 let pending = state.request_queue.len();
38 (
39 Arc::new(Mutex::new(state.request_queue)),
40 state.visited_urls,
41 AtomicUsize::new(pending),
42 )
43 } else {
44 (
45 Arc::new(Mutex::new(VecDeque::new())),
46 DashSet::new(),
47 AtomicUsize::new(0),
48 )
49 };
50
51 let scheduler = Arc::new(Scheduler {
52 request_queue,
53 visited_urls,
54 tx_internal,
55 pending_requests,
56 });
57
58 let scheduler_clone = Arc::clone(&scheduler);
59 tokio::spawn(async move {
60 scheduler_clone.run_loop(rx_internal, tx_req_out).await;
61 });
62
63 (scheduler, rx_req_out)
64 }
65
66 async fn run_loop(
67 &self,
68 rx_internal: AsyncReceiver<SchedulerMessage>,
69 tx_req_out: AsyncSender<Request>,
70 ) {
71 info!("Scheduler run_loop started.");
72 loop {
73 let maybe_request = if !tx_req_out.is_closed() && !self.is_idle() {
75 self.request_queue.lock().await.pop_front()
76 } else {
77 None
78 };
79
80 if let Some(request) = maybe_request {
81 tokio::select! {
82 biased;
83 send_res = tx_req_out.send(request) => {
84 if send_res.is_err() {
85 error!("Crawler receiver dropped. Scheduler can no longer send requests.");
86 }
87 self.pending_requests.fetch_sub(1, Ordering::SeqCst);
88 },
89 recv_res = rx_internal.recv() => {
90 self.pending_requests.fetch_sub(1, Ordering::SeqCst);
91 if !self.handle_message(recv_res).await {
92 break;
93 }
94 }
95 }
96 } else {
97 if !self.handle_message(rx_internal.recv().await).await {
99 break;
100 }
101 }
102 }
103 info!("Scheduler run_loop finished.");
104 }
105
106 async fn handle_message(&self, msg: Result<SchedulerMessage, kanal::ReceiveError>) -> bool {
107 match msg {
108 Ok(SchedulerMessage::Enqueue(boxed_request)) => {
109 let request = *boxed_request;
110 let fingerprint = request.fingerprint();
111
112 if self.visited_urls.insert(fingerprint.clone()) {
113 self.request_queue.lock().await.push_back(request);
114 self.pending_requests.fetch_add(1, Ordering::SeqCst);
115 } else {
116 debug!(
117 "Skipping already visited URL: {} (fingerprint: {})",
118 request.url, fingerprint
119 );
120 }
121 true
122 }
123 Ok(SchedulerMessage::TakeSnapshot(responder)) => {
124 let visited_urls = self.visited_urls.iter().map(|item| item.clone()).collect();
125 let request_queue = self.request_queue.lock().await.clone();
126
127 let _ = responder.send(SchedulerCheckpoint {
128 request_queue,
129 visited_urls,
130 });
131 true
132 }
133 Ok(SchedulerMessage::Shutdown) | Err(_) => {
134 info!("Scheduler received shutdown signal or channel closed. Exiting run_loop.");
135 false
136 }
137 }
138 }
139
140 pub async fn snapshot(&self) -> Result<SchedulerCheckpoint, SpiderError> {
142 let (tx, rx) = oneshot::channel();
143 self.tx_internal
144 .send(SchedulerMessage::TakeSnapshot(tx))
145 .await
146 .map_err(|e| {
147 SpiderError::GeneralError(format!(
148 "Scheduler: Failed to send snapshot request: {}",
149 e
150 ))
151 })?;
152 rx.await.map_err(|e| {
153 SpiderError::GeneralError(format!("Scheduler: Failed to receive snapshot: {}", e))
154 })
155 }
156
157 pub async fn enqueue_request(&self, request: Request) -> Result<(), (Request, SpiderError)> {
159 let original_request = request.clone();
160
161 self.tx_internal
162 .send(SchedulerMessage::Enqueue(Box::new(request)))
163 .await
164 .map_err(|e| {
165 (
166 original_request,
167 SpiderError::GeneralError(format!(
168 "Scheduler: Failed to enqueue request: {}",
169 e
170 )),
171 )
172 })
173 }
174
175 pub async fn shutdown(&self) -> Result<(), SpiderError> {
177 self.tx_internal
178 .send(SchedulerMessage::Shutdown)
179 .await
180 .map_err(|e| {
181 SpiderError::GeneralError(format!(
182 "Scheduler: Failed to send shutdown signal: {}",
183 e
184 ))
185 })
186 }
187
188 #[inline]
190 pub fn len(&self) -> usize {
191 self.pending_requests.load(Ordering::SeqCst)
192 }
193
194 #[inline]
196 pub fn is_empty(&self) -> bool {
197 self.len() == 0
198 }
199
200 #[inline]
202 pub fn is_idle(&self) -> bool {
203 self.is_empty()
204 }
205}