tonic/transport/channel/
endpoint.rs

1#[cfg(feature = "_tls-any")]
2use super::service::TlsConnector;
3use super::service::{self, Executor, SharedExec};
4use super::uds_connector::UdsConnector;
5use super::Channel;
6#[cfg(feature = "_tls-any")]
7use super::ClientTlsConfig;
8#[cfg(feature = "_tls-any")]
9use crate::transport::error;
10use crate::transport::Error;
11use bytes::Bytes;
12use http::{uri::Uri, HeaderValue};
13use hyper::rt;
14use hyper_util::client::legacy::connect::HttpConnector;
15use std::{fmt, future::Future, net::IpAddr, pin::Pin, str, str::FromStr, time::Duration};
16use tower_service::Service;
17
18#[derive(Clone, PartialEq, Eq, Hash)]
19pub(crate) enum EndpointType {
20    Uri(Uri),
21    Uds(String),
22}
23
24/// Channel builder.
25///
26/// This struct is used to build and configure HTTP/2 channels.
27#[derive(Clone)]
28pub struct Endpoint {
29    pub(crate) uri: EndpointType,
30    fallback_uri: Uri,
31    pub(crate) origin: Option<Uri>,
32    pub(crate) user_agent: Option<HeaderValue>,
33    pub(crate) timeout: Option<Duration>,
34    pub(crate) concurrency_limit: Option<usize>,
35    pub(crate) rate_limit: Option<(u64, Duration)>,
36    #[cfg(feature = "_tls-any")]
37    pub(crate) tls: Option<TlsConnector>,
38    pub(crate) buffer_size: Option<usize>,
39    pub(crate) init_stream_window_size: Option<u32>,
40    pub(crate) init_connection_window_size: Option<u32>,
41    pub(crate) tcp_keepalive: Option<Duration>,
42    pub(crate) tcp_nodelay: bool,
43    pub(crate) http2_keep_alive_interval: Option<Duration>,
44    pub(crate) http2_keep_alive_timeout: Option<Duration>,
45    pub(crate) http2_keep_alive_while_idle: Option<bool>,
46    pub(crate) http2_max_header_list_size: Option<u32>,
47    pub(crate) connect_timeout: Option<Duration>,
48    pub(crate) http2_adaptive_window: Option<bool>,
49    pub(crate) local_address: Option<IpAddr>,
50    pub(crate) executor: SharedExec,
51}
52
53impl Endpoint {
54    // FIXME: determine if we want to expose this or not. This is really
55    // just used in codegen for a shortcut.
56    #[doc(hidden)]
57    pub fn new<D>(dst: D) -> Result<Self, Error>
58    where
59        D: TryInto<Self>,
60        D::Error: Into<crate::BoxError>,
61    {
62        let me = dst.try_into().map_err(|e| Error::from_source(e.into()))?;
63        #[cfg(feature = "_tls-any")]
64        if let EndpointType::Uri(uri) = &me.uri {
65            if me.tls.is_none() && uri.scheme() == Some(&http::uri::Scheme::HTTPS) {
66                return me.tls_config(ClientTlsConfig::new().with_enabled_roots());
67            }
68        }
69        Ok(me)
70    }
71
72    fn new_uri(uri: Uri) -> Self {
73        Self {
74            uri: EndpointType::Uri(uri.clone()),
75            fallback_uri: uri,
76            origin: None,
77            user_agent: None,
78            concurrency_limit: None,
79            rate_limit: None,
80            timeout: None,
81            #[cfg(feature = "_tls-any")]
82            tls: None,
83            buffer_size: None,
84            init_stream_window_size: None,
85            init_connection_window_size: None,
86            tcp_keepalive: None,
87            tcp_nodelay: true,
88            http2_keep_alive_interval: None,
89            http2_keep_alive_timeout: None,
90            http2_keep_alive_while_idle: None,
91            http2_max_header_list_size: None,
92            connect_timeout: None,
93            http2_adaptive_window: None,
94            executor: SharedExec::tokio(),
95            local_address: None,
96        }
97    }
98
99    fn new_uds(uds_filepath: &str) -> Self {
100        Self {
101            uri: EndpointType::Uds(uds_filepath.to_string()),
102            fallback_uri: Uri::from_static("http://tonic"),
103            origin: None,
104            user_agent: None,
105            concurrency_limit: None,
106            rate_limit: None,
107            timeout: None,
108            #[cfg(feature = "_tls-any")]
109            tls: None,
110            buffer_size: None,
111            init_stream_window_size: None,
112            init_connection_window_size: None,
113            tcp_keepalive: None,
114            tcp_nodelay: true,
115            http2_keep_alive_interval: None,
116            http2_keep_alive_timeout: None,
117            http2_keep_alive_while_idle: None,
118            http2_max_header_list_size: None,
119            connect_timeout: None,
120            http2_adaptive_window: None,
121            executor: SharedExec::tokio(),
122            local_address: None,
123        }
124    }
125
126    /// Convert an `Endpoint` from a static string.
127    ///
128    /// # Panics
129    ///
130    /// This function panics if the argument is an invalid URI.
131    ///
132    /// ```
133    /// # use tonic::transport::Endpoint;
134    /// Endpoint::from_static("https://example.com");
135    /// ```
136    pub fn from_static(s: &'static str) -> Self {
137        if s.starts_with("unix:") {
138            let uds_filepath = s
139                .strip_prefix("unix://")
140                .or_else(|| s.strip_prefix("unix:"))
141                .expect("Invalid unix domain socket URI");
142            Self::new_uds(uds_filepath)
143        } else {
144            let uri = Uri::from_static(s);
145            Self::new_uri(uri)
146        }
147    }
148
149    /// Convert an `Endpoint` from shared bytes.
150    ///
151    /// ```
152    /// # use tonic::transport::Endpoint;
153    /// Endpoint::from_shared("https://example.com".to_string());
154    /// ```
155    pub fn from_shared(s: impl Into<Bytes>) -> Result<Self, Error> {
156        let s = str::from_utf8(&s.into())
157            .map_err(|e| Error::new_invalid_uri().with(e))?
158            .to_string();
159        if s.starts_with("unix:") {
160            let uds_filepath = s
161                .strip_prefix("unix://")
162                .or_else(|| s.strip_prefix("unix:"))
163                .ok_or(Error::new_invalid_uri())?;
164            Ok(Self::new_uds(uds_filepath))
165        } else {
166            let uri = Uri::from_maybe_shared(s).map_err(|e| Error::new_invalid_uri().with(e))?;
167            Ok(Self::from(uri))
168        }
169    }
170
171    /// Set a custom user-agent header.
172    ///
173    /// `user_agent` will be prepended to Tonic's default user-agent string (`tonic/x.x.x`).
174    /// It must be a value that can be converted into a valid  `http::HeaderValue` or building
175    /// the endpoint will fail.
176    /// ```
177    /// # use tonic::transport::Endpoint;
178    /// # let mut builder = Endpoint::from_static("https://example.com");
179    /// builder.user_agent("Greeter").expect("Greeter should be a valid header value");
180    /// // user-agent: "Greeter tonic/x.x.x"
181    /// ```
182    pub fn user_agent<T>(self, user_agent: T) -> Result<Self, Error>
183    where
184        T: TryInto<HeaderValue>,
185    {
186        user_agent
187            .try_into()
188            .map(|ua| Endpoint {
189                user_agent: Some(ua),
190                ..self
191            })
192            .map_err(|_| Error::new_invalid_user_agent())
193    }
194
195    /// Set a custom origin.
196    ///
197    /// Override the `origin`, mainly useful when you are reaching a Server/LoadBalancer
198    /// which serves multiple services at the same time.
199    /// It will play the role of SNI (Server Name Indication).
200    ///
201    /// ```
202    /// # use tonic::transport::Endpoint;
203    /// # let mut builder = Endpoint::from_static("https://proxy.com");
204    /// builder.origin("https://example.com".parse().expect("http://example.com must be a valid URI"));
205    /// // origin: "https://example.com"
206    /// ```
207    pub fn origin(self, origin: Uri) -> Self {
208        Endpoint {
209            origin: Some(origin),
210            ..self
211        }
212    }
213
214    /// Apply a timeout to each request.
215    ///
216    /// ```
217    /// # use tonic::transport::Endpoint;
218    /// # use std::time::Duration;
219    /// # let mut builder = Endpoint::from_static("https://example.com");
220    /// builder.timeout(Duration::from_secs(5));
221    /// ```
222    ///
223    /// # Notes
224    ///
225    /// This does **not** set the timeout metadata (`grpc-timeout` header) on
226    /// the request, meaning the server will not be informed of this timeout,
227    /// for that use [`Request::set_timeout`].
228    ///
229    /// [`Request::set_timeout`]: crate::Request::set_timeout
230    pub fn timeout(self, dur: Duration) -> Self {
231        Endpoint {
232            timeout: Some(dur),
233            ..self
234        }
235    }
236
237    /// Apply a timeout to connecting to the uri.
238    ///
239    /// Defaults to no timeout.
240    ///
241    /// ```
242    /// # use tonic::transport::Endpoint;
243    /// # use std::time::Duration;
244    /// # let mut builder = Endpoint::from_static("https://example.com");
245    /// builder.connect_timeout(Duration::from_secs(5));
246    /// ```
247    pub fn connect_timeout(self, dur: Duration) -> Self {
248        Endpoint {
249            connect_timeout: Some(dur),
250            ..self
251        }
252    }
253
254    /// Set whether TCP keepalive messages are enabled on accepted connections.
255    ///
256    /// If `None` is specified, keepalive is disabled, otherwise the duration
257    /// specified will be the time to remain idle before sending TCP keepalive
258    /// probes.
259    ///
260    /// Default is no keepalive (`None`)
261    ///
262    pub fn tcp_keepalive(self, tcp_keepalive: Option<Duration>) -> Self {
263        Endpoint {
264            tcp_keepalive,
265            ..self
266        }
267    }
268
269    /// Apply a concurrency limit to each request.
270    ///
271    /// ```
272    /// # use tonic::transport::Endpoint;
273    /// # let mut builder = Endpoint::from_static("https://example.com");
274    /// builder.concurrency_limit(256);
275    /// ```
276    pub fn concurrency_limit(self, limit: usize) -> Self {
277        Endpoint {
278            concurrency_limit: Some(limit),
279            ..self
280        }
281    }
282
283    /// Apply a rate limit to each request.
284    ///
285    /// ```
286    /// # use tonic::transport::Endpoint;
287    /// # use std::time::Duration;
288    /// # let mut builder = Endpoint::from_static("https://example.com");
289    /// builder.rate_limit(32, Duration::from_secs(1));
290    /// ```
291    pub fn rate_limit(self, limit: u64, duration: Duration) -> Self {
292        Endpoint {
293            rate_limit: Some((limit, duration)),
294            ..self
295        }
296    }
297
298    /// Sets the [`SETTINGS_INITIAL_WINDOW_SIZE`][spec] option for HTTP2
299    /// stream-level flow control.
300    ///
301    /// Default is 65,535
302    ///
303    /// [spec]: https://httpwg.org/specs/rfc9113.html#InitialWindowSize
304    pub fn initial_stream_window_size(self, sz: impl Into<Option<u32>>) -> Self {
305        Endpoint {
306            init_stream_window_size: sz.into(),
307            ..self
308        }
309    }
310
311    /// Sets the max connection-level flow control for HTTP2
312    ///
313    /// Default is 65,535
314    pub fn initial_connection_window_size(self, sz: impl Into<Option<u32>>) -> Self {
315        Endpoint {
316            init_connection_window_size: sz.into(),
317            ..self
318        }
319    }
320
321    /// Sets the tower service default internal buffer size
322    ///
323    /// Default is 1024
324    pub fn buffer_size(self, sz: impl Into<Option<usize>>) -> Self {
325        Endpoint {
326            buffer_size: sz.into(),
327            ..self
328        }
329    }
330
331    /// Configures TLS for the endpoint.
332    #[cfg(feature = "_tls-any")]
333    pub fn tls_config(self, tls_config: ClientTlsConfig) -> Result<Self, Error> {
334        match &self.uri {
335            EndpointType::Uri(uri) => Ok(Endpoint {
336                tls: Some(
337                    tls_config
338                        .into_tls_connector(uri)
339                        .map_err(Error::from_source)?,
340                ),
341                ..self
342            }),
343            EndpointType::Uds(_) => Err(Error::new(error::Kind::InvalidTlsConfigForUds)),
344        }
345    }
346
347    /// Set the value of `TCP_NODELAY` option for accepted connections. Enabled by default.
348    pub fn tcp_nodelay(self, enabled: bool) -> Self {
349        Endpoint {
350            tcp_nodelay: enabled,
351            ..self
352        }
353    }
354
355    /// Set http2 KEEP_ALIVE_INTERVAL. Uses `hyper`'s default otherwise.
356    pub fn http2_keep_alive_interval(self, interval: Duration) -> Self {
357        Endpoint {
358            http2_keep_alive_interval: Some(interval),
359            ..self
360        }
361    }
362
363    /// Set http2 KEEP_ALIVE_TIMEOUT. Uses `hyper`'s default otherwise.
364    pub fn keep_alive_timeout(self, duration: Duration) -> Self {
365        Endpoint {
366            http2_keep_alive_timeout: Some(duration),
367            ..self
368        }
369    }
370
371    /// Set http2 KEEP_ALIVE_WHILE_IDLE. Uses `hyper`'s default otherwise.
372    pub fn keep_alive_while_idle(self, enabled: bool) -> Self {
373        Endpoint {
374            http2_keep_alive_while_idle: Some(enabled),
375            ..self
376        }
377    }
378
379    /// Sets whether to use an adaptive flow control. Uses `hyper`'s default otherwise.
380    pub fn http2_adaptive_window(self, enabled: bool) -> Self {
381        Endpoint {
382            http2_adaptive_window: Some(enabled),
383            ..self
384        }
385    }
386
387    /// Sets the max size of received header frames.
388    ///
389    /// This will default to whatever the default in hyper is. As of v1.4.1, it is 16 KiB.
390    pub fn http2_max_header_list_size(self, size: u32) -> Self {
391        Endpoint {
392            http2_max_header_list_size: Some(size),
393            ..self
394        }
395    }
396
397    /// Sets the executor used to spawn async tasks.
398    ///
399    /// Uses `tokio::spawn` by default.
400    pub fn executor<E>(mut self, executor: E) -> Self
401    where
402        E: Executor<Pin<Box<dyn Future<Output = ()> + Send>>> + Send + Sync + 'static,
403    {
404        self.executor = SharedExec::new(executor);
405        self
406    }
407
408    pub(crate) fn connector<C>(&self, c: C) -> service::Connector<C> {
409        service::Connector::new(
410            c,
411            #[cfg(feature = "_tls-any")]
412            self.tls.clone(),
413        )
414    }
415
416    /// Set the local address.
417    ///
418    /// This sets the IP address the client will use. By default we let hyper select the IP address.
419    pub fn local_address(self, addr: Option<IpAddr>) -> Self {
420        Endpoint {
421            local_address: addr,
422            ..self
423        }
424    }
425
426    pub(crate) fn http_connector(&self) -> service::Connector<HttpConnector> {
427        let mut http = HttpConnector::new();
428        http.enforce_http(false);
429        http.set_nodelay(self.tcp_nodelay);
430        http.set_keepalive(self.tcp_keepalive);
431        http.set_connect_timeout(self.connect_timeout);
432        http.set_local_address(self.local_address);
433        self.connector(http)
434    }
435
436    pub(crate) fn uds_connector(&self, uds_filepath: &str) -> service::Connector<UdsConnector> {
437        self.connector(UdsConnector::new(uds_filepath))
438    }
439
440    /// Create a channel from this config.
441    pub async fn connect(&self) -> Result<Channel, Error> {
442        match &self.uri {
443            EndpointType::Uri(_) => Channel::connect(self.http_connector(), self.clone()).await,
444            EndpointType::Uds(uds_filepath) => {
445                Channel::connect(self.uds_connector(uds_filepath.as_str()), self.clone()).await
446            }
447        }
448    }
449
450    /// Create a channel from this config.
451    ///
452    /// The channel returned by this method does not attempt to connect to the endpoint until first
453    /// use.
454    pub fn connect_lazy(&self) -> Channel {
455        match &self.uri {
456            EndpointType::Uri(_) => Channel::new(self.http_connector(), self.clone()),
457            EndpointType::Uds(uds_filepath) => {
458                Channel::new(self.uds_connector(uds_filepath.as_str()), self.clone())
459            }
460        }
461    }
462
463    /// Connect with a custom connector.
464    ///
465    /// This allows you to build a [Channel](struct.Channel.html) that uses a non-HTTP transport.
466    /// See the `uds` example for an example on how to use this function to build channel that
467    /// uses a Unix socket transport.
468    ///
469    /// The [`connect_timeout`](Endpoint::connect_timeout) will still be applied.
470    pub async fn connect_with_connector<C>(&self, connector: C) -> Result<Channel, Error>
471    where
472        C: Service<Uri> + Send + 'static,
473        C::Response: rt::Read + rt::Write + Send + Unpin,
474        C::Future: Send,
475        crate::BoxError: From<C::Error> + Send,
476    {
477        let connector = self.connector(connector);
478
479        if let Some(connect_timeout) = self.connect_timeout {
480            let mut connector = hyper_timeout::TimeoutConnector::new(connector);
481            connector.set_connect_timeout(Some(connect_timeout));
482            Channel::connect(connector, self.clone()).await
483        } else {
484            Channel::connect(connector, self.clone()).await
485        }
486    }
487
488    /// Connect with a custom connector lazily.
489    ///
490    /// This allows you to build a [Channel](struct.Channel.html) that uses a non-HTTP transport
491    /// connect to it lazily.
492    ///
493    /// See the `uds` example for an example on how to use this function to build channel that
494    /// uses a Unix socket transport.
495    pub fn connect_with_connector_lazy<C>(&self, connector: C) -> Channel
496    where
497        C: Service<Uri> + Send + 'static,
498        C::Response: rt::Read + rt::Write + Send + Unpin,
499        C::Future: Send,
500        crate::BoxError: From<C::Error> + Send,
501    {
502        let connector = self.connector(connector);
503        if let Some(connect_timeout) = self.connect_timeout {
504            let mut connector = hyper_timeout::TimeoutConnector::new(connector);
505            connector.set_connect_timeout(Some(connect_timeout));
506            Channel::new(connector, self.clone())
507        } else {
508            Channel::new(connector, self.clone())
509        }
510    }
511
512    /// Get the endpoint uri.
513    ///
514    /// ```
515    /// # use tonic::transport::Endpoint;
516    /// # use http::Uri;
517    /// let endpoint = Endpoint::from_static("https://example.com");
518    ///
519    /// assert_eq!(endpoint.uri(), &Uri::from_static("https://example.com"));
520    /// ```
521    pub fn uri(&self) -> &Uri {
522        match &self.uri {
523            EndpointType::Uri(uri) => uri,
524            EndpointType::Uds(_) => &self.fallback_uri,
525        }
526    }
527
528    /// Get the value of `TCP_NODELAY` option for accepted connections.
529    pub fn get_tcp_nodelay(&self) -> bool {
530        self.tcp_nodelay
531    }
532
533    /// Get the connect timeout.
534    pub fn get_connect_timeout(&self) -> Option<Duration> {
535        self.connect_timeout
536    }
537
538    /// Get whether TCP keepalive messages are enabled on accepted connections.
539    ///
540    /// If `None` is specified, keepalive is disabled, otherwise the duration
541    /// specified will be the time to remain idle before sending TCP keepalive
542    /// probes.
543    pub fn get_tcp_keepalive(&self) -> Option<Duration> {
544        self.tcp_keepalive
545    }
546}
547
548impl From<Uri> for Endpoint {
549    fn from(uri: Uri) -> Self {
550        Self::new_uri(uri)
551    }
552}
553
554impl TryFrom<Bytes> for Endpoint {
555    type Error = Error;
556
557    fn try_from(t: Bytes) -> Result<Self, Self::Error> {
558        Self::from_shared(t)
559    }
560}
561
562impl TryFrom<String> for Endpoint {
563    type Error = Error;
564
565    fn try_from(t: String) -> Result<Self, Self::Error> {
566        Self::from_shared(t.into_bytes())
567    }
568}
569
570impl TryFrom<&'static str> for Endpoint {
571    type Error = Error;
572
573    fn try_from(t: &'static str) -> Result<Self, Self::Error> {
574        Self::from_shared(t.as_bytes())
575    }
576}
577
578impl fmt::Debug for Endpoint {
579    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
580        f.debug_struct("Endpoint").finish()
581    }
582}
583
584impl FromStr for Endpoint {
585    type Err = Error;
586
587    fn from_str(s: &str) -> Result<Self, Self::Err> {
588        Self::try_from(s.to_string())
589    }
590}