Skip to main content

solana_tpu_client_next/
workers_cache.rs

1//! This module defines [`WorkersCache`] along with aux struct [`WorkerInfo`]. These
2//! structures provide mechanisms for caching workers, sending transaction
3//! batches, and gathering send transaction statistics.
4
5use {
6    crate::{
7        connection_worker::ConnectionWorker,
8        logging::{debug, trace},
9        transaction_batch::TransactionBatch,
10        SendTransactionStats,
11    },
12    lru::LruCache,
13    quinn::Endpoint,
14    std::{net::SocketAddr, sync::Arc, time::Duration},
15    thiserror::Error,
16    tokio::{
17        sync::mpsc::{self, error::TrySendError},
18        task::{JoinHandle, JoinSet},
19    },
20    tokio_util::sync::CancellationToken,
21};
22
23/// [`WorkerInfo`] holds information about a worker responsible for sending
24/// transaction batches.
25pub struct WorkerInfo {
26    sender: mpsc::Sender<TransactionBatch>,
27    handle: JoinHandle<()>,
28    cancel: CancellationToken,
29}
30
31impl WorkerInfo {
32    pub fn new(
33        sender: mpsc::Sender<TransactionBatch>,
34        handle: JoinHandle<()>,
35        cancel: CancellationToken,
36    ) -> Self {
37        Self {
38            sender,
39            handle,
40            cancel,
41        }
42    }
43
44    fn try_send_transactions(&self, txs_batch: TransactionBatch) -> Result<(), WorkersCacheError> {
45        self.sender.try_send(txs_batch).map_err(|err| match err {
46            TrySendError::Full(_) => WorkersCacheError::FullChannel,
47            TrySendError::Closed(_) => WorkersCacheError::ReceiverDropped,
48        })?;
49        Ok(())
50    }
51
52    async fn send_transactions(
53        &self,
54        txs_batch: TransactionBatch,
55    ) -> Result<(), WorkersCacheError> {
56        self.sender
57            .send(txs_batch)
58            .await
59            .map_err(|_| WorkersCacheError::ReceiverDropped)?;
60        Ok(())
61    }
62
63    /// Closes the worker by dropping the sender and awaiting the worker's
64    /// statistics.
65    async fn shutdown(self) -> Result<(), WorkersCacheError> {
66        self.cancel.cancel();
67        drop(self.sender);
68        self.handle
69            .await
70            .map_err(|_| WorkersCacheError::TaskJoinFailure)?;
71        Ok(())
72    }
73
74    /// Returns `true` if the worker is still active and able to send
75    /// transactions.
76    fn is_active(&self) -> bool {
77        !(self.cancel.is_cancelled() || self.sender.is_closed())
78    }
79}
80
81/// Spawns a worker to handle communication with a given peer.
82pub fn spawn_worker(
83    endpoint: &Endpoint,
84    peer: &SocketAddr,
85    worker_channel_size: usize,
86    skip_check_transaction_age: bool,
87    max_reconnect_attempts: usize,
88    handshake_timeout: Duration,
89    stats: Arc<SendTransactionStats>,
90) -> WorkerInfo {
91    let (txs_sender, txs_receiver) = mpsc::channel(worker_channel_size);
92    let endpoint = endpoint.clone();
93    let peer = *peer;
94
95    let (mut worker, cancel) = ConnectionWorker::new(
96        endpoint,
97        peer,
98        txs_receiver,
99        skip_check_transaction_age,
100        max_reconnect_attempts,
101        stats,
102        handshake_timeout,
103    );
104    let handle = tokio::spawn(async move {
105        worker.run().await;
106    });
107
108    WorkerInfo::new(txs_sender, handle, cancel)
109}
110
111/// [`WorkersCache`] manages and caches workers. It uses an LRU cache to store and
112/// manage workers. It also tracks transaction statistics for each peer.
113pub struct WorkersCache {
114    workers: LruCache<SocketAddr, WorkerInfo>,
115
116    /// Indicates that the `WorkersCache` is been `shutdown()`, interrupting any outstanding
117    /// `send_transactions_to_address()` invocations.
118    cancel: CancellationToken,
119}
120
121#[derive(Debug, Error, PartialEq)]
122pub enum WorkersCacheError {
123    /// typically happens when the client could not establish the connection.
124    #[error("Work receiver has been dropped unexpectedly.")]
125    ReceiverDropped,
126
127    #[error("Worker's channel is full.")]
128    FullChannel,
129
130    #[error("Task failed to join.")]
131    TaskJoinFailure,
132
133    #[error("The WorkersCache is being shutdown.")]
134    ShutdownError,
135
136    #[error("No worker exists for the specified peer.")]
137    WorkerNotFound,
138}
139
140impl WorkersCache {
141    pub fn new(capacity: usize, cancel: CancellationToken) -> Self {
142        Self {
143            workers: LruCache::new(capacity),
144            cancel,
145        }
146    }
147
148    /// Checks if the worker for a given peer exists and it hasn't been
149    /// cancelled.
150    pub fn contains(&self, peer: &SocketAddr) -> bool {
151        self.workers.contains(peer)
152    }
153
154    pub fn push(&mut self, leader: SocketAddr, peer_worker: WorkerInfo) -> Option<ShutdownWorker> {
155        if let Some((leader, popped_worker)) = self.workers.push(leader, peer_worker) {
156            return Some(ShutdownWorker {
157                leader,
158                worker: popped_worker,
159            });
160        }
161        None
162    }
163
164    pub fn pop(&mut self, leader: SocketAddr) -> Option<ShutdownWorker> {
165        if let Some(popped_worker) = self.workers.pop(&leader) {
166            return Some(ShutdownWorker {
167                leader,
168                worker: popped_worker,
169            });
170        }
171        None
172    }
173
174    /// Ensures a worker exists for the given peer, creating one if necessary.
175    ///
176    /// Returns any evicted worker that needs shutdown.
177    pub fn ensure_worker(
178        &mut self,
179        peer: SocketAddr,
180        endpoint: &Endpoint,
181        worker_channel_size: usize,
182        skip_check_transaction_age: bool,
183        max_reconnect_attempts: usize,
184        handshake_timeout: Duration,
185        stats: Arc<SendTransactionStats>,
186    ) -> Option<ShutdownWorker> {
187        if let Some(worker) = self.workers.peek(&peer) {
188            // if worker is active, we will reuse it. Otherwise, we will spawn
189            // the new one and the existing will be popped out.
190            if worker.is_active() {
191                return None;
192            }
193        }
194        trace!("No active worker for peer {peer}, respawning.");
195
196        let worker = spawn_worker(
197            endpoint,
198            &peer,
199            worker_channel_size,
200            skip_check_transaction_age,
201            max_reconnect_attempts,
202            handshake_timeout,
203            stats,
204        );
205
206        self.push(peer, worker)
207    }
208
209    /// Attempts to send immediately a batch of transactions to the worker for a
210    /// given peer.
211    ///
212    /// This method returns immediately if the channel of worker corresponding
213    /// to this peer is full returning error [`WorkersCacheError::FullChannel`].
214    /// If no worker exists for the peer, it returns
215    /// [`WorkersCacheError::WorkerNotFound`]. If it happens that the peer's
216    /// worker is stopped, it returns [`WorkersCacheError::ShutdownError`].
217    /// In case if the worker is not stopped but it's channel is unexpectedly
218    /// dropped, it returns [`WorkersCacheError::ReceiverDropped`].
219    ///
220    /// Note: The worker existence check is necessary because workers can fail
221    /// asynchronously between creation and sending. Worker tasks may exit
222    /// due to connection failures, network issues, or cache evictions,
223    /// making a previously created worker unavailable.
224    pub fn try_send_transactions_to_address(
225        &mut self,
226        peer: &SocketAddr,
227        txs_batch: TransactionBatch,
228    ) -> Result<(), WorkersCacheError> {
229        let Self {
230            workers, cancel, ..
231        } = self;
232        if cancel.is_cancelled() {
233            return Err(WorkersCacheError::ShutdownError);
234        }
235
236        let current_worker = workers.get(peer).ok_or(WorkersCacheError::WorkerNotFound)?;
237
238        let send_res = current_worker.try_send_transactions(txs_batch);
239
240        if let Err(WorkersCacheError::ReceiverDropped) = send_res {
241            debug!(
242                "Failed to deliver transaction batch for leader {}, drop batch.",
243                peer.ip()
244            );
245            if let Some(current_worker) = workers.pop(peer) {
246                shutdown_worker(ShutdownWorker {
247                    leader: *peer,
248                    worker: current_worker,
249                })
250            }
251        }
252
253        send_res
254    }
255
256    /// Sends a batch of transactions to the worker for a given peer.
257    ///
258    /// If the worker for the peer is disconnected or fails, it
259    /// is removed from the cache. If no worker exists for the peer,
260    /// it returns [`WorkersCacheError::WorkerNotFound`].
261    #[allow(
262        dead_code,
263        reason = "This method will be used in the upcoming changes to implement optional \
264                  backpressure on the sender."
265    )]
266    pub async fn send_transactions_to_address(
267        &mut self,
268        peer: &SocketAddr,
269        txs_batch: TransactionBatch,
270    ) -> Result<(), WorkersCacheError> {
271        let Self {
272            workers, cancel, ..
273        } = self;
274
275        let body = async move {
276            let current_worker = workers.get(peer).ok_or(WorkersCacheError::WorkerNotFound)?;
277
278            let send_res = current_worker.send_transactions(txs_batch).await;
279            if let Err(WorkersCacheError::ReceiverDropped) = send_res {
280                // Remove the worker from the cache, if the peer has disconnected.
281                if let Some(current_worker) = workers.pop(peer) {
282                    shutdown_worker(ShutdownWorker {
283                        leader: *peer,
284                        worker: current_worker,
285                    })
286                }
287            }
288
289            send_res
290        };
291
292        cancel
293            .run_until_cancelled(body)
294            .await
295            .unwrap_or(Err(WorkersCacheError::ShutdownError))
296    }
297
298    /// Flushes the cache and asynchronously shuts down all workers. This method
299    /// doesn't wait for the completion of all the shutdown tasks.
300    pub(crate) fn flush(&mut self) {
301        while let Some((peer, current_worker)) = self.workers.pop_lru() {
302            shutdown_worker(ShutdownWorker {
303                leader: peer,
304                worker: current_worker,
305            });
306        }
307    }
308
309    /// Closes and removes all workers in the cache. This is typically done when
310    /// shutting down the system.
311    ///
312    /// The method awaits the completion of all shutdown tasks, ensuring that
313    /// each worker is properly terminated.
314    pub async fn shutdown(&mut self) {
315        // Interrupt any outstanding `send_transactions()` calls.
316        self.cancel.cancel();
317
318        let mut tasks = JoinSet::new();
319        while let Some((peer, current_worker)) = self.workers.pop_lru() {
320            let shutdown_worker = ShutdownWorker {
321                leader: peer,
322                worker: current_worker,
323            };
324            tasks.spawn(shutdown_worker.shutdown());
325        }
326        while let Some(res) = tasks.join_next().await {
327            if let Err(err) = res {
328                debug!("A shutdown task failed: {err}");
329            }
330        }
331    }
332}
333
334/// [`ShutdownWorker`] takes care of stopping the worker. It's method
335/// `shutdown()` should be executed in a separate task to hide the latency of
336/// finishing worker gracefully.
337pub struct ShutdownWorker {
338    leader: SocketAddr,
339    worker: WorkerInfo,
340}
341
342impl ShutdownWorker {
343    pub(crate) fn leader(&self) -> SocketAddr {
344        self.leader
345    }
346
347    pub(crate) async fn shutdown(self) -> Result<(), WorkersCacheError> {
348        self.worker.shutdown().await
349    }
350}
351
352pub fn shutdown_worker(worker: ShutdownWorker) {
353    tokio::spawn(async move {
354        let leader = worker.leader();
355        let res = worker.shutdown().await;
356        if let Err(err) = res {
357            debug!("Error while shutting down worker for {leader}: {err}");
358        }
359    });
360}
361
362#[cfg(test)]
363mod tests {
364    use {
365        crate::{
366            connection_worker::DEFAULT_MAX_CONNECTION_HANDSHAKE_TIMEOUT,
367            connection_workers_scheduler::BindTarget,
368            quic_networking::{create_client_config, create_client_endpoint},
369            send_transaction_stats::SendTransactionStatsNonAtomic,
370            transaction_batch::TransactionBatch,
371            workers_cache::{spawn_worker, WorkersCache, WorkersCacheError},
372            SendTransactionStats,
373        },
374        quinn::Endpoint,
375        solana_net_utils::sockets::{bind_to_localhost_unique, unique_port_range_for_tests},
376        solana_tls_utils::QuicClientCertificate,
377        std::{
378            net::{Ipv4Addr, SocketAddr},
379            sync::Arc,
380            time::Duration,
381        },
382        tokio::time::{sleep, timeout, Instant},
383        tokio_util::sync::CancellationToken,
384    };
385
386    // Specify the pessimistic time to finish generation and result checks.
387    const TEST_MAX_TIME: Duration = Duration::from_secs(5);
388
389    fn create_test_endpoint() -> Endpoint {
390        let socket = bind_to_localhost_unique().unwrap();
391        let client_config = create_client_config(&QuicClientCertificate::new(None));
392        create_client_endpoint(BindTarget::Socket(socket), client_config).unwrap()
393    }
394
395    #[tokio::test]
396    async fn test_worker_stopped_after_failed_connect() {
397        let endpoint = create_test_endpoint();
398
399        let port_range = unique_port_range_for_tests(2);
400        let peer: SocketAddr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), port_range.start);
401
402        let worker_channel_size = 1;
403        let skip_check_transaction_age = true;
404        let max_reconnect_attempts = 0;
405        let stats = Arc::new(SendTransactionStats::default());
406        let worker_info = spawn_worker(
407            &endpoint,
408            &peer,
409            worker_channel_size,
410            skip_check_transaction_age,
411            max_reconnect_attempts,
412            DEFAULT_MAX_CONNECTION_HANDSHAKE_TIMEOUT,
413            stats.clone(),
414        );
415
416        timeout(TEST_MAX_TIME, worker_info.handle)
417            .await
418            .unwrap_or_else(|_| panic!("Should stop in less than {TEST_MAX_TIME:?}."))
419            .expect("Worker task should finish successfully.");
420        assert_eq!(
421            stats.read_and_reset(),
422            SendTransactionStatsNonAtomic {
423                connection_error_timed_out: 1,
424                ..Default::default()
425            }
426        );
427    }
428
429    #[tokio::test]
430    async fn test_worker_shutdown() {
431        let endpoint = create_test_endpoint();
432
433        let port_range = unique_port_range_for_tests(2);
434        let peer: SocketAddr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), port_range.start);
435
436        let worker_channel_size = 1;
437        let skip_check_transaction_age = true;
438        let max_reconnect_attempts = 0;
439        let stats = Arc::new(SendTransactionStats::default());
440        let worker_info = spawn_worker(
441            &endpoint,
442            &peer,
443            worker_channel_size,
444            skip_check_transaction_age,
445            max_reconnect_attempts,
446            DEFAULT_MAX_CONNECTION_HANDSHAKE_TIMEOUT,
447            stats.clone(),
448        );
449
450        timeout(TEST_MAX_TIME, worker_info.shutdown())
451            .await
452            .unwrap_or_else(|_| panic!("Should stop in less than {TEST_MAX_TIME:?}."))
453            .expect("Worker task should finish successfully.");
454    }
455
456    // Verifies that a worker which terminates (e.g. due to connection failure)
457    // is properly detected, its sender is closed, and it is removed from the
458    // `WorkersCache`.
459    #[tokio::test]
460    async fn test_worker_removed_after_exit() {
461        let endpoint = create_test_endpoint();
462
463        let cancel = CancellationToken::new();
464        let mut cache = WorkersCache::new(10, cancel.clone());
465
466        let port_range = unique_port_range_for_tests(2);
467        let peer: SocketAddr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), port_range.start);
468        let worker_channel_size = 1;
469        let skip_check_transaction_age = true;
470        let max_reconnect_attempts = 0;
471        let stats = Arc::new(SendTransactionStats::default());
472        let worker = spawn_worker(
473            &endpoint,
474            &peer,
475            worker_channel_size,
476            skip_check_transaction_age,
477            max_reconnect_attempts,
478            DEFAULT_MAX_CONNECTION_HANDSHAKE_TIMEOUT,
479            stats.clone(),
480        );
481        assert!(cache.push(peer, worker).is_none());
482
483        let worker_info = cache.workers.peek(&peer).unwrap();
484        // wait until sender is closed which happens when task has finished.
485        let start = Instant::now();
486        while !worker_info.sender.is_closed() {
487            if start.elapsed() > TEST_MAX_TIME {
488                panic!("Sender did not close in {TEST_MAX_TIME:?}");
489            }
490            sleep(Duration::from_millis(500)).await;
491        }
492
493        assert!(!worker_info.is_active(), "Worker should be inactive");
494
495        // try to send to this worker — should fail and remove the worker
496        let result = cache
497            .try_send_transactions_to_address(&peer, TransactionBatch::new(vec![vec![0u8; 1]]));
498
499        assert_eq!(result, Err(WorkersCacheError::ReceiverDropped));
500        assert!(
501            !cache.contains(&peer),
502            "worker should be removed after failure"
503        );
504
505        // Cleanup
506        cancel.cancel();
507        cache.shutdown().await;
508
509        assert_eq!(
510            stats.read_and_reset(),
511            SendTransactionStatsNonAtomic {
512                connection_error_timed_out: 1,
513                ..Default::default()
514            }
515        );
516    }
517}