solana_tpu_client/nonblocking/
tpu_client.rs

1pub use crate::tpu_client::Result;
2use {
3    crate::tpu_client::{RecentLeaderSlots, TpuClientConfig, MAX_FANOUT_SLOTS},
4    bincode::serialize,
5    futures_util::{future::join_all, stream::StreamExt},
6    log::*,
7    solana_clock::{Slot, DEFAULT_MS_PER_SLOT, NUM_CONSECUTIVE_LEADER_SLOTS},
8    solana_commitment_config::CommitmentConfig,
9    solana_connection_cache::{
10        connection_cache::{
11            ConnectionCache, ConnectionManager, ConnectionPool, NewConnectionConfig, Protocol,
12            DEFAULT_CONNECTION_POOL_SIZE,
13        },
14        nonblocking::client_connection::ClientConnection,
15    },
16    solana_epoch_schedule::EpochSchedule,
17    solana_pubkey::Pubkey,
18    solana_pubsub_client::nonblocking::pubsub_client::{PubsubClient, PubsubClientError},
19    solana_quic_definitions::QUIC_PORT_OFFSET,
20    solana_rpc_client::nonblocking::rpc_client::RpcClient,
21    solana_rpc_client_api::{
22        client_error::{Error as ClientError, ErrorKind, Result as ClientResult},
23        request::RpcError,
24        response::{RpcContactInfo, SlotUpdate},
25    },
26    solana_signer::SignerError,
27    solana_transaction::Transaction,
28    solana_transaction_error::{TransportError, TransportResult},
29    std::{
30        collections::{HashMap, HashSet},
31        net::SocketAddr,
32        str::FromStr,
33        sync::{
34            atomic::{AtomicBool, Ordering},
35            Arc, RwLock,
36        },
37    },
38    thiserror::Error,
39    tokio::{
40        task::JoinHandle,
41        time::{sleep, timeout, Duration, Instant},
42    },
43};
44#[cfg(feature = "spinner")]
45use {
46    crate::tpu_client::{SEND_TRANSACTION_INTERVAL, TRANSACTION_RESEND_INTERVAL},
47    futures_util::FutureExt,
48    indicatif::ProgressBar,
49    solana_message::Message,
50    solana_rpc_client::spinner::{self, SendTransactionProgress},
51    solana_rpc_client_api::request::MAX_GET_SIGNATURE_STATUSES_QUERY_ITEMS,
52    solana_signer::signers::Signers,
53    solana_transaction_error::TransactionError,
54    std::{future::Future, iter},
55};
56
57#[derive(Error, Debug)]
58pub enum TpuSenderError {
59    #[error("Pubsub error: {0:?}")]
60    PubsubError(#[from] PubsubClientError),
61    #[error("RPC error: {0:?}")]
62    RpcError(#[from] ClientError),
63    #[error("IO error: {0:?}")]
64    IoError(#[from] std::io::Error),
65    #[error("Signer error: {0:?}")]
66    SignerError(#[from] SignerError),
67    #[error("Custom error: {0}")]
68    Custom(String),
69}
70
71struct LeaderTpuCacheUpdateInfo {
72    pub(super) maybe_cluster_nodes: Option<ClientResult<Vec<RpcContactInfo>>>,
73    pub(super) maybe_epoch_schedule: Option<ClientResult<EpochSchedule>>,
74    pub(super) maybe_slot_leaders: Option<ClientResult<Vec<Pubkey>>>,
75    pub(super) first_slot: Slot,
76}
77impl LeaderTpuCacheUpdateInfo {
78    pub fn has_some(&self) -> bool {
79        self.maybe_cluster_nodes.is_some()
80            || self.maybe_epoch_schedule.is_some()
81            || self.maybe_slot_leaders.is_some()
82    }
83}
84
85struct LeaderTpuCache {
86    protocol: Protocol,
87    first_slot: Slot,
88    leaders: Vec<Pubkey>,
89    leader_tpu_map: HashMap<Pubkey, SocketAddr>,
90    slots_in_epoch: Slot,
91    last_slot_in_epoch: Slot,
92}
93
94impl LeaderTpuCache {
95    pub fn new(
96        first_slot: Slot,
97        slots_in_epoch: Slot,
98        last_slot_in_epoch: Slot,
99        leaders: Vec<Pubkey>,
100        cluster_nodes: Vec<RpcContactInfo>,
101        protocol: Protocol,
102    ) -> Self {
103        let leader_tpu_map = Self::extract_cluster_tpu_sockets(protocol, cluster_nodes);
104        Self {
105            protocol,
106            first_slot,
107            leaders,
108            leader_tpu_map,
109            slots_in_epoch,
110            last_slot_in_epoch,
111        }
112    }
113
114    // Last slot that has a cached leader pubkey
115    pub fn last_slot(&self) -> Slot {
116        self.first_slot + self.leaders.len().saturating_sub(1) as u64
117    }
118
119    pub fn slot_info(&self) -> (Slot, Slot, Slot) {
120        (
121            self.last_slot(),
122            self.last_slot_in_epoch,
123            self.slots_in_epoch,
124        )
125    }
126
127    // Get the TPU sockets for the current leader and upcoming *unique* leaders according to fanout size.
128    fn get_unique_leader_sockets(
129        &self,
130        estimated_current_slot: Slot,
131        fanout_slots: u64,
132    ) -> Vec<SocketAddr> {
133        let all_leader_sockets = self.get_leader_sockets(estimated_current_slot, fanout_slots);
134
135        let mut unique_sockets = Vec::new();
136        let mut seen = HashSet::new();
137
138        for socket in all_leader_sockets {
139            if seen.insert(socket) {
140                unique_sockets.push(socket);
141            }
142        }
143
144        unique_sockets
145    }
146
147    // Get the TPU sockets for the current leader and upcoming leaders according to fanout size.
148    fn get_leader_sockets(
149        &self,
150        estimated_current_slot: Slot,
151        fanout_slots: u64,
152    ) -> Vec<SocketAddr> {
153        let mut leader_sockets = Vec::new();
154        // `first_slot` might have been advanced since caller last read the `estimated_current_slot`
155        // value. Take the greater of the two values to ensure we are reading from the latest
156        // leader schedule.
157        let current_slot = std::cmp::max(estimated_current_slot, self.first_slot);
158        for leader_slot in (current_slot..current_slot + fanout_slots)
159            .step_by(NUM_CONSECUTIVE_LEADER_SLOTS as usize)
160        {
161            if let Some(leader) = self.get_slot_leader(leader_slot) {
162                if let Some(tpu_socket) = self.leader_tpu_map.get(leader) {
163                    leader_sockets.push(*tpu_socket);
164                } else {
165                    // The leader is probably delinquent
166                    trace!("TPU not available for leader {leader}");
167                }
168            } else {
169                // Overran the local leader schedule cache
170                warn!(
171                    "Leader not known for slot {}; cache holds slots [{},{}]",
172                    leader_slot,
173                    self.first_slot,
174                    self.last_slot()
175                );
176            }
177        }
178        leader_sockets
179    }
180
181    pub fn get_slot_leader(&self, slot: Slot) -> Option<&Pubkey> {
182        if slot >= self.first_slot {
183            let index = slot - self.first_slot;
184            self.leaders.get(index as usize)
185        } else {
186            None
187        }
188    }
189
190    fn extract_cluster_tpu_sockets(
191        protocol: Protocol,
192        cluster_contact_info: Vec<RpcContactInfo>,
193    ) -> HashMap<Pubkey, SocketAddr> {
194        cluster_contact_info
195            .into_iter()
196            .filter_map(|contact_info| {
197                let pubkey = Pubkey::from_str(&contact_info.pubkey).ok()?;
198                let socket = match protocol {
199                    Protocol::QUIC => contact_info.tpu_quic.or_else(|| {
200                        let mut socket = contact_info.tpu?;
201                        let port = socket.port().checked_add(QUIC_PORT_OFFSET)?;
202                        socket.set_port(port);
203                        Some(socket)
204                    }),
205                    Protocol::UDP => contact_info.tpu,
206                }?;
207                Some((pubkey, socket))
208            })
209            .collect()
210    }
211
212    pub fn fanout(slots_in_epoch: Slot) -> Slot {
213        (2 * MAX_FANOUT_SLOTS).min(slots_in_epoch)
214    }
215
216    pub fn update_all(&mut self, cache_update_info: LeaderTpuCacheUpdateInfo) -> (bool, bool) {
217        let mut has_error = false;
218        let mut cluster_refreshed = false;
219        if let Some(cluster_nodes) = cache_update_info.maybe_cluster_nodes {
220            match cluster_nodes {
221                Ok(cluster_nodes) => {
222                    self.leader_tpu_map =
223                        Self::extract_cluster_tpu_sockets(self.protocol, cluster_nodes);
224                    cluster_refreshed = true;
225                }
226                Err(err) => {
227                    warn!("Failed to fetch cluster tpu sockets: {err}");
228                    has_error = true;
229                }
230            }
231        }
232
233        if let Some(Ok(epoch_schedule)) = cache_update_info.maybe_epoch_schedule {
234            let epoch = epoch_schedule.get_epoch(cache_update_info.first_slot);
235            self.slots_in_epoch = epoch_schedule.get_slots_in_epoch(epoch);
236            self.last_slot_in_epoch = epoch_schedule.get_last_slot_in_epoch(epoch);
237        }
238
239        if let Some(slot_leaders) = cache_update_info.maybe_slot_leaders {
240            match slot_leaders {
241                Ok(slot_leaders) => {
242                    self.first_slot = cache_update_info.first_slot;
243                    self.leaders = slot_leaders;
244                }
245                Err(err) => {
246                    warn!(
247                        "Failed to fetch slot leaders (first_slot: \
248                         {}): {err}",
249                        cache_update_info.first_slot
250                    );
251                    has_error = true;
252                }
253            }
254        }
255        (has_error, cluster_refreshed)
256    }
257}
258
259/// Client which sends transactions directly to the current leader's TPU port over UDP.
260/// The client uses RPC to determine the current leader and fetch node contact info
261pub struct TpuClient<
262    P, // ConnectionPool
263    M, // ConnectionManager
264    C, // NewConnectionConfig
265> {
266    fanout_slots: u64,
267    leader_tpu_service: LeaderTpuService,
268    exit: Arc<AtomicBool>,
269    rpc_client: Arc<RpcClient>,
270    connection_cache: Arc<ConnectionCache<P, M, C>>,
271}
272
273/// Helper function which generates futures to all be awaited together for maximum
274/// throughput
275#[cfg(feature = "spinner")]
276fn send_wire_transaction_futures<'a, P, M, C>(
277    progress_bar: &'a ProgressBar,
278    progress: &'a SendTransactionProgress,
279    index: usize,
280    num_transactions: usize,
281    wire_transaction: Vec<u8>,
282    leaders: Vec<SocketAddr>,
283    connection_cache: &'a ConnectionCache<P, M, C>,
284) -> Vec<impl Future<Output = TransportResult<()>> + 'a>
285where
286    P: ConnectionPool<NewConnectionConfig = C>,
287    M: ConnectionManager<ConnectionPool = P, NewConnectionConfig = C>,
288    C: NewConnectionConfig,
289{
290    const SEND_TIMEOUT_INTERVAL: Duration = Duration::from_secs(5);
291    let sleep_duration = SEND_TRANSACTION_INTERVAL.saturating_mul(index as u32);
292    let send_timeout = SEND_TIMEOUT_INTERVAL.saturating_add(sleep_duration);
293    leaders
294        .into_iter()
295        .map(|addr| {
296            timeout_future(
297                send_timeout,
298                sleep_and_send_wire_transaction_to_addr(
299                    sleep_duration,
300                    connection_cache,
301                    addr,
302                    wire_transaction.clone(),
303                ),
304            )
305            .boxed_local() // required to make types work simply
306        })
307        .chain(iter::once(
308            timeout_future(
309                send_timeout,
310                sleep_and_set_message(
311                    sleep_duration,
312                    progress_bar,
313                    progress,
314                    index,
315                    num_transactions,
316                ),
317            )
318            .boxed_local(), // required to make types work simply
319        ))
320        .collect::<Vec<_>>()
321}
322
323// Wrap an existing future with a timeout.
324//
325// Useful for end-users who don't need a persistent connection to each validator,
326// and want to abort more quickly.
327#[cfg(feature = "spinner")]
328async fn timeout_future<Fut: Future<Output = TransportResult<()>>>(
329    timeout_duration: Duration,
330    future: Fut,
331) -> TransportResult<()> {
332    timeout(timeout_duration, future)
333        .await
334        .unwrap_or_else(|_| Err(TransportError::Custom("Timed out".to_string())))
335}
336
337#[cfg(feature = "spinner")]
338async fn sleep_and_set_message(
339    sleep_duration: Duration,
340    progress_bar: &ProgressBar,
341    progress: &SendTransactionProgress,
342    index: usize,
343    num_transactions: usize,
344) -> TransportResult<()> {
345    sleep(sleep_duration).await;
346    progress.set_message_for_confirmed_transactions(
347        progress_bar,
348        &format!("Sending {}/{} transactions", index + 1, num_transactions,),
349    );
350    Ok(())
351}
352
353#[cfg(feature = "spinner")]
354async fn sleep_and_send_wire_transaction_to_addr<P, M, C>(
355    sleep_duration: Duration,
356    connection_cache: &ConnectionCache<P, M, C>,
357    addr: SocketAddr,
358    wire_transaction: Vec<u8>,
359) -> TransportResult<()>
360where
361    P: ConnectionPool<NewConnectionConfig = C>,
362    M: ConnectionManager<ConnectionPool = P, NewConnectionConfig = C>,
363    C: NewConnectionConfig,
364{
365    sleep(sleep_duration).await;
366    send_wire_transaction_to_addr(connection_cache, &addr, wire_transaction).await
367}
368
369async fn send_wire_transaction_to_addr<P, M, C>(
370    connection_cache: &ConnectionCache<P, M, C>,
371    addr: &SocketAddr,
372    wire_transaction: Vec<u8>,
373) -> TransportResult<()>
374where
375    P: ConnectionPool<NewConnectionConfig = C>,
376    M: ConnectionManager<ConnectionPool = P, NewConnectionConfig = C>,
377    C: NewConnectionConfig,
378{
379    let conn = connection_cache.get_nonblocking_connection(addr);
380    conn.send_data(&wire_transaction).await
381}
382
383async fn send_wire_transaction_batch_to_addr<P, M, C>(
384    connection_cache: &ConnectionCache<P, M, C>,
385    addr: &SocketAddr,
386    wire_transactions: &[Vec<u8>],
387) -> TransportResult<()>
388where
389    P: ConnectionPool<NewConnectionConfig = C>,
390    M: ConnectionManager<ConnectionPool = P, NewConnectionConfig = C>,
391    C: NewConnectionConfig,
392{
393    let conn = connection_cache.get_nonblocking_connection(addr);
394    conn.send_data_batch(wire_transactions).await
395}
396
397impl<P, M, C> TpuClient<P, M, C>
398where
399    P: ConnectionPool<NewConnectionConfig = C>,
400    M: ConnectionManager<ConnectionPool = P, NewConnectionConfig = C>,
401    C: NewConnectionConfig,
402{
403    /// Serialize and send transaction to the current and upcoming leader TPUs according to fanout
404    /// size
405    pub async fn send_transaction(&self, transaction: &Transaction) -> bool {
406        let wire_transaction = serialize(transaction).expect("serialization should succeed");
407        self.send_wire_transaction(wire_transaction).await
408    }
409
410    /// Send a wire transaction to the current and upcoming leader TPUs according to fanout size
411    pub async fn send_wire_transaction(&self, wire_transaction: Vec<u8>) -> bool {
412        self.try_send_wire_transaction(wire_transaction)
413            .await
414            .is_ok()
415    }
416
417    /// Serialize and send transaction to the current and upcoming leader TPUs according to fanout
418    /// size
419    /// Returns the last error if all sends fail
420    pub async fn try_send_transaction(&self, transaction: &Transaction) -> TransportResult<()> {
421        let wire_transaction = serialize(transaction).expect("serialization should succeed");
422        self.try_send_wire_transaction(wire_transaction).await
423    }
424
425    /// Send a wire transaction to the current and upcoming leader TPUs according to fanout size
426    /// Returns the last error if all sends fail
427    pub async fn try_send_wire_transaction(
428        &self,
429        wire_transaction: Vec<u8>,
430    ) -> TransportResult<()> {
431        let leaders = self
432            .leader_tpu_service
433            .unique_leader_tpu_sockets(self.fanout_slots);
434        let futures = leaders
435            .iter()
436            .map(|addr| {
437                send_wire_transaction_to_addr(
438                    &self.connection_cache,
439                    addr,
440                    wire_transaction.clone(),
441                )
442            })
443            .collect::<Vec<_>>();
444        let results: Vec<TransportResult<()>> = join_all(futures).await;
445
446        let mut last_error: Option<TransportError> = None;
447        let mut some_success = false;
448        for result in results {
449            if let Err(e) = result {
450                if last_error.is_none() {
451                    last_error = Some(e);
452                }
453            } else {
454                some_success = true;
455            }
456        }
457        if !some_success {
458            Err(if let Some(err) = last_error {
459                err
460            } else {
461                std::io::Error::other("No sends attempted").into()
462            })
463        } else {
464            Ok(())
465        }
466    }
467
468    /// Send a batch of wire transactions to the current and upcoming leader TPUs according to
469    /// fanout size
470    /// Returns the last error if all sends fail
471    pub async fn try_send_wire_transaction_batch(
472        &self,
473        wire_transactions: Vec<Vec<u8>>,
474    ) -> TransportResult<()> {
475        let leaders = self
476            .leader_tpu_service
477            .unique_leader_tpu_sockets(self.fanout_slots);
478        let futures = leaders
479            .iter()
480            .map(|addr| {
481                send_wire_transaction_batch_to_addr(
482                    &self.connection_cache,
483                    addr,
484                    &wire_transactions,
485                )
486            })
487            .collect::<Vec<_>>();
488        let results: Vec<TransportResult<()>> = join_all(futures).await;
489
490        let mut last_error: Option<TransportError> = None;
491        let mut some_success = false;
492        for result in results {
493            if let Err(e) = result {
494                if last_error.is_none() {
495                    last_error = Some(e);
496                }
497            } else {
498                some_success = true;
499            }
500        }
501        if !some_success {
502            Err(if let Some(err) = last_error {
503                err
504            } else {
505                std::io::Error::other("No sends attempted").into()
506            })
507        } else {
508            Ok(())
509        }
510    }
511
512    /// Create a new client that disconnects when dropped
513    pub async fn new(
514        name: &'static str,
515        rpc_client: Arc<RpcClient>,
516        websocket_url: &str,
517        config: TpuClientConfig,
518        connection_manager: M,
519    ) -> Result<Self> {
520        let connection_cache = Arc::new(
521            ConnectionCache::new(name, connection_manager, DEFAULT_CONNECTION_POOL_SIZE).unwrap(),
522        ); // TODO: Handle error properly, as the ConnectionCache ctor is now fallible.
523        Self::new_with_connection_cache(rpc_client, websocket_url, config, connection_cache).await
524    }
525
526    /// Create a new client that disconnects when dropped
527    pub async fn new_with_connection_cache(
528        rpc_client: Arc<RpcClient>,
529        websocket_url: &str,
530        config: TpuClientConfig,
531        connection_cache: Arc<ConnectionCache<P, M, C>>,
532    ) -> Result<Self> {
533        let exit = Arc::new(AtomicBool::new(false));
534        let leader_tpu_service =
535            LeaderTpuService::new(rpc_client.clone(), websocket_url, M::PROTOCOL, exit.clone())
536                .await?;
537
538        Ok(Self {
539            fanout_slots: config.fanout_slots.clamp(1, MAX_FANOUT_SLOTS),
540            leader_tpu_service,
541            exit,
542            rpc_client,
543            connection_cache,
544        })
545    }
546
547    #[cfg(feature = "spinner")]
548    pub async fn send_and_confirm_messages_with_spinner<T: Signers + ?Sized>(
549        &self,
550        messages: &[Message],
551        signers: &T,
552    ) -> Result<Vec<Option<TransactionError>>> {
553        let mut progress = SendTransactionProgress::default();
554        let progress_bar = spinner::new_progress_bar();
555        progress_bar.set_message("Setting up...");
556
557        let mut transactions = messages
558            .iter()
559            .enumerate()
560            .map(|(i, message)| (i, Transaction::new_unsigned(message.clone())))
561            .collect::<Vec<_>>();
562        progress.total_transactions = transactions.len();
563        let mut transaction_errors = vec![None; transactions.len()];
564        progress.block_height = self.rpc_client.get_block_height().await?;
565        for expired_blockhash_retries in (0..5).rev() {
566            let (blockhash, last_valid_block_height) = self
567                .rpc_client
568                .get_latest_blockhash_with_commitment(self.rpc_client.commitment())
569                .await?;
570            progress.last_valid_block_height = last_valid_block_height;
571
572            let mut pending_transactions = HashMap::new();
573            for (i, mut transaction) in transactions {
574                transaction.try_sign(signers, blockhash)?;
575                pending_transactions.insert(transaction.signatures[0], (i, transaction));
576            }
577
578            let mut last_resend = Instant::now() - TRANSACTION_RESEND_INTERVAL;
579            while progress.block_height <= progress.last_valid_block_height {
580                let num_transactions = pending_transactions.len();
581
582                // Periodically re-send all pending transactions
583                if Instant::now().duration_since(last_resend) > TRANSACTION_RESEND_INTERVAL {
584                    // Prepare futures for all transactions
585                    let mut futures = vec![];
586                    for (index, (_i, transaction)) in pending_transactions.values().enumerate() {
587                        let wire_transaction = serialize(transaction).unwrap();
588                        let leaders = self
589                            .leader_tpu_service
590                            .unique_leader_tpu_sockets(self.fanout_slots);
591                        futures.extend(send_wire_transaction_futures(
592                            &progress_bar,
593                            &progress,
594                            index,
595                            num_transactions,
596                            wire_transaction,
597                            leaders,
598                            &self.connection_cache,
599                        ));
600                    }
601
602                    // Start the process of sending them all
603                    let results = join_all(futures).await;
604
605                    progress.set_message_for_confirmed_transactions(
606                        &progress_bar,
607                        "Checking sent transactions",
608                    );
609                    for (index, (tx_results, (_i, transaction))) in results
610                        .chunks(self.fanout_slots as usize)
611                        .zip(pending_transactions.values())
612                        .enumerate()
613                    {
614                        // Only report an error if every future in the chunk errored
615                        if tx_results.iter().all(|r| r.is_err()) {
616                            progress.set_message_for_confirmed_transactions(
617                                &progress_bar,
618                                &format!(
619                                    "Resending failed transaction {} of {}",
620                                    index + 1,
621                                    num_transactions,
622                                ),
623                            );
624                            let _result = self.rpc_client.send_transaction(transaction).await.ok();
625                        }
626                    }
627                    last_resend = Instant::now();
628                }
629
630                // Wait for the next block before checking for transaction statuses
631                let mut block_height_refreshes = 10;
632                progress.set_message_for_confirmed_transactions(
633                    &progress_bar,
634                    &format!("Waiting for next block, {num_transactions} transactions pending..."),
635                );
636                let mut new_block_height = progress.block_height;
637                while progress.block_height == new_block_height && block_height_refreshes > 0 {
638                    sleep(Duration::from_millis(500)).await;
639                    new_block_height = self.rpc_client.get_block_height().await?;
640                    block_height_refreshes -= 1;
641                }
642                progress.block_height = new_block_height;
643
644                // Collect statuses for the transactions, drop those that are confirmed
645                let pending_signatures = pending_transactions.keys().cloned().collect::<Vec<_>>();
646                for pending_signatures_chunk in
647                    pending_signatures.chunks(MAX_GET_SIGNATURE_STATUSES_QUERY_ITEMS)
648                {
649                    if let Ok(result) = self
650                        .rpc_client
651                        .get_signature_statuses(pending_signatures_chunk)
652                        .await
653                    {
654                        let statuses = result.value;
655                        for (signature, status) in
656                            pending_signatures_chunk.iter().zip(statuses.into_iter())
657                        {
658                            if let Some(status) = status {
659                                if status.satisfies_commitment(self.rpc_client.commitment()) {
660                                    if let Some((i, _)) = pending_transactions.remove(signature) {
661                                        progress.confirmed_transactions += 1;
662                                        if status.err.is_some() {
663                                            progress_bar
664                                                .println(format!("Failed transaction: {status:?}"));
665                                        }
666                                        transaction_errors[i] = status.err;
667                                    }
668                                }
669                            }
670                        }
671                    }
672                    progress.set_message_for_confirmed_transactions(
673                        &progress_bar,
674                        "Checking transaction status...",
675                    );
676                }
677
678                if pending_transactions.is_empty() {
679                    return Ok(transaction_errors);
680                }
681            }
682
683            transactions = pending_transactions.into_values().collect();
684            progress_bar.println(format!(
685                "Blockhash expired. {expired_blockhash_retries} retries remaining"
686            ));
687        }
688        Err(TpuSenderError::Custom("Max retries exceeded".into()))
689    }
690
691    pub fn rpc_client(&self) -> &RpcClient {
692        &self.rpc_client
693    }
694
695    pub async fn shutdown(&mut self) {
696        self.exit.store(true, Ordering::Relaxed);
697        self.leader_tpu_service.join().await;
698    }
699
700    pub fn get_connection_cache(&self) -> &Arc<ConnectionCache<P, M, C>>
701    where
702        P: ConnectionPool<NewConnectionConfig = C>,
703        M: ConnectionManager<ConnectionPool = P, NewConnectionConfig = C>,
704        C: NewConnectionConfig,
705    {
706        &self.connection_cache
707    }
708
709    pub fn get_leader_tpu_service(&self) -> &LeaderTpuService {
710        &self.leader_tpu_service
711    }
712
713    pub fn get_fanout_slots(&self) -> u64 {
714        self.fanout_slots
715    }
716}
717
718impl<P, M, C> Drop for TpuClient<P, M, C> {
719    fn drop(&mut self) {
720        self.exit.store(true, Ordering::Relaxed);
721    }
722}
723
724/// Service that tracks upcoming leaders and maintains an up-to-date mapping
725/// of leader id to TPU socket address.
726pub struct LeaderTpuService {
727    recent_slots: RecentLeaderSlots,
728    leader_tpu_cache: Arc<RwLock<LeaderTpuCache>>,
729    t_leader_tpu_service: Option<JoinHandle<Result<()>>>,
730}
731
732impl LeaderTpuService {
733    pub async fn new(
734        rpc_client: Arc<RpcClient>,
735        websocket_url: &str,
736        protocol: Protocol,
737        exit: Arc<AtomicBool>,
738    ) -> Result<Self> {
739        let epoch_schedule = rpc_client.get_epoch_schedule().await?;
740        let start_slot = rpc_client
741            .get_slot_with_commitment(CommitmentConfig::processed())
742            .await?;
743
744        let recent_slots = RecentLeaderSlots::new(start_slot);
745        let epoch = epoch_schedule.get_epoch(start_slot);
746        let slots_in_epoch = epoch_schedule.get_slots_in_epoch(epoch);
747        let last_slot_in_epoch = epoch_schedule.get_last_slot_in_epoch(epoch);
748
749        // When a cluster is starting, we observe an invalid slot range failure that goes away after a
750        // retry. It seems as if the leader schedule is not available, but it should be. The logic
751        // below retries the RPC call in case of an invalid slot range error.
752        let tpu_leader_service_creation_timeout = Duration::from_secs(20);
753        let retry_interval = Duration::from_secs(1);
754        let leaders = timeout(tpu_leader_service_creation_timeout, async {
755            loop {
756                // TODO: The root cause appears to lie within the `rpc_client.get_slot_leaders()`.
757                // It might be worth debugging further and trying to understand why the RPC
758                // call fails. There may be a bug in the `get_slot_leaders()` logic or in the
759                // RPC implementation
760                match rpc_client
761                    .get_slot_leaders(start_slot, LeaderTpuCache::fanout(slots_in_epoch))
762                    .await
763                {
764                    Ok(leaders) => return Ok(leaders),
765                    Err(client_error) => {
766                        if is_invalid_slot_range_error(&client_error) {
767                            sleep(retry_interval).await;
768                            continue;
769                        } else {
770                            return Err(client_error);
771                        }
772                    }
773                }
774            }
775        })
776        .await
777        .map_err(|_| {
778            TpuSenderError::Custom(format!(
779                "Failed to get slot leaders connecting to: {websocket_url}, timeout: \
780                 {tpu_leader_service_creation_timeout:?}. Invalid slot range"
781            ))
782        })??;
783
784        let cluster_nodes = timeout(tpu_leader_service_creation_timeout, async {
785            loop {
786                let cluster_nodes = rpc_client.get_cluster_nodes().await?;
787                // Stop once we find at least one leader's contact info
788                if cluster_nodes.iter().any(|rpc_contact_info| {
789                    Pubkey::from_str(&rpc_contact_info.pubkey)
790                        .map(|pubkey| leaders.contains(&pubkey))
791                        .unwrap_or(false)
792                }) {
793                    return Ok::<_, ClientError>(cluster_nodes);
794                }
795                sleep(retry_interval).await;
796            }
797        })
798        .await
799        .map_err(|_| {
800            TpuSenderError::Custom(format!(
801                "Failed find any cluster node info for upcoming leaders, timeout: \
802                 {tpu_leader_service_creation_timeout:?}."
803            ))
804        })??;
805        let leader_tpu_cache = Arc::new(RwLock::new(LeaderTpuCache::new(
806            start_slot,
807            slots_in_epoch,
808            last_slot_in_epoch,
809            leaders,
810            cluster_nodes,
811            protocol,
812        )));
813
814        let pubsub_client = if !websocket_url.is_empty() {
815            Some(PubsubClient::new(websocket_url).await?)
816        } else {
817            None
818        };
819
820        let t_leader_tpu_service = Some({
821            let recent_slots = recent_slots.clone();
822            let leader_tpu_cache = leader_tpu_cache.clone();
823            tokio::spawn(Self::run(
824                rpc_client,
825                recent_slots,
826                leader_tpu_cache,
827                pubsub_client,
828                exit,
829            ))
830        });
831
832        Ok(LeaderTpuService {
833            recent_slots,
834            leader_tpu_cache,
835            t_leader_tpu_service,
836        })
837    }
838
839    pub async fn join(&mut self) {
840        if let Some(t_handle) = self.t_leader_tpu_service.take() {
841            t_handle.await.unwrap().unwrap();
842        }
843    }
844
845    pub fn estimated_current_slot(&self) -> Slot {
846        self.recent_slots.estimated_current_slot()
847    }
848
849    pub fn unique_leader_tpu_sockets(&self, fanout_slots: u64) -> Vec<SocketAddr> {
850        let current_slot = self.recent_slots.estimated_current_slot();
851        self.leader_tpu_cache
852            .read()
853            .unwrap()
854            .get_unique_leader_sockets(current_slot, fanout_slots)
855    }
856
857    pub fn leader_tpu_sockets(&self, fanout_slots: u64) -> Vec<SocketAddr> {
858        let current_slot = self.recent_slots.estimated_current_slot();
859        self.leader_tpu_cache
860            .read()
861            .unwrap()
862            .get_leader_sockets(current_slot, fanout_slots)
863    }
864
865    async fn run(
866        rpc_client: Arc<RpcClient>,
867        recent_slots: RecentLeaderSlots,
868        leader_tpu_cache: Arc<RwLock<LeaderTpuCache>>,
869        pubsub_client: Option<PubsubClient>,
870        exit: Arc<AtomicBool>,
871    ) -> Result<()> {
872        tokio::try_join!(
873            Self::run_slot_watcher(recent_slots.clone(), pubsub_client, exit.clone()),
874            Self::run_cache_refresher(rpc_client, recent_slots, leader_tpu_cache, exit),
875        )?;
876
877        Ok(())
878    }
879
880    async fn run_cache_refresher(
881        rpc_client: Arc<RpcClient>,
882        recent_slots: RecentLeaderSlots,
883        leader_tpu_cache: Arc<RwLock<LeaderTpuCache>>,
884        exit: Arc<AtomicBool>,
885    ) -> Result<()> {
886        let mut last_cluster_refresh = Instant::now();
887        let mut sleep_ms = DEFAULT_MS_PER_SLOT;
888
889        while !exit.load(Ordering::Relaxed) {
890            // Sleep a slot before checking if leader cache needs to be refreshed again
891            sleep(Duration::from_millis(sleep_ms)).await;
892            sleep_ms = DEFAULT_MS_PER_SLOT;
893
894            let cache_update_info = maybe_fetch_cache_info(
895                &leader_tpu_cache,
896                last_cluster_refresh,
897                &rpc_client,
898                &recent_slots,
899            )
900            .await;
901
902            if cache_update_info.has_some() {
903                let mut leader_tpu_cache = leader_tpu_cache.write().unwrap();
904                let (has_error, cluster_refreshed) = leader_tpu_cache.update_all(cache_update_info);
905                if has_error {
906                    sleep_ms = 100;
907                }
908                if cluster_refreshed {
909                    last_cluster_refresh = Instant::now();
910                }
911            }
912        }
913
914        Ok(())
915    }
916
917    async fn run_slot_watcher(
918        recent_slots: RecentLeaderSlots,
919        pubsub_client: Option<PubsubClient>,
920        exit: Arc<AtomicBool>,
921    ) -> Result<()> {
922        let Some(pubsub_client) = pubsub_client else {
923            return Ok(());
924        };
925
926        let (mut notifications, unsubscribe) = pubsub_client.slot_updates_subscribe().await?;
927        // Time out slot update notification polling at 10ms.
928        //
929        // Rationale is two-fold:
930        // 1. Notifications are an unbounded stream -- polling them will block indefinitely if not
931        //    interrupted, and the exit condition will never be checked. 10ms ensures negligible
932        //    CPU overhead while keeping notification checking timely.
933        // 2. The timeout must be strictly less than the slot time (DEFAULT_MS_PER_SLOT: 400) to
934        //    avoid timeout never being reached. For example, if notifications are received every
935        //    400ms and the timeout is >= 400ms, notifications may theoretically always be available
936        //    before the timeout is reached, resulting in the exit condition never being checked.
937        const SLOT_UPDATE_TIMEOUT: Duration = Duration::from_millis(10);
938
939        while !exit.load(Ordering::Relaxed) {
940            while let Ok(Some(update)) = timeout(SLOT_UPDATE_TIMEOUT, notifications.next()).await {
941                let current_slot = match update {
942                    // This update indicates that a full slot was received by the connected
943                    // node so we can stop sending transactions to the leader for that slot
944                    SlotUpdate::Completed { slot, .. } => slot.saturating_add(1),
945                    // This update indicates that we have just received the first shred from
946                    // the leader for this slot and they are probably still accepting transactions.
947                    SlotUpdate::FirstShredReceived { slot, .. } => slot,
948                    _ => continue,
949                };
950                recent_slots.record_slot(current_slot);
951            }
952        }
953
954        // `notifications` requires a valid reference to `pubsub_client`, so `notifications` must be
955        // dropped before moving `pubsub_client` via `shutdown()`.
956        drop(notifications);
957        unsubscribe().await;
958        pubsub_client.shutdown().await?;
959
960        Ok(())
961    }
962}
963
964async fn maybe_fetch_cache_info(
965    leader_tpu_cache: &Arc<RwLock<LeaderTpuCache>>,
966    last_cluster_refresh: Instant,
967    rpc_client: &RpcClient,
968    recent_slots: &RecentLeaderSlots,
969) -> LeaderTpuCacheUpdateInfo {
970    // Refresh cluster TPU ports every 5min in case validators restart with new port configuration
971    // or new validators come online
972    let maybe_cluster_nodes = if last_cluster_refresh.elapsed() > Duration::from_secs(5 * 60) {
973        Some(rpc_client.get_cluster_nodes().await)
974    } else {
975        None
976    };
977
978    // Grab information about the slot leaders currently in the cache.
979    let estimated_current_slot = recent_slots.estimated_current_slot();
980    let (last_slot, last_slot_in_epoch, slots_in_epoch) = {
981        let leader_tpu_cache = leader_tpu_cache.read().unwrap();
982        leader_tpu_cache.slot_info()
983    };
984
985    // If we're crossing into a new epoch, fetch the updated epoch schedule.
986    let maybe_epoch_schedule = if estimated_current_slot > last_slot_in_epoch {
987        Some(rpc_client.get_epoch_schedule().await)
988    } else {
989        None
990    };
991
992    // If we are within the fanout range of the last slot in the cache, fetch
993    // more slot leaders. We pull down a big batch at at time to amortize the
994    // cost of the RPC call. We don't want to stall transactions on pulling this
995    // down so we fetch it proactively.
996    let maybe_slot_leaders = if estimated_current_slot >= last_slot.saturating_sub(MAX_FANOUT_SLOTS)
997    {
998        Some(
999            rpc_client
1000                .get_slot_leaders(
1001                    estimated_current_slot,
1002                    LeaderTpuCache::fanout(slots_in_epoch),
1003                )
1004                .await,
1005        )
1006    } else {
1007        None
1008    };
1009    LeaderTpuCacheUpdateInfo {
1010        maybe_cluster_nodes,
1011        maybe_epoch_schedule,
1012        maybe_slot_leaders,
1013        first_slot: estimated_current_slot,
1014    }
1015}
1016
1017fn is_invalid_slot_range_error(client_error: &ClientError) -> bool {
1018    if let ErrorKind::RpcError(RpcError::RpcResponseError { code, message, .. }) =
1019        client_error.kind()
1020    {
1021        return *code == -32602
1022            && message.contains("Invalid slot range: leader schedule for epoch");
1023    }
1024    false
1025}