tonic_rustls/channel/
endpoint.rs

1#[cfg(feature = "tls")]
2use super::service::TlsConnector;
3use super::service::{self, Executor, SharedExec};
4use super::Channel;
5use crate::Error;
6use bytes::Bytes;
7use http::{uri::Uri, HeaderValue};
8use hyper::rt;
9use hyper_util::client::legacy::connect::HttpConnector;
10use std::{fmt, future::Future, pin::Pin, str::FromStr, time::Duration};
11use tower_service::Service;
12
13/// Channel builder.
14///
15/// This struct is used to build and configure HTTP/2 channels.
16#[derive(Clone)]
17pub struct Endpoint {
18    pub(crate) uri: Uri,
19    pub(crate) origin: Option<Uri>,
20    pub(crate) user_agent: Option<HeaderValue>,
21    pub(crate) timeout: Option<Duration>,
22    pub(crate) concurrency_limit: Option<usize>,
23    pub(crate) rate_limit: Option<(u64, Duration)>,
24    #[cfg(feature = "tls")]
25    pub(crate) tls: Option<TlsConnector>,
26    pub(crate) buffer_size: Option<usize>,
27    pub(crate) init_stream_window_size: Option<u32>,
28    pub(crate) init_connection_window_size: Option<u32>,
29    pub(crate) tcp_keepalive: Option<Duration>,
30    pub(crate) tcp_nodelay: bool,
31    pub(crate) http2_keep_alive_interval: Option<Duration>,
32    pub(crate) http2_keep_alive_timeout: Option<Duration>,
33    pub(crate) http2_keep_alive_while_idle: Option<bool>,
34    pub(crate) http2_max_header_list_size: Option<u32>,
35    pub(crate) connect_timeout: Option<Duration>,
36    pub(crate) http2_adaptive_window: Option<bool>,
37    pub(crate) executor: SharedExec,
38}
39
40impl Endpoint {
41    // FIXME: determine if we want to expose this or not. This is really
42    // just used in codegen for a shortcut.
43    #[doc(hidden)]
44    pub fn new<D>(dst: D) -> Result<Self, Error>
45    where
46        D: TryInto<Self>,
47        D::Error: Into<crate::BoxError>,
48    {
49        let me = dst.try_into().map_err(|e| Error::from_source(e.into()))?;
50        //TODO Should we provide a default or force users to provide a TLS config?
51        // #[cfg(feature = "tls")]
52        // if me.uri.scheme() == Some(&http::uri::Scheme::HTTPS) {
53        //     return me.tls_config(ClientTlsConfig::new().with_enabled_roots());
54        // }
55
56        Ok(me)
57    }
58
59    /// Convert an `Endpoint` from a static string.
60    ///
61    /// # Panics
62    ///
63    /// This function panics if the argument is an invalid URI.
64    ///
65    /// ```
66    /// # use tonic_rustls::Endpoint;
67    /// Endpoint::from_static("https://example.com");
68    /// ```
69    pub fn from_static(s: &'static str) -> Self {
70        let uri = Uri::from_static(s);
71        Self::from(uri)
72    }
73
74    /// Convert an `Endpoint` from shared bytes.
75    ///
76    /// ```
77    /// # use tonic_rustls::Endpoint;
78    /// Endpoint::from_shared("https://example.com".to_string());
79    /// ```
80    pub fn from_shared(s: impl Into<Bytes>) -> Result<Self, Error> {
81        let uri = Uri::from_maybe_shared(s.into()).map_err(|e| Error::new_invalid_uri().with(e))?;
82        Ok(Self::from(uri))
83    }
84
85    /// Set a custom user-agent header.
86    ///
87    /// `user_agent` will be prepended to Tonic's default user-agent string (`tonic/x.x.x`).
88    /// It must be a value that can be converted into a valid  `http::HeaderValue` or building
89    /// the endpoint will fail.
90    /// ```
91    /// # use tonic_rustls::Endpoint;
92    /// # let mut builder = Endpoint::from_static("https://example.com");
93    /// builder.user_agent("Greeter").expect("Greeter should be a valid header value");
94    /// // user-agent: "Greeter tonic/x.x.x"
95    /// ```
96    pub fn user_agent<T>(self, user_agent: T) -> Result<Self, Error>
97    where
98        T: TryInto<HeaderValue>,
99    {
100        user_agent
101            .try_into()
102            .map(|ua| Endpoint {
103                user_agent: Some(ua),
104                ..self
105            })
106            .map_err(|_| Error::new_invalid_user_agent())
107    }
108
109    /// Set a custom origin.
110    ///
111    /// Override the `origin`, mainly useful when you are reaching a Server/LoadBalancer
112    /// which serves multiple services at the same time.
113    /// It will play the role of SNI (Server Name Indication).
114    ///
115    /// ```
116    /// # use tonic_rustls::Endpoint;
117    /// # let mut builder = Endpoint::from_static("https://proxy.com");
118    /// builder.origin("https://example.com".parse().expect("http://example.com must be a valid URI"));
119    /// // origin: "https://example.com"
120    /// ```
121    pub fn origin(self, origin: Uri) -> Self {
122        Endpoint {
123            origin: Some(origin),
124            ..self
125        }
126    }
127
128    /// Apply a timeout to each request.
129    ///
130    /// ```
131    /// # use tonic_rustls::Endpoint;
132    /// # use std::time::Duration;
133    /// # let mut builder = Endpoint::from_static("https://example.com");
134    /// builder.timeout(Duration::from_secs(5));
135    /// ```
136    ///
137    /// # Notes
138    ///
139    /// This does **not** set the timeout metadata (`grpc-timeout` header) on
140    /// the request, meaning the server will not be informed of this timeout,
141    /// for that use [`Request::set_timeout`].
142    ///
143    /// [`Request::set_timeout`]: crate::Request::set_timeout
144    pub fn timeout(self, dur: Duration) -> Self {
145        Endpoint {
146            timeout: Some(dur),
147            ..self
148        }
149    }
150
151    /// Apply a timeout to connecting to the uri.
152    ///
153    /// Defaults to no timeout.
154    ///
155    /// ```
156    /// # use tonic_rustls::Endpoint;
157    /// # use std::time::Duration;
158    /// # let mut builder = Endpoint::from_static("https://example.com");
159    /// builder.connect_timeout(Duration::from_secs(5));
160    /// ```
161    pub fn connect_timeout(self, dur: Duration) -> Self {
162        Endpoint {
163            connect_timeout: Some(dur),
164            ..self
165        }
166    }
167
168    /// Set whether TCP keepalive messages are enabled on accepted connections.
169    ///
170    /// If `None` is specified, keepalive is disabled, otherwise the duration
171    /// specified will be the time to remain idle before sending TCP keepalive
172    /// probes.
173    ///
174    /// Default is no keepalive (`None`)
175    ///
176    pub fn tcp_keepalive(self, tcp_keepalive: Option<Duration>) -> Self {
177        Endpoint {
178            tcp_keepalive,
179            ..self
180        }
181    }
182
183    /// Apply a concurrency limit to each request.
184    ///
185    /// ```
186    /// # use tonic_rustls::Endpoint;
187    /// # let mut builder = Endpoint::from_static("https://example.com");
188    /// builder.concurrency_limit(256);
189    /// ```
190    pub fn concurrency_limit(self, limit: usize) -> Self {
191        Endpoint {
192            concurrency_limit: Some(limit),
193            ..self
194        }
195    }
196
197    /// Apply a rate limit to each request.
198    ///
199    /// ```
200    /// # use tonic_rustls::Endpoint;
201    /// # use std::time::Duration;
202    /// # let mut builder = Endpoint::from_static("https://example.com");
203    /// builder.rate_limit(32, Duration::from_secs(1));
204    /// ```
205    pub fn rate_limit(self, limit: u64, duration: Duration) -> Self {
206        Endpoint {
207            rate_limit: Some((limit, duration)),
208            ..self
209        }
210    }
211
212    /// Sets the [`SETTINGS_INITIAL_WINDOW_SIZE`][spec] option for HTTP2
213    /// stream-level flow control.
214    ///
215    /// Default is 65,535
216    ///
217    /// [spec]: https://httpwg.org/specs/rfc9113.html#InitialWindowSize
218    pub fn initial_stream_window_size(self, sz: impl Into<Option<u32>>) -> Self {
219        Endpoint {
220            init_stream_window_size: sz.into(),
221            ..self
222        }
223    }
224
225    /// Sets the max connection-level flow control for HTTP2
226    ///
227    /// Default is 65,535
228    pub fn initial_connection_window_size(self, sz: impl Into<Option<u32>>) -> Self {
229        Endpoint {
230            init_connection_window_size: sz.into(),
231            ..self
232        }
233    }
234
235    /// Sets the tower service default internal buffer size
236    ///
237    /// Default is 1024
238    pub fn buffer_size(self, sz: impl Into<Option<usize>>) -> Self {
239        Endpoint {
240            buffer_size: sz.into(),
241            ..self
242        }
243    }
244
245    /// Configures TLS for the endpoint.
246    #[cfg(feature = "tls")]
247    pub fn tls_config(self, tls_config: tokio_rustls::rustls::ClientConfig) -> Result<Self, Error> {
248        let domain = self.uri.host().ok_or_else(Error::new_invalid_uri)?;
249        let tls_connector =
250            TlsConnector::new(tls_config, domain, false).map_err(Error::from_source)?;
251
252        Ok(Endpoint {
253            tls: Some(tls_connector),
254            ..self
255        })
256    }
257
258    /// Set the value of `TCP_NODELAY` option for accepted connections. Enabled by default.
259    pub fn tcp_nodelay(self, enabled: bool) -> Self {
260        Endpoint {
261            tcp_nodelay: enabled,
262            ..self
263        }
264    }
265
266    /// Set http2 KEEP_ALIVE_INTERVAL. Uses `hyper`'s default otherwise.
267    pub fn http2_keep_alive_interval(self, interval: Duration) -> Self {
268        Endpoint {
269            http2_keep_alive_interval: Some(interval),
270            ..self
271        }
272    }
273
274    /// Set http2 KEEP_ALIVE_TIMEOUT. Uses `hyper`'s default otherwise.
275    pub fn keep_alive_timeout(self, duration: Duration) -> Self {
276        Endpoint {
277            http2_keep_alive_timeout: Some(duration),
278            ..self
279        }
280    }
281
282    /// Set http2 KEEP_ALIVE_WHILE_IDLE. Uses `hyper`'s default otherwise.
283    pub fn keep_alive_while_idle(self, enabled: bool) -> Self {
284        Endpoint {
285            http2_keep_alive_while_idle: Some(enabled),
286            ..self
287        }
288    }
289
290    /// Sets whether to use an adaptive flow control. Uses `hyper`'s default otherwise.
291    pub fn http2_adaptive_window(self, enabled: bool) -> Self {
292        Endpoint {
293            http2_adaptive_window: Some(enabled),
294            ..self
295        }
296    }
297
298    /// Sets the max size of received header frames.
299    ///
300    /// This will default to whatever the default in hyper is. As of v1.4.1, it is 16 KiB.
301    pub fn http2_max_header_list_size(self, size: u32) -> Self {
302        Endpoint {
303            http2_max_header_list_size: Some(size),
304            ..self
305        }
306    }
307
308    /// Sets the executor used to spawn async tasks.
309    ///
310    /// Uses `tokio::spawn` by default.
311    pub fn executor<E>(mut self, executor: E) -> Self
312    where
313        E: Executor<Pin<Box<dyn Future<Output = ()> + Send>>> + Send + Sync + 'static,
314    {
315        self.executor = SharedExec::new(executor);
316        self
317    }
318
319    pub(crate) fn connector<C>(&self, c: C) -> service::Connector<C> {
320        service::Connector::new(
321            c,
322            #[cfg(feature = "tls")]
323            self.tls.clone(),
324        )
325    }
326
327    /// Create a channel from this config.
328    pub async fn connect(&self) -> Result<Channel, Error> {
329        let mut http = HttpConnector::new();
330        http.enforce_http(false);
331        http.set_nodelay(self.tcp_nodelay);
332        http.set_keepalive(self.tcp_keepalive);
333        http.set_connect_timeout(self.connect_timeout);
334
335        let connector = self.connector(http);
336
337        Channel::connect(connector, self.clone()).await
338    }
339
340    /// Create a channel from this config.
341    ///
342    /// The channel returned by this method does not attempt to connect to the endpoint until first
343    /// use.
344    pub fn connect_lazy(&self) -> Channel {
345        let mut http = HttpConnector::new();
346        http.enforce_http(false);
347        http.set_nodelay(self.tcp_nodelay);
348        http.set_keepalive(self.tcp_keepalive);
349        http.set_connect_timeout(self.connect_timeout);
350
351        let connector = self.connector(http);
352
353        Channel::new(connector, self.clone())
354    }
355
356    /// Connect with a custom connector.
357    ///
358    /// This allows you to build a [Channel](struct.Channel.html) that uses a non-HTTP transport.
359    /// See the `uds` example for an example on how to use this function to build channel that
360    /// uses a Unix socket transport.
361    ///
362    /// The [`connect_timeout`](Endpoint::connect_timeout) will still be applied.
363    pub async fn connect_with_connector<C>(&self, connector: C) -> Result<Channel, Error>
364    where
365        C: Service<Uri> + Send + 'static,
366        C::Response: rt::Read + rt::Write + Send + Unpin,
367        C::Future: Send,
368        crate::BoxError: From<C::Error> + Send,
369    {
370        let connector = self.connector(connector);
371
372        if let Some(connect_timeout) = self.connect_timeout {
373            let mut connector = hyper_timeout::TimeoutConnector::new(connector);
374            connector.set_connect_timeout(Some(connect_timeout));
375            Channel::connect(connector, self.clone()).await
376        } else {
377            Channel::connect(connector, self.clone()).await
378        }
379    }
380
381    /// Connect with a custom connector lazily.
382    ///
383    /// This allows you to build a [Channel](struct.Channel.html) that uses a non-HTTP transport
384    /// connect to it lazily.
385    ///
386    /// See the `uds` example for an example on how to use this function to build channel that
387    /// uses a Unix socket transport.
388    pub fn connect_with_connector_lazy<C>(&self, connector: C) -> Channel
389    where
390        C: Service<Uri> + Send + 'static,
391        C::Response: rt::Read + rt::Write + Send + Unpin,
392        C::Future: Send,
393        crate::BoxError: From<C::Error> + Send,
394    {
395        let connector = self.connector(connector);
396        if let Some(connect_timeout) = self.connect_timeout {
397            let mut connector = hyper_timeout::TimeoutConnector::new(connector);
398            connector.set_connect_timeout(Some(connect_timeout));
399            Channel::new(connector, self.clone())
400        } else {
401            Channel::new(connector, self.clone())
402        }
403    }
404
405    /// Get the endpoint uri.
406    ///
407    /// ```
408    /// # use tonic_rustls::Endpoint;
409    /// # use http::Uri;
410    /// let endpoint = Endpoint::from_static("https://example.com");
411    ///
412    /// assert_eq!(endpoint.uri(), &Uri::from_static("https://example.com"));
413    /// ```
414    pub fn uri(&self) -> &Uri {
415        &self.uri
416    }
417
418    /// Get the value of `TCP_NODELAY` option for accepted connections.
419    pub fn get_tcp_nodelay(&self) -> bool {
420        self.tcp_nodelay
421    }
422
423    /// Get the connect timeout.
424    pub fn get_connect_timeout(&self) -> Option<Duration> {
425        self.connect_timeout
426    }
427
428    /// Get whether TCP keepalive messages are enabled on accepted connections.
429    ///
430    /// If `None` is specified, keepalive is disabled, otherwise the duration
431    /// specified will be the time to remain idle before sending TCP keepalive
432    /// probes.
433    pub fn get_tcp_keepalive(&self) -> Option<Duration> {
434        self.tcp_keepalive
435    }
436}
437
438impl From<Uri> for Endpoint {
439    fn from(uri: Uri) -> Self {
440        Self {
441            uri,
442            origin: None,
443            user_agent: None,
444            concurrency_limit: None,
445            rate_limit: None,
446            timeout: None,
447            #[cfg(feature = "tls")]
448            tls: None,
449            buffer_size: None,
450            init_stream_window_size: None,
451            init_connection_window_size: None,
452            tcp_keepalive: None,
453            tcp_nodelay: true,
454            http2_keep_alive_interval: None,
455            http2_keep_alive_timeout: None,
456            http2_keep_alive_while_idle: None,
457            http2_max_header_list_size: None,
458            connect_timeout: None,
459            http2_adaptive_window: None,
460            executor: SharedExec::tokio(),
461        }
462    }
463}
464
465impl TryFrom<Bytes> for Endpoint {
466    type Error = Error;
467
468    fn try_from(t: Bytes) -> Result<Self, Self::Error> {
469        Self::from_shared(t)
470    }
471}
472
473impl TryFrom<String> for Endpoint {
474    type Error = Error;
475
476    fn try_from(t: String) -> Result<Self, Self::Error> {
477        Self::from_shared(t.into_bytes())
478    }
479}
480
481impl TryFrom<&'static str> for Endpoint {
482    type Error = Error;
483
484    fn try_from(t: &'static str) -> Result<Self, Self::Error> {
485        Self::from_shared(t.as_bytes())
486    }
487}
488
489impl fmt::Debug for Endpoint {
490    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
491        f.debug_struct("Endpoint").finish()
492    }
493}
494
495impl FromStr for Endpoint {
496    type Err = Error;
497
498    fn from_str(s: &str) -> Result<Self, Self::Err> {
499        Self::try_from(s.to_string())
500    }
501}