solana_quic_client/nonblocking/
quic_client.rs

1//! Simple nonblocking client that connects to a given UDP port with the QUIC protocol
2//! and provides an interface for sending data which is restricted by the
3//! server's flow control.
4use {
5    async_lock::Mutex,
6    async_trait::async_trait,
7    futures::future::TryFutureExt,
8    log::*,
9    quinn::{
10        crypto::rustls::QuicClientConfig, ClientConfig, ClosedStream, ConnectError, Connection,
11        ConnectionError, Endpoint, EndpointConfig, IdleTimeout, TokioRuntime, TransportConfig,
12        WriteError,
13    },
14    solana_connection_cache::{
15        client_connection::ClientStats, connection_cache_stats::ConnectionCacheStats,
16        nonblocking::client_connection::ClientConnection,
17    },
18    solana_keypair::Keypair,
19    solana_measure::measure::Measure,
20    solana_net_utils::sockets,
21    solana_quic_definitions::{
22        QUIC_CONNECTION_HANDSHAKE_TIMEOUT, QUIC_KEEP_ALIVE, QUIC_MAX_TIMEOUT, QUIC_SEND_FAIRNESS,
23    },
24    solana_rpc_client_api::client_error::ErrorKind as ClientErrorKind,
25    solana_streamer::nonblocking::quic::ALPN_TPU_PROTOCOL_ID,
26    solana_tls_utils::{
27        new_dummy_x509_certificate, socket_addr_to_quic_server_name, tls_client_config_builder,
28        QuicClientCertificate,
29    },
30    solana_transaction_error::TransportResult,
31    std::{
32        net::{SocketAddr, UdpSocket},
33        sync::{atomic::Ordering, Arc},
34        thread,
35    },
36    thiserror::Error,
37    tokio::{sync::OnceCell, time::timeout},
38};
39
40/// A lazy-initialized Quic Endpoint
41pub struct QuicLazyInitializedEndpoint {
42    endpoint: OnceCell<Arc<Endpoint>>,
43    client_certificate: Arc<QuicClientCertificate>,
44    client_endpoint: Option<Endpoint>,
45}
46
47#[derive(Error, Debug)]
48pub enum QuicError {
49    #[error(transparent)]
50    WriteError(#[from] WriteError),
51    #[error(transparent)]
52    ConnectionError(#[from] ConnectionError),
53    #[error(transparent)]
54    ConnectError(#[from] ConnectError),
55    #[error(transparent)]
56    ClosedStream(#[from] ClosedStream),
57}
58
59impl From<QuicError> for ClientErrorKind {
60    fn from(quic_error: QuicError) -> Self {
61        Self::Custom(format!("{quic_error:?}"))
62    }
63}
64
65impl QuicLazyInitializedEndpoint {
66    pub fn new(
67        client_certificate: Arc<QuicClientCertificate>,
68        client_endpoint: Option<Endpoint>,
69    ) -> Self {
70        Self {
71            endpoint: OnceCell::<Arc<Endpoint>>::new(),
72            client_certificate,
73            client_endpoint,
74        }
75    }
76
77    fn create_endpoint(&self) -> Endpoint {
78        let mut endpoint = if let Some(endpoint) = &self.client_endpoint {
79            endpoint.clone()
80        } else {
81            // This will bind to random ports, but VALIDATOR_PORT_RANGE is outside
82            // of the range for CI tests when this is running in CI
83            let client_socket = sockets::bind_in_range_with_config(
84                std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED),
85                solana_net_utils::VALIDATOR_PORT_RANGE,
86                sockets::SocketConfiguration::default(),
87            )
88            .expect("QuicLazyInitializedEndpoint::create_endpoint bind_in_range")
89            .1;
90            info!("Local endpoint is : {client_socket:?}");
91
92            QuicNewConnection::create_endpoint(EndpointConfig::default(), client_socket)
93        };
94
95        let mut crypto = tls_client_config_builder()
96            .with_client_auth_cert(
97                vec![self.client_certificate.certificate.clone()],
98                self.client_certificate.key.clone_key(),
99            )
100            .expect("Failed to set QUIC client certificates");
101        crypto.enable_early_data = true;
102        crypto.alpn_protocols = vec![ALPN_TPU_PROTOCOL_ID.to_vec()];
103
104        let mut config = ClientConfig::new(Arc::new(QuicClientConfig::try_from(crypto).unwrap()));
105        let mut transport_config = TransportConfig::default();
106
107        let timeout = IdleTimeout::try_from(QUIC_MAX_TIMEOUT).unwrap();
108        transport_config.max_idle_timeout(Some(timeout));
109        transport_config.keep_alive_interval(Some(QUIC_KEEP_ALIVE));
110        transport_config.send_fairness(QUIC_SEND_FAIRNESS);
111        config.transport_config(Arc::new(transport_config));
112
113        endpoint.set_default_client_config(config);
114
115        endpoint
116    }
117
118    async fn get_endpoint(&self) -> Arc<Endpoint> {
119        self.endpoint
120            .get_or_init(|| async { Arc::new(self.create_endpoint()) })
121            .await
122            .clone()
123    }
124}
125
126impl Default for QuicLazyInitializedEndpoint {
127    fn default() -> Self {
128        let (cert, priv_key) = new_dummy_x509_certificate(&Keypair::new());
129        Self::new(
130            Arc::new(QuicClientCertificate {
131                certificate: cert,
132                key: priv_key,
133            }),
134            None,
135        )
136    }
137}
138
139/// A wrapper over NewConnection with additional capability to create the endpoint as part
140/// of creating a new connection.
141#[derive(Clone)]
142struct QuicNewConnection {
143    endpoint: Arc<Endpoint>,
144    connection: Arc<Connection>,
145}
146
147impl QuicNewConnection {
148    /// Create a QuicNewConnection given the remote address 'addr'.
149    async fn make_connection(
150        endpoint: Arc<QuicLazyInitializedEndpoint>,
151        addr: SocketAddr,
152        stats: &ClientStats,
153    ) -> Result<Self, QuicError> {
154        let mut make_connection_measure = Measure::start("make_connection_measure");
155        let endpoint = endpoint.get_endpoint().await;
156        let server_name = socket_addr_to_quic_server_name(addr);
157        let connecting = endpoint.connect(addr, &server_name)?;
158        stats.total_connections.fetch_add(1, Ordering::Relaxed);
159        if let Ok(connecting_result) = timeout(QUIC_CONNECTION_HANDSHAKE_TIMEOUT, connecting).await
160        {
161            if connecting_result.is_err() {
162                stats.connection_errors.fetch_add(1, Ordering::Relaxed);
163            }
164            make_connection_measure.stop();
165            stats
166                .make_connection_ms
167                .fetch_add(make_connection_measure.as_ms(), Ordering::Relaxed);
168
169            let connection = connecting_result?;
170
171            Ok(Self {
172                endpoint,
173                connection: Arc::new(connection),
174            })
175        } else {
176            Err(ConnectionError::TimedOut.into())
177        }
178    }
179
180    fn create_endpoint(config: EndpointConfig, client_socket: UdpSocket) -> Endpoint {
181        quinn::Endpoint::new(config, None, client_socket, Arc::new(TokioRuntime))
182            .expect("QuicNewConnection::create_endpoint quinn::Endpoint::new")
183    }
184
185    // Attempts to make a faster connection by taking advantage of pre-existing key material.
186    // Only works if connection to this endpoint was previously established.
187    async fn make_connection_0rtt(
188        &mut self,
189        addr: SocketAddr,
190        stats: &ClientStats,
191    ) -> Result<Arc<Connection>, QuicError> {
192        let server_name = socket_addr_to_quic_server_name(addr);
193        let connecting = self.endpoint.connect(addr, &server_name)?;
194        stats.total_connections.fetch_add(1, Ordering::Relaxed);
195        let connection = match connecting.into_0rtt() {
196            Ok((connection, zero_rtt)) => {
197                if let Ok(zero_rtt) = timeout(QUIC_CONNECTION_HANDSHAKE_TIMEOUT, zero_rtt).await {
198                    if zero_rtt {
199                        stats.zero_rtt_accepts.fetch_add(1, Ordering::Relaxed);
200                    } else {
201                        stats.zero_rtt_rejects.fetch_add(1, Ordering::Relaxed);
202                    }
203                    connection
204                } else {
205                    return Err(ConnectionError::TimedOut.into());
206                }
207            }
208            Err(connecting) => {
209                stats.connection_errors.fetch_add(1, Ordering::Relaxed);
210
211                if let Ok(connecting_result) =
212                    timeout(QUIC_CONNECTION_HANDSHAKE_TIMEOUT, connecting).await
213                {
214                    connecting_result?
215                } else {
216                    return Err(ConnectionError::TimedOut.into());
217                }
218            }
219        };
220        self.connection = Arc::new(connection);
221        Ok(self.connection.clone())
222    }
223}
224
225pub struct QuicClient {
226    endpoint: Arc<QuicLazyInitializedEndpoint>,
227    connection: Arc<Mutex<Option<QuicNewConnection>>>,
228    addr: SocketAddr,
229    stats: Arc<ClientStats>,
230}
231
232const CONNECTION_CLOSE_CODE_APPLICATION_CLOSE: u32 = 0u32;
233const CONNECTION_CLOSE_REASON_APPLICATION_CLOSE: &[u8] = b"dropped";
234
235impl QuicClient {
236    /// Explicitly close the connection. Must be called manually if cleanup is needed.
237    pub async fn close(&self) {
238        let mut conn_guard = self.connection.lock().await;
239        if let Some(conn) = conn_guard.take() {
240            debug!(
241                "Closing connection to {} connection_id: {:?}",
242                self.addr, conn.connection
243            );
244            conn.connection.close(
245                CONNECTION_CLOSE_CODE_APPLICATION_CLOSE.into(),
246                CONNECTION_CLOSE_REASON_APPLICATION_CLOSE,
247            );
248        }
249    }
250}
251
252impl QuicClient {
253    pub fn new(endpoint: Arc<QuicLazyInitializedEndpoint>, addr: SocketAddr) -> Self {
254        Self {
255            endpoint,
256            connection: Arc::new(Mutex::new(None)),
257            addr,
258            stats: Arc::new(ClientStats::default()),
259        }
260    }
261
262    async fn _send_buffer_using_conn(
263        data: &[u8],
264        connection: &Connection,
265    ) -> Result<(), QuicError> {
266        let mut send_stream = connection.open_uni().await?;
267        send_stream.write_all(data).await?;
268        Ok(())
269    }
270
271    // Attempts to send data, connecting/reconnecting as necessary
272    // On success, returns the connection used to successfully send the data
273    async fn _send_buffer(
274        &self,
275        data: &[u8],
276        stats: &ClientStats,
277        connection_stats: Arc<ConnectionCacheStats>,
278    ) -> Result<Arc<Connection>, QuicError> {
279        let mut measure_send_packet = Measure::start("send_packet_us");
280        let mut measure_prepare_connection = Measure::start("prepare_connection");
281        let mut connection_try_count = 0;
282        let mut last_connection_id = 0;
283        let mut last_error = None;
284        while connection_try_count < 2 {
285            let connection = {
286                let mut conn_guard = self.connection.lock().await;
287
288                let maybe_conn = conn_guard.as_mut();
289                match maybe_conn {
290                    Some(conn) => {
291                        if conn.connection.stable_id() == last_connection_id {
292                            // this is the problematic connection we had used before, create a new one
293                            let conn = conn.make_connection_0rtt(self.addr, stats).await;
294                            match conn {
295                                Ok(conn) => {
296                                    info!(
297                                        "Made 0rtt connection to {} with id {} try_count {}, \
298                                         last_connection_id: {}, last_error: {:?}",
299                                        self.addr,
300                                        conn.stable_id(),
301                                        connection_try_count,
302                                        last_connection_id,
303                                        last_error,
304                                    );
305                                    connection_try_count += 1;
306                                    conn
307                                }
308                                Err(err) => {
309                                    info!(
310                                        "Cannot make 0rtt connection to {}, error {:}",
311                                        self.addr, err
312                                    );
313                                    return Err(err);
314                                }
315                            }
316                        } else {
317                            stats.connection_reuse.fetch_add(1, Ordering::Relaxed);
318                            conn.connection.clone()
319                        }
320                    }
321                    None => {
322                        let conn = QuicNewConnection::make_connection(
323                            self.endpoint.clone(),
324                            self.addr,
325                            stats,
326                        )
327                        .await;
328                        match conn {
329                            Ok(conn) => {
330                                *conn_guard = Some(conn.clone());
331                                info!(
332                                    "Made connection to {} id {} try_count {}, from connection \
333                                     cache warming?: {}",
334                                    self.addr,
335                                    conn.connection.stable_id(),
336                                    connection_try_count,
337                                    data.is_empty(),
338                                );
339                                connection_try_count += 1;
340                                conn.connection.clone()
341                            }
342                            Err(err) => {
343                                info!(
344                                    "Cannot make connection to {}, error {:}, from connection \
345                                     cache warming?: {}",
346                                    self.addr,
347                                    err,
348                                    data.is_empty()
349                                );
350                                return Err(err);
351                            }
352                        }
353                    }
354                }
355            };
356
357            let new_stats = connection.stats();
358
359            connection_stats
360                .total_client_stats
361                .congestion_events
362                .update_stat(
363                    &self.stats.congestion_events,
364                    new_stats.path.congestion_events,
365                );
366
367            connection_stats
368                .total_client_stats
369                .streams_blocked_uni
370                .update_stat(
371                    &self.stats.streams_blocked_uni,
372                    new_stats.frame_tx.streams_blocked_uni,
373                );
374
375            connection_stats
376                .total_client_stats
377                .data_blocked
378                .update_stat(&self.stats.data_blocked, new_stats.frame_tx.data_blocked);
379
380            connection_stats
381                .total_client_stats
382                .acks
383                .update_stat(&self.stats.acks, new_stats.frame_tx.acks);
384
385            if data.is_empty() {
386                // no need to send packet as it is only for warming connections
387                return Ok(connection);
388            }
389
390            last_connection_id = connection.stable_id();
391            measure_prepare_connection.stop();
392
393            match Self::_send_buffer_using_conn(data, &connection).await {
394                Ok(()) => {
395                    measure_send_packet.stop();
396                    stats.successful_packets.fetch_add(1, Ordering::Relaxed);
397                    stats
398                        .send_packets_us
399                        .fetch_add(measure_send_packet.as_us(), Ordering::Relaxed);
400                    stats
401                        .prepare_connection_us
402                        .fetch_add(measure_prepare_connection.as_us(), Ordering::Relaxed);
403                    trace!(
404                        "Succcessfully sent to {} with id {}, thread: {:?}, data len: {}, \
405                         send_packet_us: {} prepare_connection_us: {}",
406                        self.addr,
407                        connection.stable_id(),
408                        thread::current().id(),
409                        data.len(),
410                        measure_send_packet.as_us(),
411                        measure_prepare_connection.as_us(),
412                    );
413
414                    return Ok(connection);
415                }
416                Err(err) => match err {
417                    QuicError::ConnectionError(_) => {
418                        last_error = Some(err);
419                    }
420                    _ => {
421                        info!(
422                            "Error sending to {} with id {}, error {:?} thread: {:?}",
423                            self.addr,
424                            connection.stable_id(),
425                            err,
426                            thread::current().id(),
427                        );
428                        return Err(err);
429                    }
430                },
431            }
432        }
433
434        // if we come here, that means we have exhausted maximum retries, return the error
435        info!(
436            "Ran into an error sending data {:?}, exhausted retries to {}",
437            last_error, self.addr
438        );
439        // If we get here but last_error is None, then we have a logic error
440        // in this function, so panic here with an expect to help debugging
441        Err(last_error.expect("QuicClient::_send_buffer last_error.expect"))
442    }
443
444    pub async fn send_buffer<T>(
445        &self,
446        data: T,
447        stats: &ClientStats,
448        connection_stats: Arc<ConnectionCacheStats>,
449    ) -> Result<(), ClientErrorKind>
450    where
451        T: AsRef<[u8]>,
452    {
453        self._send_buffer(data.as_ref(), stats, connection_stats)
454            .await
455            .map_err(Into::<ClientErrorKind>::into)?;
456        Ok(())
457    }
458
459    pub async fn send_batch<T>(
460        &self,
461        buffers: &[T],
462        stats: &ClientStats,
463        connection_stats: Arc<ConnectionCacheStats>,
464    ) -> Result<(), ClientErrorKind>
465    where
466        T: AsRef<[u8]>,
467    {
468        // Start off by "testing" the connection by sending the first buffer
469        // This will also connect to the server if not already connected
470        // and reconnect and retry if the first send attempt failed
471        // (for example due to a timed out connection), returning an error
472        // or the connection that was used to successfully send the buffer.
473        // We will use the returned connection to send the rest of the buffers in the batch
474        // to avoid touching the mutex in self, and not bother reconnecting if we fail along the way
475        // since testing even in the ideal GCE environment has found no cases
476        // where reconnecting and retrying in the middle of a batch send
477        // (i.e. we encounter a connection error in the middle of a batch send, which presumably cannot
478        // be due to a timed out connection) has succeeded
479        if buffers.is_empty() {
480            return Ok(());
481        }
482        let connection = self
483            ._send_buffer(buffers[0].as_ref(), stats, connection_stats)
484            .await
485            .map_err(Into::<ClientErrorKind>::into)?;
486
487        for data in buffers[1..buffers.len()].iter() {
488            Self::_send_buffer_using_conn(data.as_ref(), &connection).await?;
489        }
490        Ok(())
491    }
492
493    pub fn server_addr(&self) -> &SocketAddr {
494        &self.addr
495    }
496
497    pub fn stats(&self) -> Arc<ClientStats> {
498        self.stats.clone()
499    }
500}
501
502pub struct QuicClientConnection {
503    pub client: Arc<QuicClient>,
504    pub connection_stats: Arc<ConnectionCacheStats>,
505}
506
507impl QuicClientConnection {
508    pub fn base_stats(&self) -> Arc<ClientStats> {
509        self.client.stats()
510    }
511
512    pub fn connection_stats(&self) -> Arc<ConnectionCacheStats> {
513        self.connection_stats.clone()
514    }
515
516    pub fn new(
517        endpoint: Arc<QuicLazyInitializedEndpoint>,
518        addr: SocketAddr,
519        connection_stats: Arc<ConnectionCacheStats>,
520    ) -> Self {
521        let client = Arc::new(QuicClient::new(endpoint, addr));
522        Self::new_with_client(client, connection_stats)
523    }
524
525    pub fn new_with_client(
526        client: Arc<QuicClient>,
527        connection_stats: Arc<ConnectionCacheStats>,
528    ) -> Self {
529        Self {
530            client,
531            connection_stats,
532        }
533    }
534}
535
536#[async_trait]
537impl ClientConnection for QuicClientConnection {
538    fn server_addr(&self) -> &SocketAddr {
539        self.client.server_addr()
540    }
541
542    async fn send_data_batch(&self, buffers: &[Vec<u8>]) -> TransportResult<()> {
543        let stats = ClientStats::default();
544        let len = buffers.len();
545        let res = self
546            .client
547            .send_batch(buffers, &stats, self.connection_stats.clone())
548            .await;
549        self.connection_stats
550            .add_client_stats(&stats, len, res.is_ok());
551        res?;
552        Ok(())
553    }
554
555    async fn send_data(&self, data: &[u8]) -> TransportResult<()> {
556        let stats = Arc::new(ClientStats::default());
557        // When data is empty which is from cache warmer, we are not sending packets actually, do not count it in
558        let num_packets = if data.is_empty() { 0 } else { 1 };
559        self.client
560            .send_buffer(data, &stats, self.connection_stats.clone())
561            .map_ok(|v| {
562                self.connection_stats
563                    .add_client_stats(&stats, num_packets, true);
564                v
565            })
566            .map_err(|e| {
567                warn!(
568                    "Failed to send data async to {}, error: {:?} ",
569                    self.server_addr(),
570                    e
571                );
572                datapoint_warn!("send-wire-async", ("failure", 1, i64),);
573                self.connection_stats
574                    .add_client_stats(&stats, num_packets, false);
575                e.into()
576            })
577            .await
578    }
579}