soda_pool/
endpoint_template.rs

1use core::fmt;
2use http::HeaderValue;
3use std::error::Error;
4use std::fmt::{Debug, Display};
5use std::{net::IpAddr, str::FromStr, time::Duration};
6#[cfg(feature = "tls")]
7use tonic::transport::ClientTlsConfig;
8use tonic::transport::{Endpoint, Uri};
9use url::Host;
10use url::Url;
11
12/// Template for creating [`Endpoint`]s.
13///
14/// This structure is used to store all the information necessary to create an [`Endpoint`].
15/// It then creates an [`Endpoint`] to a specific IP address using the [`build`](EndpointTemplate::build) method.
16#[derive(Clone)]
17pub struct EndpointTemplate {
18    url: Url,
19    origin: Option<Uri>,
20    user_agent: Option<HeaderValue>,
21    timeout: Option<Duration>,
22    concurrency_limit: Option<usize>,
23    rate_limit: Option<(u64, Duration)>,
24    #[cfg(feature = "tls")]
25    tls_config: Option<ClientTlsConfig>,
26    buffer_size: Option<usize>,
27    init_stream_window_size: Option<u32>,
28    init_connection_window_size: Option<u32>,
29    tcp_keepalive: Option<Duration>,
30    tcp_nodelay: Option<bool>,
31    http2_keep_alive_interval: Option<Duration>,
32    http2_keep_alive_timeout: Option<Duration>,
33    http2_keep_alive_while_idle: Option<bool>,
34    http2_max_header_list_size: Option<u32>,
35    connect_timeout: Option<Duration>,
36    http2_adaptive_window: Option<bool>,
37    local_address: Option<IpAddr>,
38    // todo: If at all possible, support also setting the executor.
39}
40
41impl EndpointTemplate {
42    /// Creates a new `EndpointTemplate` from the provided URL.
43    ///
44    /// # Errors
45    /// - Will return [`EndpointTemplateError::NotAUrl`] if the provided URL is not a valid URL.
46    /// - Will return [`EndpointTemplateError::HostMissing`] if the provided URL does not contain a host.
47    /// - Will return [`EndpointTemplateError::AlreadyIpAddress`] if the provided URL already contains an IP address.
48    /// - Will return [`EndpointTemplateError::Inconvertible`] if the provided URL cannot be converted to the tonic's internal representation.
49    // Url requires a full Unicode support which, although correct, seems like
50    // an overkill for just substituting hostname with an IP address. Accepts
51    // any type that has a conversion to Url instead of just Url to limit
52    // breaking changes in the future if decide to use another type.
53    pub fn new(url: impl TryInto<Url>) -> Result<Self, EndpointTemplateError> {
54        let url: Url = url.try_into().map_err(|_| EndpointTemplateError::NotAUrl)?;
55
56        // Check if URL contains hostname that can be resolved with DNS
57        match url.host() {
58            Some(host) => match host {
59                Host::Domain(_) => {}
60                _ => return Err(EndpointTemplateError::AlreadyIpAddress),
61            },
62            None => return Err(EndpointTemplateError::HostMissing),
63        }
64
65        // Check if hostname in URL can be substituted by IP address
66        if url.cannot_be_a_base() {
67            // Since we have a host, I can't imagine an address that still
68            // couldn't be a base. If there is one, let's treat it as
69            // Inconvertible error for simplicity.
70            return Err(EndpointTemplateError::Inconvertible);
71        }
72
73        // Check if tonic Uri can be build from Url.
74        if Uri::from_str(url.as_str()).is_err() {
75            // It's hard to prove that any url::Url will also be parsable as
76            // tonic::transport::Uri, but in practice this error should never
77            // happen.
78            return Err(EndpointTemplateError::Inconvertible);
79        }
80
81        Ok(Self {
82            url,
83            origin: None,
84            user_agent: None,
85            timeout: None,
86            #[cfg(feature = "tls")]
87            tls_config: None,
88            concurrency_limit: None,
89            rate_limit: None,
90            buffer_size: None,
91            init_stream_window_size: None,
92            init_connection_window_size: None,
93            tcp_keepalive: None,
94            tcp_nodelay: None,
95            http2_keep_alive_interval: None,
96            http2_keep_alive_timeout: None,
97            http2_keep_alive_while_idle: None,
98            http2_max_header_list_size: None,
99            connect_timeout: None,
100            http2_adaptive_window: None,
101            local_address: None,
102        })
103    }
104
105    /// Builds an [`Endpoint`] to the IP address.
106    ///
107    /// This will substitute the hostname in the URL with the provided IP
108    /// address, create a new [`Endpoint`] from it, and apply all the settings
109    /// set in the builder.
110    #[allow(clippy::missing_panics_doc)]
111    pub fn build(&self, ip_address: impl Into<IpAddr>) -> Endpoint {
112        let mut endpoint = Endpoint::from(self.build_uri(ip_address.into()));
113
114        if let Some(origin) = self.origin.clone() {
115            endpoint = endpoint.origin(origin);
116        }
117
118        if let Some(user_agent) = self.user_agent.clone() {
119            endpoint = endpoint
120                .user_agent(user_agent)
121                .expect("already checked in the setter");
122        }
123
124        if let Some(timeout) = self.timeout {
125            endpoint = endpoint.timeout(timeout);
126        }
127
128        #[cfg(feature = "tls")]
129        if let Some(tls_config) = self.tls_config.clone() {
130            endpoint = endpoint
131                .tls_config(tls_config)
132                .expect("already checked in the setter");
133        }
134
135        if let Some(connect_timeout) = self.connect_timeout {
136            endpoint = endpoint.connect_timeout(connect_timeout);
137        }
138
139        endpoint = endpoint.tcp_keepalive(self.tcp_keepalive);
140
141        if let Some(limit) = self.concurrency_limit {
142            endpoint = endpoint.concurrency_limit(limit);
143        }
144
145        if let Some((limit, duration)) = self.rate_limit {
146            endpoint = endpoint.rate_limit(limit, duration);
147        }
148
149        if let Some(sz) = self.init_stream_window_size {
150            endpoint = endpoint.initial_stream_window_size(sz);
151        }
152
153        if let Some(sz) = self.init_connection_window_size {
154            endpoint = endpoint.initial_connection_window_size(sz);
155        }
156
157        endpoint = endpoint.buffer_size(self.buffer_size);
158
159        if let Some(tcp_nodelay) = self.tcp_nodelay {
160            endpoint = endpoint.tcp_nodelay(tcp_nodelay);
161        }
162
163        if let Some(interval) = self.http2_keep_alive_interval {
164            endpoint = endpoint.http2_keep_alive_interval(interval);
165        }
166
167        if let Some(duration) = self.http2_keep_alive_timeout {
168            endpoint = endpoint.keep_alive_timeout(duration);
169        }
170
171        if let Some(enabled) = self.http2_keep_alive_while_idle {
172            endpoint = endpoint.keep_alive_while_idle(enabled);
173        }
174
175        if let Some(enabled) = self.http2_adaptive_window {
176            endpoint = endpoint.http2_adaptive_window(enabled);
177        }
178
179        if let Some(size) = self.http2_max_header_list_size {
180            endpoint = endpoint.http2_max_header_list_size(size);
181        }
182
183        endpoint = endpoint.local_address(self.local_address);
184
185        endpoint
186    }
187
188    /// Returns the hostname of the URL held in the template.
189    #[allow(clippy::missing_panics_doc)]
190    pub fn domain(&self) -> &str {
191        self.url
192            .domain()
193            .expect("already checked in the constructor")
194    }
195
196    fn build_uri(&self, ip_addr: IpAddr) -> Uri {
197        // We make sure this conversion doesn't return any errors in Self::new
198        // already so it's safe to unwrap here.
199        let mut url = self.url.clone();
200        url.set_ip_host(ip_addr)
201            .expect("already checked in the constructor by trying cannot_be_a_base");
202        Uri::from_str(url.as_str()).expect("starting from Url, this should always be a valid Uri")
203    }
204
205    /// r.f. [`Endpoint::user_agent`].
206    ///
207    /// # Errors
208    ///
209    /// Will return [`EndpointTemplateError::InvalidUserAgent`] if the provided
210    /// value cannot be converted to a [`HeaderValue`] and would cause a failure
211    /// when building an endpoint.
212    pub fn user_agent(
213        self,
214        user_agent: impl TryInto<HeaderValue>,
215    ) -> Result<Self, EndpointTemplateError> {
216        user_agent
217            .try_into()
218            .map(|ua| Self {
219                user_agent: Some(ua),
220                ..self
221            })
222            .map_err(|_| EndpointTemplateError::InvalidUserAgent)
223    }
224
225    /// r.f. [`Endpoint::origin`].
226    #[must_use]
227    pub fn origin(self, origin: Uri) -> Self {
228        Self {
229            origin: Some(origin),
230            ..self
231        }
232    }
233
234    /// r.f. [`Endpoint::timeout`].
235    #[must_use]
236    pub fn timeout(self, dur: Duration) -> Self {
237        Self {
238            timeout: Some(dur),
239            ..self
240        }
241    }
242
243    /// r.f. [`Endpoint::connect_timeout`].
244    #[must_use]
245    pub fn connect_timeout(self, dur: Duration) -> Self {
246        Self {
247            connect_timeout: Some(dur),
248            ..self
249        }
250    }
251
252    /// r.f. [`Endpoint::tcp_keepalive`].
253    #[must_use]
254    pub fn tcp_keepalive(self, tcp_keepalive: Option<Duration>) -> Self {
255        Self {
256            tcp_keepalive,
257            ..self
258        }
259    }
260
261    /// r.f. [`Endpoint::concurrency_limit`]
262    #[must_use]
263    pub fn concurrency_limit(self, limit: usize) -> Self {
264        Self {
265            concurrency_limit: Some(limit),
266            ..self
267        }
268    }
269
270    /// r.f. [`Endpoint::rate_limit`].
271    #[must_use]
272    pub fn rate_limit(self, limit: u64, duration: Duration) -> Self {
273        Self {
274            rate_limit: Some((limit, duration)),
275            ..self
276        }
277    }
278
279    /// r.f. [`Endpoint::initial_stream_window_size`].
280    #[must_use]
281    pub fn initial_stream_window_size(self, sz: impl Into<Option<u32>>) -> Self {
282        Self {
283            init_stream_window_size: sz.into(),
284            ..self
285        }
286    }
287
288    /// r.f. [`Endpoint::initial_connection_window_size`].
289    #[must_use]
290    pub fn initial_connection_window_size(self, sz: impl Into<Option<u32>>) -> Self {
291        Self {
292            init_connection_window_size: sz.into(),
293            ..self
294        }
295    }
296
297    /// r.f. [`Endpoint::buffer_size`].
298    #[must_use]
299    pub fn buffer_size(self, sz: impl Into<Option<usize>>) -> Self {
300        Self {
301            buffer_size: sz.into(),
302            ..self
303        }
304    }
305
306    /// r.f. [`Endpoint::tls_config`].
307    ///
308    /// # Errors
309    ///
310    /// Will return [`EndpointTemplateError::InvalidTlsConfig`] if the provided
311    /// config cannot be passed to an [`Endpoint`] and would cause a failure
312    /// when building an endpoint.
313    #[cfg(feature = "tls")]
314    pub fn tls_config(self, tls_config: ClientTlsConfig) -> Result<Self, EndpointTemplateError> {
315        // Make sure we'll be able to build the Endpoint using this ClientTlsConfig
316        let endpoint = self.build(std::net::Ipv4Addr::new(127, 0, 0, 1));
317        let _ = endpoint
318            .tls_config(tls_config.clone())
319            .map_err(|_| EndpointTemplateError::InvalidTlsConfig)?;
320
321        Ok(Self {
322            tls_config: Some(tls_config),
323            ..self
324        })
325    }
326
327    /// r.f. [`Endpoint::tcp_nodelay`].
328    #[must_use]
329    pub fn tcp_nodelay(self, enabled: bool) -> Self {
330        Self {
331            tcp_nodelay: Some(enabled),
332            ..self
333        }
334    }
335
336    /// r.f. [`Endpoint::http2_keep_alive_interval`].
337    #[must_use]
338    pub fn http2_keep_alive_interval(self, interval: Duration) -> Self {
339        Self {
340            http2_keep_alive_interval: Some(interval),
341            ..self
342        }
343    }
344
345    /// r.f. [`Endpoint::keep_alive_timeout`].
346    #[must_use]
347    pub fn keep_alive_timeout(self, duration: Duration) -> Self {
348        Self {
349            http2_keep_alive_timeout: Some(duration),
350            ..self
351        }
352    }
353
354    /// r.f. [`Endpoint::keep_alive_while_idle`].
355    #[must_use]
356    pub fn keep_alive_while_idle(self, enabled: bool) -> Self {
357        Self {
358            http2_keep_alive_while_idle: Some(enabled),
359            ..self
360        }
361    }
362
363    /// r.f. [`Endpoint::http2_adaptive_window`].
364    #[must_use]
365    pub fn http2_adaptive_window(self, enabled: bool) -> Self {
366        Self {
367            http2_adaptive_window: Some(enabled),
368            ..self
369        }
370    }
371
372    /// r.f. [`Endpoint::http2_max_header_list_size`].
373    #[must_use]
374    pub fn http2_max_header_list_size(self, size: u32) -> Self {
375        Self {
376            http2_max_header_list_size: Some(size),
377            ..self
378        }
379    }
380
381    /// r.f. [`Endpoint::local_address`].
382    #[must_use]
383    pub fn local_address(self, ip: Option<IpAddr>) -> Self {
384        Self {
385            local_address: ip,
386            ..self
387        }
388    }
389}
390
391impl Debug for EndpointTemplate {
392    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
393        f.debug_struct("EndpointTemplate")
394            .field("url", &self.url.as_str())
395            .finish_non_exhaustive()
396    }
397}
398
399/// Errors that can occur when creating an [`EndpointTemplate`].
400#[derive(Debug, PartialEq, Eq, Clone, Copy, PartialOrd, Ord, Hash)]
401pub enum EndpointTemplateError {
402    /// Provided value is not a valid URL.
403    ///
404    /// Provided value could not be parsed as a URL.
405    NotAUrl,
406
407    /// The URL does not contain a host.
408    ///
409    /// Provided URL does not contain a host that can be resolved with DNS.
410    HostMissing,
411
412    /// The URL is already an IP address.
413    ///
414    /// Provided URL is already an IP address, so it cannot be used as a template.
415    AlreadyIpAddress,
416
417    /// The URL cannot be converted to an internal type.
418    ///
419    /// tonic's [`Endpoint`](tonic::transport::Endpoint) uses its own
420    /// [type](tonic::transport::Uri) for representing an address and provided
421    /// URL (after substituting hostname for an IP address) could not be
422    /// converted into it.
423    Inconvertible,
424
425    /// The provided user agent is invalid.
426    ///
427    /// Provided user agent cannot be converted to a [`HeaderValue`] and would
428    /// cause a failure when building an endpoint.
429    InvalidUserAgent,
430
431    /// The provided TLS config is invalid.
432    ///
433    /// Provided TLS config would cause a failure when building an endpoint.
434    #[cfg(feature = "tls")]
435    InvalidTlsConfig,
436}
437
438impl TryFrom<Url> for EndpointTemplate {
439    type Error = EndpointTemplateError;
440
441    fn try_from(url: Url) -> Result<Self, Self::Error> {
442        Self::new(url)
443    }
444}
445
446#[cfg_attr(coverage_nightly, coverage(off))]
447impl Display for EndpointTemplateError {
448    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
449        match self {
450            EndpointTemplateError::NotAUrl => write!(f, "not a valid URL"),
451            EndpointTemplateError::HostMissing => write!(f, "host missing"),
452            EndpointTemplateError::AlreadyIpAddress => write!(f, "already an IP address"),
453            EndpointTemplateError::Inconvertible => write!(f, "inconvertible URL"),
454            EndpointTemplateError::InvalidUserAgent => write!(f, "invalid user agent"),
455            #[cfg(feature = "tls")]
456            EndpointTemplateError::InvalidTlsConfig => write!(f, "invalid TLS config"),
457        }
458    }
459}
460
461impl Error for EndpointTemplateError {}
462
463#[cfg(test)]
464#[cfg_attr(coverage_nightly, coverage(off))]
465mod tests {
466    use std::{net::IpAddr, str::FromStr};
467
468    use http::Uri;
469    use url::Url;
470
471    use super::*;
472
473    #[test]
474    fn can_substitute_domain_fot_ipv4_address() {
475        let builder =
476            EndpointTemplate::new(Url::parse("http://example.com:50051/foo").unwrap()).unwrap();
477
478        let endpoint = builder.build("203.0.113.6".parse::<IpAddr>().unwrap());
479        assert_eq!(
480            *endpoint.uri(),
481            Uri::from_str("http://203.0.113.6:50051/foo").unwrap()
482        );
483    }
484
485    #[test]
486    fn can_substitute_domain_fot_ipv6_address() {
487        let builder =
488            EndpointTemplate::new(Url::parse("http://example.com:50051/foo").unwrap()).unwrap();
489
490        let endpoint = builder.build("2001:db8::".parse::<IpAddr>().unwrap());
491        assert_eq!(
492            *endpoint.uri(),
493            Uri::from_str("http://[2001:db8::]:50051/foo").unwrap()
494        );
495    }
496
497    #[rstest::rstest]
498    #[case("http://127.0.0.1:50051", EndpointTemplateError::AlreadyIpAddress)]
499    #[case("http://[::1]:50051", EndpointTemplateError::AlreadyIpAddress)]
500    #[case("mailto:admin@example.com", EndpointTemplateError::HostMissing)]
501    fn builder_error(#[case] input: &str, #[case] expected: EndpointTemplateError) {
502        let result = EndpointTemplate::new(Url::parse(input).unwrap());
503        assert!(result.is_err());
504        assert_eq!(result.unwrap_err(), expected);
505    }
506
507    #[rstest::rstest]
508    #[case("http://example.com:50051/foo", Ok("example.com"))]
509    #[case("http://127.0.0.1:50051", Err(EndpointTemplateError::AlreadyIpAddress))]
510    #[case("http://[::1]:50051", Err(EndpointTemplateError::AlreadyIpAddress))]
511    #[case("mailto:admin@example.com", Err(EndpointTemplateError::HostMissing))]
512    fn from_trait(#[case] url: &str, #[case] expected: Result<&str, EndpointTemplateError>) {
513        let url = Url::parse(url).unwrap();
514        let result = EndpointTemplate::try_from(url.clone());
515        let domain = result.as_ref().map(EndpointTemplate::domain);
516        assert_eq!(domain, expected.as_deref());
517    }
518
519    #[test]
520    fn setters() {
521        let url = Url::parse("http://example.com:50051/foo").unwrap();
522        let builder = EndpointTemplate::new(url.clone()).unwrap();
523
524        let origin = Uri::from_str("http://example.net:50001").unwrap();
525        let builder = builder.origin(origin.clone());
526        assert_eq!(builder.origin, Some(origin));
527
528        let user_agent = HeaderValue::from_str("my-user-agent").unwrap();
529        let builder = builder.user_agent(user_agent.clone()).unwrap();
530        assert_eq!(builder.user_agent, Some(user_agent));
531
532        let duration = Duration::from_secs(10);
533        let builder = builder.timeout(duration);
534        assert_eq!(builder.timeout, Some(duration));
535
536        let connect_timeout = Duration::from_secs(5);
537        let builder = builder.connect_timeout(connect_timeout);
538        assert_eq!(builder.connect_timeout, Some(connect_timeout));
539
540        let tcp_keepalive = Some(Duration::from_secs(30));
541        let builder = builder.tcp_keepalive(tcp_keepalive);
542        assert_eq!(builder.tcp_keepalive, tcp_keepalive);
543
544        let concurrency_limit = 10;
545        let builder = builder.concurrency_limit(concurrency_limit);
546        assert_eq!(builder.concurrency_limit, Some(concurrency_limit));
547
548        let rate_limit = (100, Duration::from_secs(1));
549        let builder = builder.rate_limit(rate_limit.0, rate_limit.1);
550        assert_eq!(builder.rate_limit, Some(rate_limit));
551
552        let init_stream_window_size = Some(64);
553        let builder = builder.initial_stream_window_size(init_stream_window_size);
554        assert_eq!(builder.init_stream_window_size, init_stream_window_size);
555
556        let init_connection_window_size = Some(128);
557        let builder = builder.initial_connection_window_size(init_connection_window_size);
558        assert_eq!(
559            builder.init_connection_window_size,
560            init_connection_window_size
561        );
562
563        let buffer_size = Some(1024);
564        let builder = builder.buffer_size(buffer_size);
565        assert_eq!(builder.buffer_size, buffer_size);
566
567        let tcp_nodelay = true;
568        let builder = builder.tcp_nodelay(tcp_nodelay);
569        assert_eq!(builder.tcp_nodelay, Some(tcp_nodelay));
570
571        let http2_keep_alive_interval = Duration::from_secs(30);
572        let builder = builder.http2_keep_alive_interval(http2_keep_alive_interval);
573        assert_eq!(
574            builder.http2_keep_alive_interval,
575            Some(http2_keep_alive_interval)
576        );
577
578        let keep_alive_timeout = Duration::from_secs(60);
579        let builder = builder.keep_alive_timeout(keep_alive_timeout);
580        assert_eq!(builder.http2_keep_alive_timeout, Some(keep_alive_timeout));
581
582        let keep_alive_while_idle = true;
583        let builder = builder.keep_alive_while_idle(keep_alive_while_idle);
584        assert_eq!(
585            builder.http2_keep_alive_while_idle,
586            Some(keep_alive_while_idle)
587        );
588
589        let http2_adaptive_window = true;
590        let builder = builder.http2_adaptive_window(http2_adaptive_window);
591        assert_eq!(builder.http2_adaptive_window, Some(http2_adaptive_window));
592
593        let http2_max_header_list_size = 8192;
594        let builder = builder.http2_max_header_list_size(http2_max_header_list_size);
595        assert_eq!(
596            builder.http2_max_header_list_size,
597            Some(http2_max_header_list_size)
598        );
599
600        let local_address = Some(IpAddr::from([127, 0, 0, 2]));
601        let builder = builder.local_address(local_address);
602        assert_eq!(builder.local_address, local_address);
603
604        let _ = builder.build([127, 0, 0, 1]);
605    }
606
607    #[test]
608    fn debug_output() {
609        let url = Url::parse("http://example.com:50051/foo").unwrap();
610        let builder = EndpointTemplate::new(url.clone()).unwrap();
611
612        let debug_output = format!("{builder:?}");
613        assert_eq!(
614            debug_output,
615            "EndpointTemplate { url: \"http://example.com:50051/foo\", .. }"
616        );
617    }
618}