solana_tpu_client/nonblocking/
tpu_client.rs

1pub use crate::tpu_client::Result;
2use {
3    crate::tpu_client::{RecentLeaderSlots, TpuClientConfig, MAX_FANOUT_SLOTS},
4    bincode::serialize,
5    futures_util::{future::join_all, stream::StreamExt},
6    log::*,
7    solana_clock::{Slot, DEFAULT_MS_PER_SLOT, NUM_CONSECUTIVE_LEADER_SLOTS},
8    solana_commitment_config::CommitmentConfig,
9    solana_connection_cache::{
10        connection_cache::{
11            ConnectionCache, ConnectionManager, ConnectionPool, NewConnectionConfig, Protocol,
12            DEFAULT_CONNECTION_POOL_SIZE,
13        },
14        nonblocking::client_connection::ClientConnection,
15    },
16    solana_epoch_schedule::EpochSchedule,
17    solana_pubkey::Pubkey,
18    solana_pubsub_client::nonblocking::pubsub_client::{PubsubClient, PubsubClientError},
19    solana_quic_definitions::QUIC_PORT_OFFSET,
20    solana_rpc_client::nonblocking::rpc_client::RpcClient,
21    solana_rpc_client_api::{
22        client_error::{Error as ClientError, ErrorKind, Result as ClientResult},
23        request::RpcError,
24        response::{RpcContactInfo, SlotUpdate},
25    },
26    solana_signer::SignerError,
27    solana_transaction::Transaction,
28    solana_transaction_error::{TransportError, TransportResult},
29    std::{
30        collections::{HashMap, HashSet},
31        net::SocketAddr,
32        str::FromStr,
33        sync::{
34            atomic::{AtomicBool, Ordering},
35            Arc, RwLock,
36        },
37    },
38    thiserror::Error,
39    tokio::{
40        task::JoinHandle,
41        time::{sleep, timeout, Duration, Instant},
42    },
43};
44#[cfg(feature = "spinner")]
45use {
46    crate::tpu_client::{SEND_TRANSACTION_INTERVAL, TRANSACTION_RESEND_INTERVAL},
47    futures_util::FutureExt,
48    indicatif::ProgressBar,
49    solana_message::Message,
50    solana_rpc_client::spinner::{self, SendTransactionProgress},
51    solana_rpc_client_api::request::MAX_GET_SIGNATURE_STATUSES_QUERY_ITEMS,
52    solana_signer::signers::Signers,
53    solana_transaction_error::TransactionError,
54    std::{future::Future, iter},
55};
56
57#[derive(Error, Debug)]
58pub enum TpuSenderError {
59    #[error("Pubsub error: {0:?}")]
60    PubsubError(#[from] PubsubClientError),
61    #[error("RPC error: {0:?}")]
62    RpcError(#[from] ClientError),
63    #[error("IO error: {0:?}")]
64    IoError(#[from] std::io::Error),
65    #[error("Signer error: {0:?}")]
66    SignerError(#[from] SignerError),
67    #[error("Custom error: {0}")]
68    Custom(String),
69}
70
71struct LeaderTpuCacheUpdateInfo {
72    pub(super) maybe_cluster_nodes: Option<ClientResult<Vec<RpcContactInfo>>>,
73    pub(super) maybe_epoch_schedule: Option<ClientResult<EpochSchedule>>,
74    pub(super) maybe_slot_leaders: Option<ClientResult<Vec<Pubkey>>>,
75    pub(super) first_slot: Slot,
76}
77impl LeaderTpuCacheUpdateInfo {
78    pub fn has_some(&self) -> bool {
79        self.maybe_cluster_nodes.is_some()
80            || self.maybe_epoch_schedule.is_some()
81            || self.maybe_slot_leaders.is_some()
82    }
83}
84
85struct LeaderTpuCache {
86    protocol: Protocol,
87    first_slot: Slot,
88    leaders: Vec<Pubkey>,
89    leader_tpu_map: HashMap<Pubkey, SocketAddr>,
90    slots_in_epoch: Slot,
91    last_slot_in_epoch: Slot,
92}
93
94impl LeaderTpuCache {
95    pub fn new(
96        first_slot: Slot,
97        slots_in_epoch: Slot,
98        last_slot_in_epoch: Slot,
99        leaders: Vec<Pubkey>,
100        cluster_nodes: Vec<RpcContactInfo>,
101        protocol: Protocol,
102    ) -> Self {
103        let leader_tpu_map = Self::extract_cluster_tpu_sockets(protocol, cluster_nodes);
104        Self {
105            protocol,
106            first_slot,
107            leaders,
108            leader_tpu_map,
109            slots_in_epoch,
110            last_slot_in_epoch,
111        }
112    }
113
114    // Last slot that has a cached leader pubkey
115    pub fn last_slot(&self) -> Slot {
116        self.first_slot + self.leaders.len().saturating_sub(1) as u64
117    }
118
119    pub fn slot_info(&self) -> (Slot, Slot, Slot) {
120        (
121            self.last_slot(),
122            self.last_slot_in_epoch,
123            self.slots_in_epoch,
124        )
125    }
126
127    // Get the TPU sockets for the current leader and upcoming *unique* leaders according to fanout size.
128    fn get_unique_leader_sockets(
129        &self,
130        estimated_current_slot: Slot,
131        fanout_slots: u64,
132    ) -> Vec<SocketAddr> {
133        let all_leader_sockets = self.get_leader_sockets(estimated_current_slot, fanout_slots);
134
135        let mut unique_sockets = Vec::new();
136        let mut seen = HashSet::new();
137
138        for socket in all_leader_sockets {
139            if seen.insert(socket) {
140                unique_sockets.push(socket);
141            }
142        }
143
144        unique_sockets
145    }
146
147    // Get the TPU sockets for the current leader and upcoming leaders according to fanout size.
148    fn get_leader_sockets(
149        &self,
150        estimated_current_slot: Slot,
151        fanout_slots: u64,
152    ) -> Vec<SocketAddr> {
153        let mut leader_sockets = Vec::new();
154        // `first_slot` might have been advanced since caller last read the `estimated_current_slot`
155        // value. Take the greater of the two values to ensure we are reading from the latest
156        // leader schedule.
157        let current_slot = std::cmp::max(estimated_current_slot, self.first_slot);
158        for leader_slot in (current_slot..current_slot + fanout_slots)
159            .step_by(NUM_CONSECUTIVE_LEADER_SLOTS as usize)
160        {
161            if let Some(leader) = self.get_slot_leader(leader_slot) {
162                if let Some(tpu_socket) = self.leader_tpu_map.get(leader) {
163                    leader_sockets.push(*tpu_socket);
164                } else {
165                    // The leader is probably delinquent
166                    trace!("TPU not available for leader {leader}");
167                }
168            } else {
169                // Overran the local leader schedule cache
170                warn!(
171                    "Leader not known for slot {}; cache holds slots [{},{}]",
172                    leader_slot,
173                    self.first_slot,
174                    self.last_slot()
175                );
176            }
177        }
178        leader_sockets
179    }
180
181    pub fn get_slot_leader(&self, slot: Slot) -> Option<&Pubkey> {
182        if slot >= self.first_slot {
183            let index = slot - self.first_slot;
184            self.leaders.get(index as usize)
185        } else {
186            None
187        }
188    }
189
190    fn extract_cluster_tpu_sockets(
191        protocol: Protocol,
192        cluster_contact_info: Vec<RpcContactInfo>,
193    ) -> HashMap<Pubkey, SocketAddr> {
194        cluster_contact_info
195            .into_iter()
196            .filter_map(|contact_info| {
197                let pubkey = Pubkey::from_str(&contact_info.pubkey).ok()?;
198                let socket = match protocol {
199                    Protocol::QUIC => contact_info.tpu_quic.or_else(|| {
200                        let mut socket = contact_info.tpu?;
201                        let port = socket.port().checked_add(QUIC_PORT_OFFSET)?;
202                        socket.set_port(port);
203                        Some(socket)
204                    }),
205                    Protocol::UDP => contact_info.tpu,
206                }?;
207                Some((pubkey, socket))
208            })
209            .collect()
210    }
211
212    pub fn fanout(slots_in_epoch: Slot) -> Slot {
213        (2 * MAX_FANOUT_SLOTS).min(slots_in_epoch)
214    }
215
216    pub fn update_all(&mut self, cache_update_info: LeaderTpuCacheUpdateInfo) -> (bool, bool) {
217        let mut has_error = false;
218        let mut cluster_refreshed = false;
219        if let Some(cluster_nodes) = cache_update_info.maybe_cluster_nodes {
220            match cluster_nodes {
221                Ok(cluster_nodes) => {
222                    self.leader_tpu_map =
223                        Self::extract_cluster_tpu_sockets(self.protocol, cluster_nodes);
224                    cluster_refreshed = true;
225                }
226                Err(err) => {
227                    warn!("Failed to fetch cluster tpu sockets: {err}");
228                    has_error = true;
229                }
230            }
231        }
232
233        if let Some(Ok(epoch_schedule)) = cache_update_info.maybe_epoch_schedule {
234            let epoch = epoch_schedule.get_epoch(cache_update_info.first_slot);
235            self.slots_in_epoch = epoch_schedule.get_slots_in_epoch(epoch);
236            self.last_slot_in_epoch = epoch_schedule.get_last_slot_in_epoch(epoch);
237        }
238
239        if let Some(slot_leaders) = cache_update_info.maybe_slot_leaders {
240            match slot_leaders {
241                Ok(slot_leaders) => {
242                    self.first_slot = cache_update_info.first_slot;
243                    self.leaders = slot_leaders;
244                }
245                Err(err) => {
246                    warn!(
247                        "Failed to fetch slot leaders (first_slot: {}): {err}",
248                        cache_update_info.first_slot
249                    );
250                    has_error = true;
251                }
252            }
253        }
254        (has_error, cluster_refreshed)
255    }
256}
257
258/// Client which sends transactions directly to the current leader's TPU port over UDP.
259/// The client uses RPC to determine the current leader and fetch node contact info
260pub struct TpuClient<
261    P, // ConnectionPool
262    M, // ConnectionManager
263    C, // NewConnectionConfig
264> {
265    fanout_slots: u64,
266    leader_tpu_service: LeaderTpuService,
267    exit: Arc<AtomicBool>,
268    rpc_client: Arc<RpcClient>,
269    connection_cache: Arc<ConnectionCache<P, M, C>>,
270}
271
272/// Helper function which generates futures to all be awaited together for maximum
273/// throughput
274#[cfg(feature = "spinner")]
275fn send_wire_transaction_futures<'a, P, M, C>(
276    progress_bar: &'a ProgressBar,
277    progress: &'a SendTransactionProgress,
278    index: usize,
279    num_transactions: usize,
280    wire_transaction: Vec<u8>,
281    leaders: Vec<SocketAddr>,
282    connection_cache: &'a ConnectionCache<P, M, C>,
283) -> Vec<impl Future<Output = TransportResult<()>> + 'a>
284where
285    P: ConnectionPool<NewConnectionConfig = C>,
286    M: ConnectionManager<ConnectionPool = P, NewConnectionConfig = C>,
287    C: NewConnectionConfig,
288{
289    const SEND_TIMEOUT_INTERVAL: Duration = Duration::from_secs(5);
290    let sleep_duration = SEND_TRANSACTION_INTERVAL.saturating_mul(index as u32);
291    let send_timeout = SEND_TIMEOUT_INTERVAL.saturating_add(sleep_duration);
292    leaders
293        .into_iter()
294        .map(|addr| {
295            timeout_future(
296                send_timeout,
297                sleep_and_send_wire_transaction_to_addr(
298                    sleep_duration,
299                    connection_cache,
300                    addr,
301                    wire_transaction.clone(),
302                ),
303            )
304            .boxed_local() // required to make types work simply
305        })
306        .chain(iter::once(
307            timeout_future(
308                send_timeout,
309                sleep_and_set_message(
310                    sleep_duration,
311                    progress_bar,
312                    progress,
313                    index,
314                    num_transactions,
315                ),
316            )
317            .boxed_local(), // required to make types work simply
318        ))
319        .collect::<Vec<_>>()
320}
321
322// Wrap an existing future with a timeout.
323//
324// Useful for end-users who don't need a persistent connection to each validator,
325// and want to abort more quickly.
326#[cfg(feature = "spinner")]
327async fn timeout_future<Fut: Future<Output = TransportResult<()>>>(
328    timeout_duration: Duration,
329    future: Fut,
330) -> TransportResult<()> {
331    timeout(timeout_duration, future)
332        .await
333        .unwrap_or_else(|_| Err(TransportError::Custom("Timed out".to_string())))
334}
335
336#[cfg(feature = "spinner")]
337async fn sleep_and_set_message(
338    sleep_duration: Duration,
339    progress_bar: &ProgressBar,
340    progress: &SendTransactionProgress,
341    index: usize,
342    num_transactions: usize,
343) -> TransportResult<()> {
344    sleep(sleep_duration).await;
345    progress.set_message_for_confirmed_transactions(
346        progress_bar,
347        &format!("Sending {}/{} transactions", index + 1, num_transactions,),
348    );
349    Ok(())
350}
351
352#[cfg(feature = "spinner")]
353async fn sleep_and_send_wire_transaction_to_addr<P, M, C>(
354    sleep_duration: Duration,
355    connection_cache: &ConnectionCache<P, M, C>,
356    addr: SocketAddr,
357    wire_transaction: Vec<u8>,
358) -> TransportResult<()>
359where
360    P: ConnectionPool<NewConnectionConfig = C>,
361    M: ConnectionManager<ConnectionPool = P, NewConnectionConfig = C>,
362    C: NewConnectionConfig,
363{
364    sleep(sleep_duration).await;
365    send_wire_transaction_to_addr(connection_cache, &addr, wire_transaction).await
366}
367
368async fn send_wire_transaction_to_addr<P, M, C>(
369    connection_cache: &ConnectionCache<P, M, C>,
370    addr: &SocketAddr,
371    wire_transaction: Vec<u8>,
372) -> TransportResult<()>
373where
374    P: ConnectionPool<NewConnectionConfig = C>,
375    M: ConnectionManager<ConnectionPool = P, NewConnectionConfig = C>,
376    C: NewConnectionConfig,
377{
378    let conn = connection_cache.get_nonblocking_connection(addr);
379    conn.send_data(&wire_transaction).await
380}
381
382async fn send_wire_transaction_batch_to_addr<P, M, C>(
383    connection_cache: &ConnectionCache<P, M, C>,
384    addr: &SocketAddr,
385    wire_transactions: &[Vec<u8>],
386) -> TransportResult<()>
387where
388    P: ConnectionPool<NewConnectionConfig = C>,
389    M: ConnectionManager<ConnectionPool = P, NewConnectionConfig = C>,
390    C: NewConnectionConfig,
391{
392    let conn = connection_cache.get_nonblocking_connection(addr);
393    conn.send_data_batch(wire_transactions).await
394}
395
396impl<P, M, C> TpuClient<P, M, C>
397where
398    P: ConnectionPool<NewConnectionConfig = C>,
399    M: ConnectionManager<ConnectionPool = P, NewConnectionConfig = C>,
400    C: NewConnectionConfig,
401{
402    /// Serialize and send transaction to the current and upcoming leader TPUs according to fanout
403    /// size
404    pub async fn send_transaction(&self, transaction: &Transaction) -> bool {
405        let wire_transaction = serialize(transaction).expect("serialization should succeed");
406        self.send_wire_transaction(wire_transaction).await
407    }
408
409    /// Send a wire transaction to the current and upcoming leader TPUs according to fanout size
410    pub async fn send_wire_transaction(&self, wire_transaction: Vec<u8>) -> bool {
411        self.try_send_wire_transaction(wire_transaction)
412            .await
413            .is_ok()
414    }
415
416    /// Serialize and send transaction to the current and upcoming leader TPUs according to fanout
417    /// size
418    /// Returns the last error if all sends fail
419    pub async fn try_send_transaction(&self, transaction: &Transaction) -> TransportResult<()> {
420        let wire_transaction = serialize(transaction).expect("serialization should succeed");
421        self.try_send_wire_transaction(wire_transaction).await
422    }
423
424    /// Send a wire transaction to the current and upcoming leader TPUs according to fanout size
425    /// Returns the last error if all sends fail
426    pub async fn try_send_wire_transaction(
427        &self,
428        wire_transaction: Vec<u8>,
429    ) -> TransportResult<()> {
430        let leaders = self
431            .leader_tpu_service
432            .unique_leader_tpu_sockets(self.fanout_slots);
433        let futures = leaders
434            .iter()
435            .map(|addr| {
436                send_wire_transaction_to_addr(
437                    &self.connection_cache,
438                    addr,
439                    wire_transaction.clone(),
440                )
441            })
442            .collect::<Vec<_>>();
443        let results: Vec<TransportResult<()>> = join_all(futures).await;
444
445        let mut last_error: Option<TransportError> = None;
446        let mut some_success = false;
447        for result in results {
448            if let Err(e) = result {
449                if last_error.is_none() {
450                    last_error = Some(e);
451                }
452            } else {
453                some_success = true;
454            }
455        }
456        if !some_success {
457            Err(if let Some(err) = last_error {
458                err
459            } else {
460                std::io::Error::other("No sends attempted").into()
461            })
462        } else {
463            Ok(())
464        }
465    }
466
467    /// Send a batch of wire transactions to the current and upcoming leader TPUs according to
468    /// fanout size
469    /// Returns the last error if all sends fail
470    pub async fn try_send_wire_transaction_batch(
471        &self,
472        wire_transactions: Vec<Vec<u8>>,
473    ) -> TransportResult<()> {
474        let leaders = self
475            .leader_tpu_service
476            .unique_leader_tpu_sockets(self.fanout_slots);
477        let futures = leaders
478            .iter()
479            .map(|addr| {
480                send_wire_transaction_batch_to_addr(
481                    &self.connection_cache,
482                    addr,
483                    &wire_transactions,
484                )
485            })
486            .collect::<Vec<_>>();
487        let results: Vec<TransportResult<()>> = join_all(futures).await;
488
489        let mut last_error: Option<TransportError> = None;
490        let mut some_success = false;
491        for result in results {
492            if let Err(e) = result {
493                if last_error.is_none() {
494                    last_error = Some(e);
495                }
496            } else {
497                some_success = true;
498            }
499        }
500        if !some_success {
501            Err(if let Some(err) = last_error {
502                err
503            } else {
504                std::io::Error::other("No sends attempted").into()
505            })
506        } else {
507            Ok(())
508        }
509    }
510
511    /// Create a new client that disconnects when dropped
512    pub async fn new(
513        name: &'static str,
514        rpc_client: Arc<RpcClient>,
515        websocket_url: &str,
516        config: TpuClientConfig,
517        connection_manager: M,
518    ) -> Result<Self> {
519        let connection_cache = Arc::new(
520            ConnectionCache::new(name, connection_manager, DEFAULT_CONNECTION_POOL_SIZE).unwrap(),
521        ); // TODO: Handle error properly, as the ConnectionCache ctor is now fallible.
522        Self::new_with_connection_cache(rpc_client, websocket_url, config, connection_cache).await
523    }
524
525    /// Create a new client that disconnects when dropped
526    pub async fn new_with_connection_cache(
527        rpc_client: Arc<RpcClient>,
528        websocket_url: &str,
529        config: TpuClientConfig,
530        connection_cache: Arc<ConnectionCache<P, M, C>>,
531    ) -> Result<Self> {
532        let exit = Arc::new(AtomicBool::new(false));
533        let leader_tpu_service =
534            LeaderTpuService::new(rpc_client.clone(), websocket_url, M::PROTOCOL, exit.clone())
535                .await?;
536
537        Ok(Self {
538            fanout_slots: config.fanout_slots.clamp(1, MAX_FANOUT_SLOTS),
539            leader_tpu_service,
540            exit,
541            rpc_client,
542            connection_cache,
543        })
544    }
545
546    #[cfg(feature = "spinner")]
547    pub async fn send_and_confirm_messages_with_spinner<T: Signers + ?Sized>(
548        &self,
549        messages: &[Message],
550        signers: &T,
551    ) -> Result<Vec<Option<TransactionError>>> {
552        let mut progress = SendTransactionProgress::default();
553        let progress_bar = spinner::new_progress_bar();
554        progress_bar.set_message("Setting up...");
555
556        let mut transactions = messages
557            .iter()
558            .enumerate()
559            .map(|(i, message)| (i, Transaction::new_unsigned(message.clone())))
560            .collect::<Vec<_>>();
561        progress.total_transactions = transactions.len();
562        let mut transaction_errors = vec![None; transactions.len()];
563        progress.block_height = self.rpc_client.get_block_height().await?;
564        for expired_blockhash_retries in (0..5).rev() {
565            let (blockhash, last_valid_block_height) = self
566                .rpc_client
567                .get_latest_blockhash_with_commitment(self.rpc_client.commitment())
568                .await?;
569            progress.last_valid_block_height = last_valid_block_height;
570
571            let mut pending_transactions = HashMap::new();
572            for (i, mut transaction) in transactions {
573                transaction.try_sign(signers, blockhash)?;
574                pending_transactions.insert(transaction.signatures[0], (i, transaction));
575            }
576
577            let mut last_resend = Instant::now() - TRANSACTION_RESEND_INTERVAL;
578            while progress.block_height <= progress.last_valid_block_height {
579                let num_transactions = pending_transactions.len();
580
581                // Periodically re-send all pending transactions
582                if Instant::now().duration_since(last_resend) > TRANSACTION_RESEND_INTERVAL {
583                    // Prepare futures for all transactions
584                    let mut futures = vec![];
585                    for (index, (_i, transaction)) in pending_transactions.values().enumerate() {
586                        let wire_transaction = serialize(transaction).unwrap();
587                        let leaders = self
588                            .leader_tpu_service
589                            .unique_leader_tpu_sockets(self.fanout_slots);
590                        futures.extend(send_wire_transaction_futures(
591                            &progress_bar,
592                            &progress,
593                            index,
594                            num_transactions,
595                            wire_transaction,
596                            leaders,
597                            &self.connection_cache,
598                        ));
599                    }
600
601                    // Start the process of sending them all
602                    let results = join_all(futures).await;
603
604                    progress.set_message_for_confirmed_transactions(
605                        &progress_bar,
606                        "Checking sent transactions",
607                    );
608                    for (index, (tx_results, (_i, transaction))) in results
609                        .chunks(self.fanout_slots as usize)
610                        .zip(pending_transactions.values())
611                        .enumerate()
612                    {
613                        // Only report an error if every future in the chunk errored
614                        if tx_results.iter().all(|r| r.is_err()) {
615                            progress.set_message_for_confirmed_transactions(
616                                &progress_bar,
617                                &format!(
618                                    "Resending failed transaction {} of {}",
619                                    index + 1,
620                                    num_transactions,
621                                ),
622                            );
623                            let _result = self.rpc_client.send_transaction(transaction).await.ok();
624                        }
625                    }
626                    last_resend = Instant::now();
627                }
628
629                // Wait for the next block before checking for transaction statuses
630                let mut block_height_refreshes = 10;
631                progress.set_message_for_confirmed_transactions(
632                    &progress_bar,
633                    &format!("Waiting for next block, {num_transactions} transactions pending..."),
634                );
635                let mut new_block_height = progress.block_height;
636                while progress.block_height == new_block_height && block_height_refreshes > 0 {
637                    sleep(Duration::from_millis(500)).await;
638                    new_block_height = self.rpc_client.get_block_height().await?;
639                    block_height_refreshes -= 1;
640                }
641                progress.block_height = new_block_height;
642
643                // Collect statuses for the transactions, drop those that are confirmed
644                let pending_signatures = pending_transactions.keys().cloned().collect::<Vec<_>>();
645                for pending_signatures_chunk in
646                    pending_signatures.chunks(MAX_GET_SIGNATURE_STATUSES_QUERY_ITEMS)
647                {
648                    if let Ok(result) = self
649                        .rpc_client
650                        .get_signature_statuses(pending_signatures_chunk)
651                        .await
652                    {
653                        let statuses = result.value;
654                        for (signature, status) in
655                            pending_signatures_chunk.iter().zip(statuses.into_iter())
656                        {
657                            if let Some(status) = status {
658                                if status.satisfies_commitment(self.rpc_client.commitment()) {
659                                    if let Some((i, _)) = pending_transactions.remove(signature) {
660                                        progress.confirmed_transactions += 1;
661                                        if status.err.is_some() {
662                                            progress_bar
663                                                .println(format!("Failed transaction: {status:?}"));
664                                        }
665                                        transaction_errors[i] = status.err;
666                                    }
667                                }
668                            }
669                        }
670                    }
671                    progress.set_message_for_confirmed_transactions(
672                        &progress_bar,
673                        "Checking transaction status...",
674                    );
675                }
676
677                if pending_transactions.is_empty() {
678                    return Ok(transaction_errors);
679                }
680            }
681
682            transactions = pending_transactions.into_values().collect();
683            progress_bar.println(format!(
684                "Blockhash expired. {expired_blockhash_retries} retries remaining"
685            ));
686        }
687        Err(TpuSenderError::Custom("Max retries exceeded".into()))
688    }
689
690    pub fn rpc_client(&self) -> &RpcClient {
691        &self.rpc_client
692    }
693
694    pub async fn shutdown(&mut self) {
695        self.exit.store(true, Ordering::Relaxed);
696        self.leader_tpu_service.join().await;
697    }
698
699    pub fn get_connection_cache(&self) -> &Arc<ConnectionCache<P, M, C>>
700    where
701        P: ConnectionPool<NewConnectionConfig = C>,
702        M: ConnectionManager<ConnectionPool = P, NewConnectionConfig = C>,
703        C: NewConnectionConfig,
704    {
705        &self.connection_cache
706    }
707
708    pub fn get_leader_tpu_service(&self) -> &LeaderTpuService {
709        &self.leader_tpu_service
710    }
711
712    pub fn get_fanout_slots(&self) -> u64 {
713        self.fanout_slots
714    }
715}
716
717impl<P, M, C> Drop for TpuClient<P, M, C> {
718    fn drop(&mut self) {
719        self.exit.store(true, Ordering::Relaxed);
720    }
721}
722
723/// Service that tracks upcoming leaders and maintains an up-to-date mapping
724/// of leader id to TPU socket address.
725pub struct LeaderTpuService {
726    recent_slots: RecentLeaderSlots,
727    leader_tpu_cache: Arc<RwLock<LeaderTpuCache>>,
728    t_leader_tpu_service: Option<JoinHandle<Result<()>>>,
729}
730
731impl LeaderTpuService {
732    pub async fn new(
733        rpc_client: Arc<RpcClient>,
734        websocket_url: &str,
735        protocol: Protocol,
736        exit: Arc<AtomicBool>,
737    ) -> Result<Self> {
738        let epoch_schedule = rpc_client.get_epoch_schedule().await?;
739        let start_slot = rpc_client
740            .get_slot_with_commitment(CommitmentConfig::processed())
741            .await?;
742
743        let recent_slots = RecentLeaderSlots::new(start_slot);
744        let epoch = epoch_schedule.get_epoch(start_slot);
745        let slots_in_epoch = epoch_schedule.get_slots_in_epoch(epoch);
746        let last_slot_in_epoch = epoch_schedule.get_last_slot_in_epoch(epoch);
747
748        // When a cluster is starting, we observe an invalid slot range failure that goes away after a
749        // retry. It seems as if the leader schedule is not available, but it should be. The logic
750        // below retries the RPC call in case of an invalid slot range error.
751        let tpu_leader_service_creation_timeout = Duration::from_secs(20);
752        let retry_interval = Duration::from_secs(1);
753        let leaders = timeout(tpu_leader_service_creation_timeout, async {
754            loop {
755                // TODO: The root cause appears to lie within the `rpc_client.get_slot_leaders()`.
756                // It might be worth debugging further and trying to understand why the RPC
757                // call fails. There may be a bug in the `get_slot_leaders()` logic or in the
758                // RPC implementation
759                match rpc_client
760                    .get_slot_leaders(start_slot, LeaderTpuCache::fanout(slots_in_epoch))
761                    .await
762                {
763                    Ok(leaders) => return Ok(leaders),
764                    Err(client_error) => {
765                        if is_invalid_slot_range_error(&client_error) {
766                            sleep(retry_interval).await;
767                            continue;
768                        } else {
769                            return Err(client_error);
770                        }
771                    }
772                }
773            }
774        })
775        .await
776        .map_err(|_| {
777            TpuSenderError::Custom(format!(
778                "Failed to get slot leaders connecting to: {websocket_url}, timeout: \
779                 {tpu_leader_service_creation_timeout:?}. Invalid slot range"
780            ))
781        })??;
782
783        let cluster_nodes = timeout(tpu_leader_service_creation_timeout, async {
784            loop {
785                let cluster_nodes = rpc_client.get_cluster_nodes().await?;
786                // Stop once we find at least one leader's contact info
787                if cluster_nodes.iter().any(|rpc_contact_info| {
788                    Pubkey::from_str(&rpc_contact_info.pubkey)
789                        .map(|pubkey| leaders.contains(&pubkey))
790                        .unwrap_or(false)
791                }) {
792                    return Ok::<_, ClientError>(cluster_nodes);
793                }
794                sleep(retry_interval).await;
795            }
796        })
797        .await
798        .map_err(|_| {
799            TpuSenderError::Custom(format!(
800                "Failed find any cluster node info for upcoming leaders, timeout: \
801                 {tpu_leader_service_creation_timeout:?}."
802            ))
803        })??;
804        let leader_tpu_cache = Arc::new(RwLock::new(LeaderTpuCache::new(
805            start_slot,
806            slots_in_epoch,
807            last_slot_in_epoch,
808            leaders,
809            cluster_nodes,
810            protocol,
811        )));
812
813        let pubsub_client = if !websocket_url.is_empty() {
814            Some(PubsubClient::new(websocket_url).await?)
815        } else {
816            None
817        };
818
819        let t_leader_tpu_service = Some({
820            let recent_slots = recent_slots.clone();
821            let leader_tpu_cache = leader_tpu_cache.clone();
822            tokio::spawn(Self::run(
823                rpc_client,
824                recent_slots,
825                leader_tpu_cache,
826                pubsub_client,
827                exit,
828            ))
829        });
830
831        Ok(LeaderTpuService {
832            recent_slots,
833            leader_tpu_cache,
834            t_leader_tpu_service,
835        })
836    }
837
838    pub async fn join(&mut self) {
839        if let Some(t_handle) = self.t_leader_tpu_service.take() {
840            t_handle.await.unwrap().unwrap();
841        }
842    }
843
844    pub fn estimated_current_slot(&self) -> Slot {
845        self.recent_slots.estimated_current_slot()
846    }
847
848    pub fn unique_leader_tpu_sockets(&self, fanout_slots: u64) -> Vec<SocketAddr> {
849        let current_slot = self.recent_slots.estimated_current_slot();
850        self.leader_tpu_cache
851            .read()
852            .unwrap()
853            .get_unique_leader_sockets(current_slot, fanout_slots)
854    }
855
856    pub fn leader_tpu_sockets(&self, fanout_slots: u64) -> Vec<SocketAddr> {
857        let current_slot = self.recent_slots.estimated_current_slot();
858        self.leader_tpu_cache
859            .read()
860            .unwrap()
861            .get_leader_sockets(current_slot, fanout_slots)
862    }
863
864    async fn run(
865        rpc_client: Arc<RpcClient>,
866        recent_slots: RecentLeaderSlots,
867        leader_tpu_cache: Arc<RwLock<LeaderTpuCache>>,
868        pubsub_client: Option<PubsubClient>,
869        exit: Arc<AtomicBool>,
870    ) -> Result<()> {
871        tokio::try_join!(
872            Self::run_slot_watcher(recent_slots.clone(), pubsub_client, exit.clone()),
873            Self::run_cache_refresher(rpc_client, recent_slots, leader_tpu_cache, exit),
874        )?;
875
876        Ok(())
877    }
878
879    async fn run_cache_refresher(
880        rpc_client: Arc<RpcClient>,
881        recent_slots: RecentLeaderSlots,
882        leader_tpu_cache: Arc<RwLock<LeaderTpuCache>>,
883        exit: Arc<AtomicBool>,
884    ) -> Result<()> {
885        let mut last_cluster_refresh = Instant::now();
886        let mut sleep_ms = DEFAULT_MS_PER_SLOT;
887
888        while !exit.load(Ordering::Relaxed) {
889            // Sleep a slot before checking if leader cache needs to be refreshed again
890            sleep(Duration::from_millis(sleep_ms)).await;
891            sleep_ms = DEFAULT_MS_PER_SLOT;
892
893            let cache_update_info = maybe_fetch_cache_info(
894                &leader_tpu_cache,
895                last_cluster_refresh,
896                &rpc_client,
897                &recent_slots,
898            )
899            .await;
900
901            if cache_update_info.has_some() {
902                let mut leader_tpu_cache = leader_tpu_cache.write().unwrap();
903                let (has_error, cluster_refreshed) = leader_tpu_cache.update_all(cache_update_info);
904                if has_error {
905                    sleep_ms = 100;
906                }
907                if cluster_refreshed {
908                    last_cluster_refresh = Instant::now();
909                }
910            }
911        }
912
913        Ok(())
914    }
915
916    async fn run_slot_watcher(
917        recent_slots: RecentLeaderSlots,
918        pubsub_client: Option<PubsubClient>,
919        exit: Arc<AtomicBool>,
920    ) -> Result<()> {
921        let Some(pubsub_client) = pubsub_client else {
922            return Ok(());
923        };
924
925        let (mut notifications, unsubscribe) = pubsub_client.slot_updates_subscribe().await?;
926        // Time out slot update notification polling at 10ms.
927        //
928        // Rationale is two-fold:
929        // 1. Notifications are an unbounded stream -- polling them will block indefinitely if not
930        //    interrupted, and the exit condition will never be checked. 10ms ensures negligible
931        //    CPU overhead while keeping notification checking timely.
932        // 2. The timeout must be strictly less than the slot time (DEFAULT_MS_PER_SLOT: 400) to
933        //    avoid timeout never being reached. For example, if notifications are received every
934        //    400ms and the timeout is >= 400ms, notifications may theoretically always be available
935        //    before the timeout is reached, resulting in the exit condition never being checked.
936        const SLOT_UPDATE_TIMEOUT: Duration = Duration::from_millis(10);
937
938        while !exit.load(Ordering::Relaxed) {
939            while let Ok(Some(update)) = timeout(SLOT_UPDATE_TIMEOUT, notifications.next()).await {
940                let current_slot = match update {
941                    // This update indicates that a full slot was received by the connected
942                    // node so we can stop sending transactions to the leader for that slot
943                    SlotUpdate::Completed { slot, .. } => slot.saturating_add(1),
944                    // This update indicates that we have just received the first shred from
945                    // the leader for this slot and they are probably still accepting transactions.
946                    SlotUpdate::FirstShredReceived { slot, .. } => slot,
947                    _ => continue,
948                };
949                recent_slots.record_slot(current_slot);
950            }
951        }
952
953        // `notifications` requires a valid reference to `pubsub_client`, so `notifications` must be
954        // dropped before moving `pubsub_client` via `shutdown()`.
955        drop(notifications);
956        unsubscribe().await;
957        pubsub_client.shutdown().await?;
958
959        Ok(())
960    }
961}
962
963async fn maybe_fetch_cache_info(
964    leader_tpu_cache: &Arc<RwLock<LeaderTpuCache>>,
965    last_cluster_refresh: Instant,
966    rpc_client: &RpcClient,
967    recent_slots: &RecentLeaderSlots,
968) -> LeaderTpuCacheUpdateInfo {
969    // Refresh cluster TPU ports every 5min in case validators restart with new port configuration
970    // or new validators come online
971    let maybe_cluster_nodes = if last_cluster_refresh.elapsed() > Duration::from_secs(5 * 60) {
972        Some(rpc_client.get_cluster_nodes().await)
973    } else {
974        None
975    };
976
977    // Grab information about the slot leaders currently in the cache.
978    let estimated_current_slot = recent_slots.estimated_current_slot();
979    let (last_slot, last_slot_in_epoch, slots_in_epoch) = {
980        let leader_tpu_cache = leader_tpu_cache.read().unwrap();
981        leader_tpu_cache.slot_info()
982    };
983
984    // If we're crossing into a new epoch, fetch the updated epoch schedule.
985    let maybe_epoch_schedule = if estimated_current_slot > last_slot_in_epoch {
986        Some(rpc_client.get_epoch_schedule().await)
987    } else {
988        None
989    };
990
991    // If we are within the fanout range of the last slot in the cache, fetch
992    // more slot leaders. We pull down a big batch at at time to amortize the
993    // cost of the RPC call. We don't want to stall transactions on pulling this
994    // down so we fetch it proactively.
995    let maybe_slot_leaders = if estimated_current_slot >= last_slot.saturating_sub(MAX_FANOUT_SLOTS)
996    {
997        Some(
998            rpc_client
999                .get_slot_leaders(
1000                    estimated_current_slot,
1001                    LeaderTpuCache::fanout(slots_in_epoch),
1002                )
1003                .await,
1004        )
1005    } else {
1006        None
1007    };
1008    LeaderTpuCacheUpdateInfo {
1009        maybe_cluster_nodes,
1010        maybe_epoch_schedule,
1011        maybe_slot_leaders,
1012        first_slot: estimated_current_slot,
1013    }
1014}
1015
1016fn is_invalid_slot_range_error(client_error: &ClientError) -> bool {
1017    if let ErrorKind::RpcError(RpcError::RpcResponseError { code, message, .. }) =
1018        client_error.kind()
1019    {
1020        return *code == -32602
1021            && message.contains("Invalid slot range: leader schedule for epoch");
1022    }
1023    false
1024}