1use {
6 crate::{
7 connection_worker::ConnectionWorker,
8 logging::{debug, trace},
9 transaction_batch::TransactionBatch,
10 SendTransactionStats,
11 },
12 lru::LruCache,
13 quinn::Endpoint,
14 std::{net::SocketAddr, sync::Arc, time::Duration},
15 thiserror::Error,
16 tokio::{
17 sync::mpsc::{self, error::TrySendError},
18 task::{JoinHandle, JoinSet},
19 },
20 tokio_util::sync::CancellationToken,
21};
22
23pub struct WorkerInfo {
26 sender: mpsc::Sender<TransactionBatch>,
27 handle: JoinHandle<()>,
28 cancel: CancellationToken,
29}
30
31impl WorkerInfo {
32 pub fn new(
33 sender: mpsc::Sender<TransactionBatch>,
34 handle: JoinHandle<()>,
35 cancel: CancellationToken,
36 ) -> Self {
37 Self {
38 sender,
39 handle,
40 cancel,
41 }
42 }
43
44 fn try_send_transactions(&self, txs_batch: TransactionBatch) -> Result<(), WorkersCacheError> {
45 self.sender.try_send(txs_batch).map_err(|err| match err {
46 TrySendError::Full(_) => WorkersCacheError::FullChannel,
47 TrySendError::Closed(_) => WorkersCacheError::ReceiverDropped,
48 })?;
49 Ok(())
50 }
51
52 async fn send_transactions(
53 &self,
54 txs_batch: TransactionBatch,
55 ) -> Result<(), WorkersCacheError> {
56 self.sender
57 .send(txs_batch)
58 .await
59 .map_err(|_| WorkersCacheError::ReceiverDropped)?;
60 Ok(())
61 }
62
63 async fn shutdown(self) -> Result<(), WorkersCacheError> {
66 self.cancel.cancel();
67 drop(self.sender);
68 self.handle
69 .await
70 .map_err(|_| WorkersCacheError::TaskJoinFailure)?;
71 Ok(())
72 }
73
74 fn is_active(&self) -> bool {
77 !(self.cancel.is_cancelled() || self.sender.is_closed())
78 }
79}
80
81pub fn spawn_worker(
83 endpoint: &Endpoint,
84 peer: &SocketAddr,
85 worker_channel_size: usize,
86 skip_check_transaction_age: bool,
87 max_reconnect_attempts: usize,
88 handshake_timeout: Duration,
89 stats: Arc<SendTransactionStats>,
90) -> WorkerInfo {
91 let (txs_sender, txs_receiver) = mpsc::channel(worker_channel_size);
92 let endpoint = endpoint.clone();
93 let peer = *peer;
94
95 let (mut worker, cancel) = ConnectionWorker::new(
96 endpoint,
97 peer,
98 txs_receiver,
99 skip_check_transaction_age,
100 max_reconnect_attempts,
101 stats,
102 handshake_timeout,
103 );
104 let handle = tokio::spawn(async move {
105 worker.run().await;
106 });
107
108 WorkerInfo::new(txs_sender, handle, cancel)
109}
110
111pub struct WorkersCache {
114 workers: LruCache<SocketAddr, WorkerInfo>,
115
116 cancel: CancellationToken,
119}
120
121#[derive(Debug, Error, PartialEq)]
122pub enum WorkersCacheError {
123 #[error("Work receiver has been dropped unexpectedly.")]
125 ReceiverDropped,
126
127 #[error("Worker's channel is full.")]
128 FullChannel,
129
130 #[error("Task failed to join.")]
131 TaskJoinFailure,
132
133 #[error("The WorkersCache is being shutdown.")]
134 ShutdownError,
135
136 #[error("No worker exists for the specified peer.")]
137 WorkerNotFound,
138}
139
140impl WorkersCache {
141 pub fn new(capacity: usize, cancel: CancellationToken) -> Self {
142 Self {
143 workers: LruCache::new(capacity),
144 cancel,
145 }
146 }
147
148 pub fn contains(&self, peer: &SocketAddr) -> bool {
151 self.workers.contains(peer)
152 }
153
154 pub fn push(&mut self, leader: SocketAddr, peer_worker: WorkerInfo) -> Option<ShutdownWorker> {
155 if let Some((leader, popped_worker)) = self.workers.push(leader, peer_worker) {
156 return Some(ShutdownWorker {
157 leader,
158 worker: popped_worker,
159 });
160 }
161 None
162 }
163
164 pub fn pop(&mut self, leader: SocketAddr) -> Option<ShutdownWorker> {
165 if let Some(popped_worker) = self.workers.pop(&leader) {
166 return Some(ShutdownWorker {
167 leader,
168 worker: popped_worker,
169 });
170 }
171 None
172 }
173
174 pub fn ensure_worker(
178 &mut self,
179 peer: SocketAddr,
180 endpoint: &Endpoint,
181 worker_channel_size: usize,
182 skip_check_transaction_age: bool,
183 max_reconnect_attempts: usize,
184 handshake_timeout: Duration,
185 stats: Arc<SendTransactionStats>,
186 ) -> Option<ShutdownWorker> {
187 if let Some(worker) = self.workers.peek(&peer) {
188 if worker.is_active() {
191 return None;
192 }
193 }
194 trace!("No active worker for peer {peer}, respawning.");
195
196 let worker = spawn_worker(
197 endpoint,
198 &peer,
199 worker_channel_size,
200 skip_check_transaction_age,
201 max_reconnect_attempts,
202 handshake_timeout,
203 stats,
204 );
205
206 self.push(peer, worker)
207 }
208
209 pub fn try_send_transactions_to_address(
225 &mut self,
226 peer: &SocketAddr,
227 txs_batch: TransactionBatch,
228 ) -> Result<(), WorkersCacheError> {
229 let Self {
230 workers, cancel, ..
231 } = self;
232 if cancel.is_cancelled() {
233 return Err(WorkersCacheError::ShutdownError);
234 }
235
236 let current_worker = workers.get(peer).ok_or(WorkersCacheError::WorkerNotFound)?;
237
238 let send_res = current_worker.try_send_transactions(txs_batch);
239
240 if let Err(WorkersCacheError::ReceiverDropped) = send_res {
241 debug!(
242 "Failed to deliver transaction batch for leader {}, drop batch.",
243 peer.ip()
244 );
245 if let Some(current_worker) = workers.pop(peer) {
246 shutdown_worker(ShutdownWorker {
247 leader: *peer,
248 worker: current_worker,
249 })
250 }
251 }
252
253 send_res
254 }
255
256 #[allow(
262 dead_code,
263 reason = "This method will be used in the upcoming changes to implement optional \
264 backpressure on the sender."
265 )]
266 pub async fn send_transactions_to_address(
267 &mut self,
268 peer: &SocketAddr,
269 txs_batch: TransactionBatch,
270 ) -> Result<(), WorkersCacheError> {
271 let Self {
272 workers, cancel, ..
273 } = self;
274
275 let body = async move {
276 let current_worker = workers.get(peer).ok_or(WorkersCacheError::WorkerNotFound)?;
277
278 let send_res = current_worker.send_transactions(txs_batch).await;
279 if let Err(WorkersCacheError::ReceiverDropped) = send_res {
280 if let Some(current_worker) = workers.pop(peer) {
282 shutdown_worker(ShutdownWorker {
283 leader: *peer,
284 worker: current_worker,
285 })
286 }
287 }
288
289 send_res
290 };
291
292 cancel
293 .run_until_cancelled(body)
294 .await
295 .unwrap_or(Err(WorkersCacheError::ShutdownError))
296 }
297
298 pub(crate) fn flush(&mut self) {
301 while let Some((peer, current_worker)) = self.workers.pop_lru() {
302 shutdown_worker(ShutdownWorker {
303 leader: peer,
304 worker: current_worker,
305 });
306 }
307 }
308
309 pub async fn shutdown(&mut self) {
315 self.cancel.cancel();
317
318 let mut tasks = JoinSet::new();
319 while let Some((peer, current_worker)) = self.workers.pop_lru() {
320 let shutdown_worker = ShutdownWorker {
321 leader: peer,
322 worker: current_worker,
323 };
324 tasks.spawn(shutdown_worker.shutdown());
325 }
326 while let Some(res) = tasks.join_next().await {
327 if let Err(err) = res {
328 debug!("A shutdown task failed: {err}");
329 }
330 }
331 }
332}
333
334pub struct ShutdownWorker {
338 leader: SocketAddr,
339 worker: WorkerInfo,
340}
341
342impl ShutdownWorker {
343 pub(crate) fn leader(&self) -> SocketAddr {
344 self.leader
345 }
346
347 pub(crate) async fn shutdown(self) -> Result<(), WorkersCacheError> {
348 self.worker.shutdown().await
349 }
350}
351
352pub fn shutdown_worker(worker: ShutdownWorker) {
353 tokio::spawn(async move {
354 let leader = worker.leader();
355 let res = worker.shutdown().await;
356 if let Err(err) = res {
357 debug!("Error while shutting down worker for {leader}: {err}");
358 }
359 });
360}
361
362#[cfg(test)]
363mod tests {
364 use {
365 crate::{
366 connection_worker::DEFAULT_MAX_CONNECTION_HANDSHAKE_TIMEOUT,
367 connection_workers_scheduler::BindTarget,
368 quic_networking::{create_client_config, create_client_endpoint},
369 send_transaction_stats::SendTransactionStatsNonAtomic,
370 transaction_batch::TransactionBatch,
371 workers_cache::{spawn_worker, WorkersCache, WorkersCacheError},
372 SendTransactionStats,
373 },
374 quinn::Endpoint,
375 solana_net_utils::sockets::{bind_to_localhost_unique, unique_port_range_for_tests},
376 solana_tls_utils::QuicClientCertificate,
377 std::{
378 net::{Ipv4Addr, SocketAddr},
379 sync::Arc,
380 time::Duration,
381 },
382 tokio::time::{sleep, timeout, Instant},
383 tokio_util::sync::CancellationToken,
384 };
385
386 const TEST_MAX_TIME: Duration = Duration::from_secs(5);
388
389 fn create_test_endpoint() -> Endpoint {
390 let socket = bind_to_localhost_unique().unwrap();
391 let client_config = create_client_config(&QuicClientCertificate::new(None));
392 create_client_endpoint(BindTarget::Socket(socket), client_config).unwrap()
393 }
394
395 #[tokio::test]
396 async fn test_worker_stopped_after_failed_connect() {
397 let endpoint = create_test_endpoint();
398
399 let port_range = unique_port_range_for_tests(2);
400 let peer: SocketAddr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), port_range.start);
401
402 let worker_channel_size = 1;
403 let skip_check_transaction_age = true;
404 let max_reconnect_attempts = 0;
405 let stats = Arc::new(SendTransactionStats::default());
406 let worker_info = spawn_worker(
407 &endpoint,
408 &peer,
409 worker_channel_size,
410 skip_check_transaction_age,
411 max_reconnect_attempts,
412 DEFAULT_MAX_CONNECTION_HANDSHAKE_TIMEOUT,
413 stats.clone(),
414 );
415
416 timeout(TEST_MAX_TIME, worker_info.handle)
417 .await
418 .unwrap_or_else(|_| panic!("Should stop in less than {TEST_MAX_TIME:?}."))
419 .expect("Worker task should finish successfully.");
420 assert_eq!(
421 stats.read_and_reset(),
422 SendTransactionStatsNonAtomic {
423 connection_error_timed_out: 1,
424 ..Default::default()
425 }
426 );
427 }
428
429 #[tokio::test]
430 async fn test_worker_shutdown() {
431 let endpoint = create_test_endpoint();
432
433 let port_range = unique_port_range_for_tests(2);
434 let peer: SocketAddr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), port_range.start);
435
436 let worker_channel_size = 1;
437 let skip_check_transaction_age = true;
438 let max_reconnect_attempts = 0;
439 let stats = Arc::new(SendTransactionStats::default());
440 let worker_info = spawn_worker(
441 &endpoint,
442 &peer,
443 worker_channel_size,
444 skip_check_transaction_age,
445 max_reconnect_attempts,
446 DEFAULT_MAX_CONNECTION_HANDSHAKE_TIMEOUT,
447 stats.clone(),
448 );
449
450 timeout(TEST_MAX_TIME, worker_info.shutdown())
451 .await
452 .unwrap_or_else(|_| panic!("Should stop in less than {TEST_MAX_TIME:?}."))
453 .expect("Worker task should finish successfully.");
454 }
455
456 #[tokio::test]
460 async fn test_worker_removed_after_exit() {
461 let endpoint = create_test_endpoint();
462
463 let cancel = CancellationToken::new();
464 let mut cache = WorkersCache::new(10, cancel.clone());
465
466 let port_range = unique_port_range_for_tests(2);
467 let peer: SocketAddr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), port_range.start);
468 let worker_channel_size = 1;
469 let skip_check_transaction_age = true;
470 let max_reconnect_attempts = 0;
471 let stats = Arc::new(SendTransactionStats::default());
472 let worker = spawn_worker(
473 &endpoint,
474 &peer,
475 worker_channel_size,
476 skip_check_transaction_age,
477 max_reconnect_attempts,
478 DEFAULT_MAX_CONNECTION_HANDSHAKE_TIMEOUT,
479 stats.clone(),
480 );
481 assert!(cache.push(peer, worker).is_none());
482
483 let worker_info = cache.workers.peek(&peer).unwrap();
484 let start = Instant::now();
486 while !worker_info.sender.is_closed() {
487 if start.elapsed() > TEST_MAX_TIME {
488 panic!("Sender did not close in {TEST_MAX_TIME:?}");
489 }
490 sleep(Duration::from_millis(500)).await;
491 }
492
493 assert!(!worker_info.is_active(), "Worker should be inactive");
494
495 let result = cache
497 .try_send_transactions_to_address(&peer, TransactionBatch::new(vec![vec![0u8; 1]]));
498
499 assert_eq!(result, Err(WorkersCacheError::ReceiverDropped));
500 assert!(
501 !cache.contains(&peer),
502 "worker should be removed after failure"
503 );
504
505 cancel.cancel();
507 cache.shutdown().await;
508
509 assert_eq!(
510 stats.read_and_reset(),
511 SendTransactionStatsNonAtomic {
512 connection_error_timed_out: 1,
513 ..Default::default()
514 }
515 );
516 }
517}