solana_tpu_client/nonblocking/
tpu_client.rs

1pub use crate::tpu_client::Result;
2use {
3    crate::tpu_client::{RecentLeaderSlots, TpuClientConfig, MAX_FANOUT_SLOTS},
4    bincode::serialize,
5    futures_util::{future::join_all, stream::StreamExt},
6    log::*,
7    solana_clock::{Slot, DEFAULT_MS_PER_SLOT, NUM_CONSECUTIVE_LEADER_SLOTS},
8    solana_commitment_config::CommitmentConfig,
9    solana_connection_cache::{
10        connection_cache::{
11            ConnectionCache, ConnectionManager, ConnectionPool, NewConnectionConfig, Protocol,
12            DEFAULT_CONNECTION_POOL_SIZE,
13        },
14        nonblocking::client_connection::ClientConnection,
15    },
16    solana_epoch_schedule::EpochSchedule,
17    solana_pubkey::Pubkey,
18    solana_pubsub_client::nonblocking::pubsub_client::{PubsubClient, PubsubClientError},
19    solana_quic_definitions::QUIC_PORT_OFFSET,
20    solana_rpc_client::nonblocking::rpc_client::RpcClient,
21    solana_rpc_client_api::{
22        client_error::{Error as ClientError, ErrorKind, Result as ClientResult},
23        request::RpcError,
24        response::{RpcContactInfo, SlotUpdate},
25    },
26    solana_signer::SignerError,
27    solana_transaction::Transaction,
28    solana_transaction_error::{TransportError, TransportResult},
29    std::{
30        collections::{HashMap, HashSet},
31        net::SocketAddr,
32        str::FromStr,
33        sync::{
34            atomic::{AtomicBool, Ordering},
35            Arc, RwLock,
36        },
37    },
38    thiserror::Error,
39    tokio::{
40        task::JoinHandle,
41        time::{sleep, timeout, Duration, Instant},
42    },
43};
44#[cfg(feature = "spinner")]
45use {
46    crate::tpu_client::{SEND_TRANSACTION_INTERVAL, TRANSACTION_RESEND_INTERVAL},
47    futures_util::FutureExt,
48    indicatif::ProgressBar,
49    solana_message::Message,
50    solana_rpc_client::spinner::{self, SendTransactionProgress},
51    solana_rpc_client_api::request::MAX_GET_SIGNATURE_STATUSES_QUERY_ITEMS,
52    solana_signer::signers::Signers,
53    solana_transaction_error::TransactionError,
54    std::{future::Future, iter},
55};
56
57#[derive(Error, Debug)]
58pub enum TpuSenderError {
59    #[error("Pubsub error: {0:?}")]
60    PubsubError(#[from] PubsubClientError),
61    #[error("RPC error: {0:?}")]
62    RpcError(#[from] ClientError),
63    #[error("IO error: {0:?}")]
64    IoError(#[from] std::io::Error),
65    #[error("Signer error: {0:?}")]
66    SignerError(#[from] SignerError),
67    #[error("Custom error: {0}")]
68    Custom(String),
69}
70
71struct LeaderTpuCacheUpdateInfo {
72    pub(super) maybe_cluster_nodes: Option<ClientResult<Vec<RpcContactInfo>>>,
73    pub(super) maybe_epoch_schedule: Option<ClientResult<EpochSchedule>>,
74    pub(super) maybe_slot_leaders: Option<ClientResult<Vec<Pubkey>>>,
75}
76impl LeaderTpuCacheUpdateInfo {
77    pub fn has_some(&self) -> bool {
78        self.maybe_cluster_nodes.is_some()
79            || self.maybe_epoch_schedule.is_some()
80            || self.maybe_slot_leaders.is_some()
81    }
82}
83
84struct LeaderTpuCache {
85    protocol: Protocol,
86    first_slot: Slot,
87    leaders: Vec<Pubkey>,
88    leader_tpu_map: HashMap<Pubkey, SocketAddr>,
89    slots_in_epoch: Slot,
90    last_slot_in_epoch: Slot,
91}
92
93impl LeaderTpuCache {
94    pub fn new(
95        first_slot: Slot,
96        slots_in_epoch: Slot,
97        last_slot_in_epoch: Slot,
98        leaders: Vec<Pubkey>,
99        cluster_nodes: Vec<RpcContactInfo>,
100        protocol: Protocol,
101    ) -> Self {
102        let leader_tpu_map = Self::extract_cluster_tpu_sockets(protocol, cluster_nodes);
103        Self {
104            protocol,
105            first_slot,
106            leaders,
107            leader_tpu_map,
108            slots_in_epoch,
109            last_slot_in_epoch,
110        }
111    }
112
113    // Last slot that has a cached leader pubkey
114    pub fn last_slot(&self) -> Slot {
115        self.first_slot + self.leaders.len().saturating_sub(1) as u64
116    }
117
118    pub fn slot_info(&self) -> (Slot, Slot, Slot) {
119        (
120            self.last_slot(),
121            self.last_slot_in_epoch,
122            self.slots_in_epoch,
123        )
124    }
125
126    // Get the TPU sockets for the current leader and upcoming *unique* leaders according to fanout size.
127    fn get_unique_leader_sockets(
128        &self,
129        estimated_current_slot: Slot,
130        fanout_slots: u64,
131    ) -> Vec<SocketAddr> {
132        let all_leader_sockets = self.get_leader_sockets(estimated_current_slot, fanout_slots);
133
134        let mut unique_sockets = Vec::new();
135        let mut seen = HashSet::new();
136
137        for socket in all_leader_sockets {
138            if seen.insert(socket) {
139                unique_sockets.push(socket);
140            }
141        }
142
143        unique_sockets
144    }
145
146    // Get the TPU sockets for the current leader and upcoming leaders according to fanout size.
147    fn get_leader_sockets(
148        &self,
149        estimated_current_slot: Slot,
150        fanout_slots: u64,
151    ) -> Vec<SocketAddr> {
152        let mut leader_sockets = Vec::new();
153        // `first_slot` might have been advanced since caller last read the `estimated_current_slot`
154        // value. Take the greater of the two values to ensure we are reading from the latest
155        // leader schedule.
156        let current_slot = std::cmp::max(estimated_current_slot, self.first_slot);
157        for leader_slot in (current_slot..current_slot + fanout_slots)
158            .step_by(NUM_CONSECUTIVE_LEADER_SLOTS as usize)
159        {
160            if let Some(leader) = self.get_slot_leader(leader_slot) {
161                if let Some(tpu_socket) = self.leader_tpu_map.get(leader) {
162                    leader_sockets.push(*tpu_socket);
163                } else {
164                    // The leader is probably delinquent
165                    trace!("TPU not available for leader {}", leader);
166                }
167            } else {
168                // Overran the local leader schedule cache
169                warn!(
170                    "Leader not known for slot {}; cache holds slots [{},{}]",
171                    leader_slot,
172                    self.first_slot,
173                    self.last_slot()
174                );
175            }
176        }
177        leader_sockets
178    }
179
180    pub fn get_slot_leader(&self, slot: Slot) -> Option<&Pubkey> {
181        if slot >= self.first_slot {
182            let index = slot - self.first_slot;
183            self.leaders.get(index as usize)
184        } else {
185            None
186        }
187    }
188
189    fn extract_cluster_tpu_sockets(
190        protocol: Protocol,
191        cluster_contact_info: Vec<RpcContactInfo>,
192    ) -> HashMap<Pubkey, SocketAddr> {
193        cluster_contact_info
194            .into_iter()
195            .filter_map(|contact_info| {
196                let pubkey = Pubkey::from_str(&contact_info.pubkey).ok()?;
197                let socket = match protocol {
198                    Protocol::QUIC => contact_info.tpu_quic.or_else(|| {
199                        let mut socket = contact_info.tpu?;
200                        let port = socket.port().checked_add(QUIC_PORT_OFFSET)?;
201                        socket.set_port(port);
202                        Some(socket)
203                    }),
204                    Protocol::UDP => contact_info.tpu,
205                }?;
206                Some((pubkey, socket))
207            })
208            .collect()
209    }
210
211    pub fn fanout(slots_in_epoch: Slot) -> Slot {
212        (2 * MAX_FANOUT_SLOTS).min(slots_in_epoch)
213    }
214
215    pub fn update_all(
216        &mut self,
217        estimated_current_slot: Slot,
218        cache_update_info: LeaderTpuCacheUpdateInfo,
219    ) -> (bool, bool) {
220        let mut has_error = false;
221        let mut cluster_refreshed = false;
222        if let Some(cluster_nodes) = cache_update_info.maybe_cluster_nodes {
223            match cluster_nodes {
224                Ok(cluster_nodes) => {
225                    self.leader_tpu_map =
226                        Self::extract_cluster_tpu_sockets(self.protocol, cluster_nodes);
227                    cluster_refreshed = true;
228                }
229                Err(err) => {
230                    warn!("Failed to fetch cluster tpu sockets: {}", err);
231                    has_error = true;
232                }
233            }
234        }
235
236        if let Some(Ok(epoch_schedule)) = cache_update_info.maybe_epoch_schedule {
237            let epoch = epoch_schedule.get_epoch(estimated_current_slot);
238            self.slots_in_epoch = epoch_schedule.get_slots_in_epoch(epoch);
239            self.last_slot_in_epoch = epoch_schedule.get_last_slot_in_epoch(epoch);
240        }
241
242        if let Some(slot_leaders) = cache_update_info.maybe_slot_leaders {
243            match slot_leaders {
244                Ok(slot_leaders) => {
245                    self.first_slot = estimated_current_slot;
246                    self.leaders = slot_leaders;
247                }
248                Err(err) => {
249                    warn!(
250                        "Failed to fetch slot leaders (current estimated slot: {}): {}",
251                        estimated_current_slot, err
252                    );
253                    has_error = true;
254                }
255            }
256        }
257        (has_error, cluster_refreshed)
258    }
259}
260
261/// Client which sends transactions directly to the current leader's TPU port over UDP.
262/// The client uses RPC to determine the current leader and fetch node contact info
263pub struct TpuClient<
264    P, // ConnectionPool
265    M, // ConnectionManager
266    C, // NewConnectionConfig
267> {
268    fanout_slots: u64,
269    leader_tpu_service: LeaderTpuService,
270    exit: Arc<AtomicBool>,
271    rpc_client: Arc<RpcClient>,
272    connection_cache: Arc<ConnectionCache<P, M, C>>,
273}
274
275/// Helper function which generates futures to all be awaited together for maximum
276/// throughput
277#[cfg(feature = "spinner")]
278fn send_wire_transaction_futures<'a, P, M, C>(
279    progress_bar: &'a ProgressBar,
280    progress: &'a SendTransactionProgress,
281    index: usize,
282    num_transactions: usize,
283    wire_transaction: Vec<u8>,
284    leaders: Vec<SocketAddr>,
285    connection_cache: &'a ConnectionCache<P, M, C>,
286) -> Vec<impl Future<Output = TransportResult<()>> + 'a>
287where
288    P: ConnectionPool<NewConnectionConfig = C>,
289    M: ConnectionManager<ConnectionPool = P, NewConnectionConfig = C>,
290    C: NewConnectionConfig,
291{
292    const SEND_TIMEOUT_INTERVAL: Duration = Duration::from_secs(5);
293    let sleep_duration = SEND_TRANSACTION_INTERVAL.saturating_mul(index as u32);
294    let send_timeout = SEND_TIMEOUT_INTERVAL.saturating_add(sleep_duration);
295    leaders
296        .into_iter()
297        .map(|addr| {
298            timeout_future(
299                send_timeout,
300                sleep_and_send_wire_transaction_to_addr(
301                    sleep_duration,
302                    connection_cache,
303                    addr,
304                    wire_transaction.clone(),
305                ),
306            )
307            .boxed_local() // required to make types work simply
308        })
309        .chain(iter::once(
310            timeout_future(
311                send_timeout,
312                sleep_and_set_message(
313                    sleep_duration,
314                    progress_bar,
315                    progress,
316                    index,
317                    num_transactions,
318                ),
319            )
320            .boxed_local(), // required to make types work simply
321        ))
322        .collect::<Vec<_>>()
323}
324
325// Wrap an existing future with a timeout.
326//
327// Useful for end-users who don't need a persistent connection to each validator,
328// and want to abort more quickly.
329#[cfg(feature = "spinner")]
330async fn timeout_future<Fut: Future<Output = TransportResult<()>>>(
331    timeout_duration: Duration,
332    future: Fut,
333) -> TransportResult<()> {
334    timeout(timeout_duration, future)
335        .await
336        .unwrap_or_else(|_| Err(TransportError::Custom("Timed out".to_string())))
337}
338
339#[cfg(feature = "spinner")]
340async fn sleep_and_set_message(
341    sleep_duration: Duration,
342    progress_bar: &ProgressBar,
343    progress: &SendTransactionProgress,
344    index: usize,
345    num_transactions: usize,
346) -> TransportResult<()> {
347    sleep(sleep_duration).await;
348    progress.set_message_for_confirmed_transactions(
349        progress_bar,
350        &format!("Sending {}/{} transactions", index + 1, num_transactions,),
351    );
352    Ok(())
353}
354
355#[cfg(feature = "spinner")]
356async fn sleep_and_send_wire_transaction_to_addr<P, M, C>(
357    sleep_duration: Duration,
358    connection_cache: &ConnectionCache<P, M, C>,
359    addr: SocketAddr,
360    wire_transaction: Vec<u8>,
361) -> TransportResult<()>
362where
363    P: ConnectionPool<NewConnectionConfig = C>,
364    M: ConnectionManager<ConnectionPool = P, NewConnectionConfig = C>,
365    C: NewConnectionConfig,
366{
367    sleep(sleep_duration).await;
368    send_wire_transaction_to_addr(connection_cache, &addr, wire_transaction).await
369}
370
371async fn send_wire_transaction_to_addr<P, M, C>(
372    connection_cache: &ConnectionCache<P, M, C>,
373    addr: &SocketAddr,
374    wire_transaction: Vec<u8>,
375) -> TransportResult<()>
376where
377    P: ConnectionPool<NewConnectionConfig = C>,
378    M: ConnectionManager<ConnectionPool = P, NewConnectionConfig = C>,
379    C: NewConnectionConfig,
380{
381    let conn = connection_cache.get_nonblocking_connection(addr);
382    conn.send_data(&wire_transaction).await
383}
384
385async fn send_wire_transaction_batch_to_addr<P, M, C>(
386    connection_cache: &ConnectionCache<P, M, C>,
387    addr: &SocketAddr,
388    wire_transactions: &[Vec<u8>],
389) -> TransportResult<()>
390where
391    P: ConnectionPool<NewConnectionConfig = C>,
392    M: ConnectionManager<ConnectionPool = P, NewConnectionConfig = C>,
393    C: NewConnectionConfig,
394{
395    let conn = connection_cache.get_nonblocking_connection(addr);
396    conn.send_data_batch(wire_transactions).await
397}
398
399impl<P, M, C> TpuClient<P, M, C>
400where
401    P: ConnectionPool<NewConnectionConfig = C>,
402    M: ConnectionManager<ConnectionPool = P, NewConnectionConfig = C>,
403    C: NewConnectionConfig,
404{
405    /// Serialize and send transaction to the current and upcoming leader TPUs according to fanout
406    /// size
407    pub async fn send_transaction(&self, transaction: &Transaction) -> bool {
408        let wire_transaction = serialize(transaction).expect("serialization should succeed");
409        self.send_wire_transaction(wire_transaction).await
410    }
411
412    /// Send a wire transaction to the current and upcoming leader TPUs according to fanout size
413    pub async fn send_wire_transaction(&self, wire_transaction: Vec<u8>) -> bool {
414        self.try_send_wire_transaction(wire_transaction)
415            .await
416            .is_ok()
417    }
418
419    /// Serialize and send transaction to the current and upcoming leader TPUs according to fanout
420    /// size
421    /// Returns the last error if all sends fail
422    pub async fn try_send_transaction(&self, transaction: &Transaction) -> TransportResult<()> {
423        let wire_transaction = serialize(transaction).expect("serialization should succeed");
424        self.try_send_wire_transaction(wire_transaction).await
425    }
426
427    /// Send a wire transaction to the current and upcoming leader TPUs according to fanout size
428    /// Returns the last error if all sends fail
429    pub async fn try_send_wire_transaction(
430        &self,
431        wire_transaction: Vec<u8>,
432    ) -> TransportResult<()> {
433        let leaders = self
434            .leader_tpu_service
435            .unique_leader_tpu_sockets(self.fanout_slots);
436        let futures = leaders
437            .iter()
438            .map(|addr| {
439                send_wire_transaction_to_addr(
440                    &self.connection_cache,
441                    addr,
442                    wire_transaction.clone(),
443                )
444            })
445            .collect::<Vec<_>>();
446        let results: Vec<TransportResult<()>> = join_all(futures).await;
447
448        let mut last_error: Option<TransportError> = None;
449        let mut some_success = false;
450        for result in results {
451            if let Err(e) = result {
452                if last_error.is_none() {
453                    last_error = Some(e);
454                }
455            } else {
456                some_success = true;
457            }
458        }
459        if !some_success {
460            Err(if let Some(err) = last_error {
461                err
462            } else {
463                std::io::Error::other("No sends attempted").into()
464            })
465        } else {
466            Ok(())
467        }
468    }
469
470    /// Send a batch of wire transactions to the current and upcoming leader TPUs according to
471    /// fanout size
472    /// Returns the last error if all sends fail
473    pub async fn try_send_wire_transaction_batch(
474        &self,
475        wire_transactions: Vec<Vec<u8>>,
476    ) -> TransportResult<()> {
477        let leaders = self
478            .leader_tpu_service
479            .unique_leader_tpu_sockets(self.fanout_slots);
480        let futures = leaders
481            .iter()
482            .map(|addr| {
483                send_wire_transaction_batch_to_addr(
484                    &self.connection_cache,
485                    addr,
486                    &wire_transactions,
487                )
488            })
489            .collect::<Vec<_>>();
490        let results: Vec<TransportResult<()>> = join_all(futures).await;
491
492        let mut last_error: Option<TransportError> = None;
493        let mut some_success = false;
494        for result in results {
495            if let Err(e) = result {
496                if last_error.is_none() {
497                    last_error = Some(e);
498                }
499            } else {
500                some_success = true;
501            }
502        }
503        if !some_success {
504            Err(if let Some(err) = last_error {
505                err
506            } else {
507                std::io::Error::other("No sends attempted").into()
508            })
509        } else {
510            Ok(())
511        }
512    }
513
514    /// Create a new client that disconnects when dropped
515    pub async fn new(
516        name: &'static str,
517        rpc_client: Arc<RpcClient>,
518        websocket_url: &str,
519        config: TpuClientConfig,
520        connection_manager: M,
521    ) -> Result<Self> {
522        let connection_cache = Arc::new(
523            ConnectionCache::new(name, connection_manager, DEFAULT_CONNECTION_POOL_SIZE).unwrap(),
524        ); // TODO: Handle error properly, as the ConnectionCache ctor is now fallible.
525        Self::new_with_connection_cache(rpc_client, websocket_url, config, connection_cache).await
526    }
527
528    /// Create a new client that disconnects when dropped
529    pub async fn new_with_connection_cache(
530        rpc_client: Arc<RpcClient>,
531        websocket_url: &str,
532        config: TpuClientConfig,
533        connection_cache: Arc<ConnectionCache<P, M, C>>,
534    ) -> Result<Self> {
535        let exit = Arc::new(AtomicBool::new(false));
536        let leader_tpu_service =
537            LeaderTpuService::new(rpc_client.clone(), websocket_url, M::PROTOCOL, exit.clone())
538                .await?;
539
540        Ok(Self {
541            fanout_slots: config.fanout_slots.clamp(1, MAX_FANOUT_SLOTS),
542            leader_tpu_service,
543            exit,
544            rpc_client,
545            connection_cache,
546        })
547    }
548
549    #[cfg(feature = "spinner")]
550    pub async fn send_and_confirm_messages_with_spinner<T: Signers + ?Sized>(
551        &self,
552        messages: &[Message],
553        signers: &T,
554    ) -> Result<Vec<Option<TransactionError>>> {
555        let mut progress = SendTransactionProgress::default();
556        let progress_bar = spinner::new_progress_bar();
557        progress_bar.set_message("Setting up...");
558
559        let mut transactions = messages
560            .iter()
561            .enumerate()
562            .map(|(i, message)| (i, Transaction::new_unsigned(message.clone())))
563            .collect::<Vec<_>>();
564        progress.total_transactions = transactions.len();
565        let mut transaction_errors = vec![None; transactions.len()];
566        progress.block_height = self.rpc_client.get_block_height().await?;
567        for expired_blockhash_retries in (0..5).rev() {
568            let (blockhash, last_valid_block_height) = self
569                .rpc_client
570                .get_latest_blockhash_with_commitment(self.rpc_client.commitment())
571                .await?;
572            progress.last_valid_block_height = last_valid_block_height;
573
574            let mut pending_transactions = HashMap::new();
575            for (i, mut transaction) in transactions {
576                transaction.try_sign(signers, blockhash)?;
577                pending_transactions.insert(transaction.signatures[0], (i, transaction));
578            }
579
580            let mut last_resend = Instant::now() - TRANSACTION_RESEND_INTERVAL;
581            while progress.block_height <= progress.last_valid_block_height {
582                let num_transactions = pending_transactions.len();
583
584                // Periodically re-send all pending transactions
585                if Instant::now().duration_since(last_resend) > TRANSACTION_RESEND_INTERVAL {
586                    // Prepare futures for all transactions
587                    let mut futures = vec![];
588                    for (index, (_i, transaction)) in pending_transactions.values().enumerate() {
589                        let wire_transaction = serialize(transaction).unwrap();
590                        let leaders = self
591                            .leader_tpu_service
592                            .unique_leader_tpu_sockets(self.fanout_slots);
593                        futures.extend(send_wire_transaction_futures(
594                            &progress_bar,
595                            &progress,
596                            index,
597                            num_transactions,
598                            wire_transaction,
599                            leaders,
600                            &self.connection_cache,
601                        ));
602                    }
603
604                    // Start the process of sending them all
605                    let results = join_all(futures).await;
606
607                    progress.set_message_for_confirmed_transactions(
608                        &progress_bar,
609                        "Checking sent transactions",
610                    );
611                    for (index, (tx_results, (_i, transaction))) in results
612                        .chunks(self.fanout_slots as usize)
613                        .zip(pending_transactions.values())
614                        .enumerate()
615                    {
616                        // Only report an error if every future in the chunk errored
617                        if tx_results.iter().all(|r| r.is_err()) {
618                            progress.set_message_for_confirmed_transactions(
619                                &progress_bar,
620                                &format!(
621                                    "Resending failed transaction {} of {}",
622                                    index + 1,
623                                    num_transactions,
624                                ),
625                            );
626                            let _result = self.rpc_client.send_transaction(transaction).await.ok();
627                        }
628                    }
629                    last_resend = Instant::now();
630                }
631
632                // Wait for the next block before checking for transaction statuses
633                let mut block_height_refreshes = 10;
634                progress.set_message_for_confirmed_transactions(
635                    &progress_bar,
636                    &format!("Waiting for next block, {num_transactions} transactions pending..."),
637                );
638                let mut new_block_height = progress.block_height;
639                while progress.block_height == new_block_height && block_height_refreshes > 0 {
640                    sleep(Duration::from_millis(500)).await;
641                    new_block_height = self.rpc_client.get_block_height().await?;
642                    block_height_refreshes -= 1;
643                }
644                progress.block_height = new_block_height;
645
646                // Collect statuses for the transactions, drop those that are confirmed
647                let pending_signatures = pending_transactions.keys().cloned().collect::<Vec<_>>();
648                for pending_signatures_chunk in
649                    pending_signatures.chunks(MAX_GET_SIGNATURE_STATUSES_QUERY_ITEMS)
650                {
651                    if let Ok(result) = self
652                        .rpc_client
653                        .get_signature_statuses(pending_signatures_chunk)
654                        .await
655                    {
656                        let statuses = result.value;
657                        for (signature, status) in
658                            pending_signatures_chunk.iter().zip(statuses.into_iter())
659                        {
660                            if let Some(status) = status {
661                                if status.satisfies_commitment(self.rpc_client.commitment()) {
662                                    if let Some((i, _)) = pending_transactions.remove(signature) {
663                                        progress.confirmed_transactions += 1;
664                                        if status.err.is_some() {
665                                            progress_bar
666                                                .println(format!("Failed transaction: {status:?}"));
667                                        }
668                                        transaction_errors[i] = status.err;
669                                    }
670                                }
671                            }
672                        }
673                    }
674                    progress.set_message_for_confirmed_transactions(
675                        &progress_bar,
676                        "Checking transaction status...",
677                    );
678                }
679
680                if pending_transactions.is_empty() {
681                    return Ok(transaction_errors);
682                }
683            }
684
685            transactions = pending_transactions.into_values().collect();
686            progress_bar.println(format!(
687                "Blockhash expired. {expired_blockhash_retries} retries remaining"
688            ));
689        }
690        Err(TpuSenderError::Custom("Max retries exceeded".into()))
691    }
692
693    pub fn rpc_client(&self) -> &RpcClient {
694        &self.rpc_client
695    }
696
697    pub async fn shutdown(&mut self) {
698        self.exit.store(true, Ordering::Relaxed);
699        self.leader_tpu_service.join().await;
700    }
701
702    pub fn get_connection_cache(&self) -> &Arc<ConnectionCache<P, M, C>>
703    where
704        P: ConnectionPool<NewConnectionConfig = C>,
705        M: ConnectionManager<ConnectionPool = P, NewConnectionConfig = C>,
706        C: NewConnectionConfig,
707    {
708        &self.connection_cache
709    }
710
711    pub fn get_leader_tpu_service(&self) -> &LeaderTpuService {
712        &self.leader_tpu_service
713    }
714
715    pub fn get_fanout_slots(&self) -> u64 {
716        self.fanout_slots
717    }
718}
719
720impl<P, M, C> Drop for TpuClient<P, M, C> {
721    fn drop(&mut self) {
722        self.exit.store(true, Ordering::Relaxed);
723    }
724}
725
726/// Service that tracks upcoming leaders and maintains an up-to-date mapping
727/// of leader id to TPU socket address.
728pub struct LeaderTpuService {
729    recent_slots: RecentLeaderSlots,
730    leader_tpu_cache: Arc<RwLock<LeaderTpuCache>>,
731    t_leader_tpu_service: Option<JoinHandle<Result<()>>>,
732}
733
734impl LeaderTpuService {
735    pub async fn new(
736        rpc_client: Arc<RpcClient>,
737        websocket_url: &str,
738        protocol: Protocol,
739        exit: Arc<AtomicBool>,
740    ) -> Result<Self> {
741        let epoch_schedule = rpc_client.get_epoch_schedule().await?;
742        let start_slot = rpc_client
743            .get_slot_with_commitment(CommitmentConfig::processed())
744            .await?;
745
746        let recent_slots = RecentLeaderSlots::new(start_slot);
747        let epoch = epoch_schedule.get_epoch(start_slot);
748        let slots_in_epoch = epoch_schedule.get_slots_in_epoch(epoch);
749        let last_slot_in_epoch = epoch_schedule.get_last_slot_in_epoch(epoch);
750
751        // When a cluster is starting, we observe an invalid slot range failure that goes away after a
752        // retry. It seems as if the leader schedule is not available, but it should be. The logic
753        // below retries the RPC call in case of an invalid slot range error.
754        let tpu_leader_service_creation_timeout = Duration::from_secs(20);
755        let retry_interval = Duration::from_secs(1);
756        let leaders = timeout(tpu_leader_service_creation_timeout, async {
757            loop {
758                // TODO: The root cause appears to lie within the `rpc_client.get_slot_leaders()`.
759                // It might be worth debugging further and trying to understand why the RPC
760                // call fails. There may be a bug in the `get_slot_leaders()` logic or in the
761                // RPC implementation
762                match rpc_client
763                    .get_slot_leaders(start_slot, LeaderTpuCache::fanout(slots_in_epoch))
764                    .await
765                {
766                    Ok(leaders) => return Ok(leaders),
767                    Err(client_error) => {
768                        if is_invalid_slot_range_error(&client_error) {
769                            sleep(retry_interval).await;
770                            continue;
771                        } else {
772                            return Err(client_error);
773                        }
774                    }
775                }
776            }
777        })
778        .await
779        .map_err(|_| {
780            TpuSenderError::Custom(format!(
781                "Failed to get slot leaders connecting to: {}, timeout: {:?}. Invalid slot range",
782                websocket_url, tpu_leader_service_creation_timeout
783            ))
784        })??;
785
786        let cluster_nodes = timeout(tpu_leader_service_creation_timeout, async {
787            loop {
788                let cluster_nodes = rpc_client.get_cluster_nodes().await?;
789                // Stop once we find at least one leader's contact info
790                if cluster_nodes.iter().any(|rpc_contact_info| {
791                    Pubkey::from_str(&rpc_contact_info.pubkey)
792                        .map(|pubkey| leaders.contains(&pubkey))
793                        .unwrap_or(false)
794                }) {
795                    return Ok::<_, ClientError>(cluster_nodes);
796                }
797                sleep(retry_interval).await;
798            }
799        })
800        .await
801        .map_err(|_| {
802            TpuSenderError::Custom(format!(
803                "Failed find any cluster node info for upcoming leaders, timeout: {:?}.",
804                tpu_leader_service_creation_timeout
805            ))
806        })??;
807        let leader_tpu_cache = Arc::new(RwLock::new(LeaderTpuCache::new(
808            start_slot,
809            slots_in_epoch,
810            last_slot_in_epoch,
811            leaders,
812            cluster_nodes,
813            protocol,
814        )));
815
816        let pubsub_client = if !websocket_url.is_empty() {
817            Some(PubsubClient::new(websocket_url).await?)
818        } else {
819            None
820        };
821
822        let t_leader_tpu_service = Some({
823            let recent_slots = recent_slots.clone();
824            let leader_tpu_cache = leader_tpu_cache.clone();
825            tokio::spawn(Self::run(
826                rpc_client,
827                recent_slots,
828                leader_tpu_cache,
829                pubsub_client,
830                exit,
831            ))
832        });
833
834        Ok(LeaderTpuService {
835            recent_slots,
836            leader_tpu_cache,
837            t_leader_tpu_service,
838        })
839    }
840
841    pub async fn join(&mut self) {
842        if let Some(t_handle) = self.t_leader_tpu_service.take() {
843            t_handle.await.unwrap().unwrap();
844        }
845    }
846
847    pub fn estimated_current_slot(&self) -> Slot {
848        self.recent_slots.estimated_current_slot()
849    }
850
851    pub fn unique_leader_tpu_sockets(&self, fanout_slots: u64) -> Vec<SocketAddr> {
852        let current_slot = self.recent_slots.estimated_current_slot();
853        self.leader_tpu_cache
854            .read()
855            .unwrap()
856            .get_unique_leader_sockets(current_slot, fanout_slots)
857    }
858
859    pub fn leader_tpu_sockets(&self, fanout_slots: u64) -> Vec<SocketAddr> {
860        let current_slot = self.recent_slots.estimated_current_slot();
861        self.leader_tpu_cache
862            .read()
863            .unwrap()
864            .get_leader_sockets(current_slot, fanout_slots)
865    }
866
867    async fn run(
868        rpc_client: Arc<RpcClient>,
869        recent_slots: RecentLeaderSlots,
870        leader_tpu_cache: Arc<RwLock<LeaderTpuCache>>,
871        pubsub_client: Option<PubsubClient>,
872        exit: Arc<AtomicBool>,
873    ) -> Result<()> {
874        tokio::try_join!(
875            Self::run_slot_watcher(recent_slots.clone(), pubsub_client, exit.clone()),
876            Self::run_cache_refresher(rpc_client, recent_slots, leader_tpu_cache, exit),
877        )?;
878
879        Ok(())
880    }
881
882    async fn run_cache_refresher(
883        rpc_client: Arc<RpcClient>,
884        recent_slots: RecentLeaderSlots,
885        leader_tpu_cache: Arc<RwLock<LeaderTpuCache>>,
886        exit: Arc<AtomicBool>,
887    ) -> Result<()> {
888        let mut last_cluster_refresh = Instant::now();
889        let mut sleep_ms = DEFAULT_MS_PER_SLOT;
890
891        while !exit.load(Ordering::Relaxed) {
892            // Sleep a slot before checking if leader cache needs to be refreshed again
893            sleep(Duration::from_millis(sleep_ms)).await;
894            sleep_ms = DEFAULT_MS_PER_SLOT;
895
896            let cache_update_info = maybe_fetch_cache_info(
897                &leader_tpu_cache,
898                last_cluster_refresh,
899                &rpc_client,
900                &recent_slots,
901            )
902            .await;
903
904            if cache_update_info.has_some() {
905                let mut leader_tpu_cache = leader_tpu_cache.write().unwrap();
906                let (has_error, cluster_refreshed) = leader_tpu_cache
907                    .update_all(recent_slots.estimated_current_slot(), cache_update_info);
908                if has_error {
909                    sleep_ms = 100;
910                }
911                if cluster_refreshed {
912                    last_cluster_refresh = Instant::now();
913                }
914            }
915        }
916
917        Ok(())
918    }
919
920    async fn run_slot_watcher(
921        recent_slots: RecentLeaderSlots,
922        pubsub_client: Option<PubsubClient>,
923        exit: Arc<AtomicBool>,
924    ) -> Result<()> {
925        let Some(pubsub_client) = pubsub_client else {
926            return Ok(());
927        };
928
929        let (mut notifications, unsubscribe) = pubsub_client.slot_updates_subscribe().await?;
930        // Time out slot update notification polling at 10ms.
931        //
932        // Rationale is two-fold:
933        // 1. Notifications are an unbounded stream -- polling them will block indefinitely if not
934        //    interrupted, and the exit condition will never be checked. 10ms ensures negligible
935        //    CPU overhead while keeping notification checking timely.
936        // 2. The timeout must be strictly less than the slot time (DEFAULT_MS_PER_SLOT: 400) to
937        //    avoid timeout never being reached. For example, if notifications are received every
938        //    400ms and the timeout is >= 400ms, notifications may theoretically always be available
939        //    before the timeout is reached, resulting in the exit condition never being checked.
940        const SLOT_UPDATE_TIMEOUT: Duration = Duration::from_millis(10);
941
942        while !exit.load(Ordering::Relaxed) {
943            while let Ok(Some(update)) = timeout(SLOT_UPDATE_TIMEOUT, notifications.next()).await {
944                let current_slot = match update {
945                    // This update indicates that a full slot was received by the connected
946                    // node so we can stop sending transactions to the leader for that slot
947                    SlotUpdate::Completed { slot, .. } => slot.saturating_add(1),
948                    // This update indicates that we have just received the first shred from
949                    // the leader for this slot and they are probably still accepting transactions.
950                    SlotUpdate::FirstShredReceived { slot, .. } => slot,
951                    _ => continue,
952                };
953                recent_slots.record_slot(current_slot);
954            }
955        }
956
957        // `notifications` requires a valid reference to `pubsub_client`, so `notifications` must be
958        // dropped before moving `pubsub_client` via `shutdown()`.
959        drop(notifications);
960        unsubscribe().await;
961        pubsub_client.shutdown().await?;
962
963        Ok(())
964    }
965}
966
967async fn maybe_fetch_cache_info(
968    leader_tpu_cache: &Arc<RwLock<LeaderTpuCache>>,
969    last_cluster_refresh: Instant,
970    rpc_client: &RpcClient,
971    recent_slots: &RecentLeaderSlots,
972) -> LeaderTpuCacheUpdateInfo {
973    // Refresh cluster TPU ports every 5min in case validators restart with new port configuration
974    // or new validators come online
975    let maybe_cluster_nodes = if last_cluster_refresh.elapsed() > Duration::from_secs(5 * 60) {
976        Some(rpc_client.get_cluster_nodes().await)
977    } else {
978        None
979    };
980
981    let estimated_current_slot = recent_slots.estimated_current_slot();
982    let (last_slot, last_slot_in_epoch, slots_in_epoch) = {
983        let leader_tpu_cache = leader_tpu_cache.read().unwrap();
984        leader_tpu_cache.slot_info()
985    };
986    let maybe_epoch_schedule = if estimated_current_slot > last_slot_in_epoch {
987        Some(rpc_client.get_epoch_schedule().await)
988    } else {
989        None
990    };
991
992    let maybe_slot_leaders = if estimated_current_slot >= last_slot.saturating_sub(MAX_FANOUT_SLOTS)
993    {
994        Some(
995            rpc_client
996                .get_slot_leaders(
997                    estimated_current_slot,
998                    LeaderTpuCache::fanout(slots_in_epoch),
999                )
1000                .await,
1001        )
1002    } else {
1003        None
1004    };
1005    LeaderTpuCacheUpdateInfo {
1006        maybe_cluster_nodes,
1007        maybe_epoch_schedule,
1008        maybe_slot_leaders,
1009    }
1010}
1011
1012fn is_invalid_slot_range_error(client_error: &ClientError) -> bool {
1013    if let ErrorKind::RpcError(RpcError::RpcResponseError { code, message, .. }) =
1014        client_error.kind()
1015    {
1016        return *code == -32602
1017            && message.contains("Invalid slot range: leader schedule for epoch");
1018    }
1019    false
1020}