soda_pool/
endpoint_template.rs

1use core::fmt;
2use http::HeaderValue;
3use std::error::Error;
4use std::fmt::{Debug, Display};
5use std::{net::IpAddr, str::FromStr, time::Duration};
6#[cfg(feature = "tls")]
7use tonic::transport::ClientTlsConfig;
8use tonic::transport::{Endpoint, Uri};
9use url::Host;
10use url::Url;
11
12/// Template for creating [`Endpoint`]s.
13///
14/// This structure is used to store all the information necessary to create an [`Endpoint`].
15/// It then creates an [`Endpoint`] to a specific IP address using the [`build`](EndpointTemplate::build) method.
16#[derive(Clone)]
17pub struct EndpointTemplate {
18    url: Url,
19    origin: Option<Uri>,
20    user_agent: Option<HeaderValue>,
21    timeout: Option<Duration>,
22    concurrency_limit: Option<usize>,
23    rate_limit: Option<(u64, Duration)>,
24    #[cfg(feature = "tls")]
25    tls_config: Option<ClientTlsConfig>,
26    buffer_size: Option<usize>,
27    init_stream_window_size: Option<u32>,
28    init_connection_window_size: Option<u32>,
29    tcp_keepalive: Option<Duration>,
30    tcp_keepalive_interval: Option<Duration>,
31    tcp_keepalive_retries: Option<u32>,
32    tcp_nodelay: Option<bool>,
33    http2_keep_alive_interval: Option<Duration>,
34    http2_keep_alive_timeout: Option<Duration>,
35    http2_keep_alive_while_idle: Option<bool>,
36    http2_max_header_list_size: Option<u32>,
37    connect_timeout: Option<Duration>,
38    http2_adaptive_window: Option<bool>,
39    local_address: Option<IpAddr>,
40    // todo: If at all possible, support also setting the executor.
41}
42
43impl EndpointTemplate {
44    /// Creates a new `EndpointTemplate` from the provided URL.
45    ///
46    /// # Errors
47    /// - Will return [`EndpointTemplateError::NotAUrl`] if the provided URL is not a valid URL.
48    /// - Will return [`EndpointTemplateError::HostMissing`] if the provided URL does not contain a host.
49    /// - Will return [`EndpointTemplateError::AlreadyIpAddress`] if the provided URL already contains an IP address.
50    /// - Will return [`EndpointTemplateError::Inconvertible`] if the provided URL cannot be converted to the tonic's internal representation.
51    // Url requires a full Unicode support which, although correct, seems like
52    // an overkill for just substituting hostname with an IP address. Accepts
53    // any type that has a conversion to Url instead of just Url to limit
54    // breaking changes in the future if decide to use another type.
55    pub fn new(url: impl TryInto<Url>) -> Result<Self, EndpointTemplateError> {
56        let url: Url = url.try_into().map_err(|_| EndpointTemplateError::NotAUrl)?;
57
58        // Check if URL contains hostname that can be resolved with DNS
59        match url.host() {
60            Some(host) => match host {
61                Host::Domain(_) => {}
62                _ => return Err(EndpointTemplateError::AlreadyIpAddress),
63            },
64            None => return Err(EndpointTemplateError::HostMissing),
65        }
66
67        // Check if hostname in URL can be substituted by IP address
68        if url.cannot_be_a_base() {
69            // Since we have a host, I can't imagine an address that still
70            // couldn't be a base. If there is one, let's treat it as
71            // Inconvertible error for simplicity.
72            return Err(EndpointTemplateError::Inconvertible);
73        }
74
75        // Check if tonic Uri can be build from Url.
76        if Uri::from_str(url.as_str()).is_err() {
77            // It's hard to prove that any url::Url will also be parsable as
78            // tonic::transport::Uri, but in practice this error should never
79            // happen.
80            return Err(EndpointTemplateError::Inconvertible);
81        }
82
83        Ok(Self {
84            url,
85            origin: None,
86            user_agent: None,
87            timeout: None,
88            #[cfg(feature = "tls")]
89            tls_config: None,
90            concurrency_limit: None,
91            rate_limit: None,
92            buffer_size: None,
93            init_stream_window_size: None,
94            init_connection_window_size: None,
95            tcp_keepalive: None,
96            tcp_keepalive_interval: None,
97            tcp_keepalive_retries: None,
98            tcp_nodelay: None,
99            http2_keep_alive_interval: None,
100            http2_keep_alive_timeout: None,
101            http2_keep_alive_while_idle: None,
102            http2_max_header_list_size: None,
103            connect_timeout: None,
104            http2_adaptive_window: None,
105            local_address: None,
106        })
107    }
108
109    /// Builds an [`Endpoint`] to the IP address.
110    ///
111    /// This will substitute the hostname in the URL with the provided IP
112    /// address, create a new [`Endpoint`] from it, and apply all the settings
113    /// set in the builder.
114    #[allow(clippy::missing_panics_doc)]
115    pub fn build(&self, ip_address: impl Into<IpAddr>) -> Endpoint {
116        let mut endpoint = Endpoint::from(self.build_uri(ip_address.into()));
117
118        if let Some(origin) = self.origin.clone() {
119            endpoint = endpoint.origin(origin);
120        }
121
122        if let Some(user_agent) = self.user_agent.clone() {
123            endpoint = endpoint
124                .user_agent(user_agent)
125                .expect("already checked in the setter");
126        }
127
128        if let Some(timeout) = self.timeout {
129            endpoint = endpoint.timeout(timeout);
130        }
131
132        #[cfg(feature = "tls")]
133        if let Some(tls_config) = self.tls_config.clone() {
134            endpoint = endpoint
135                .tls_config(tls_config)
136                .expect("already checked in the setter");
137        }
138
139        if let Some(connect_timeout) = self.connect_timeout {
140            endpoint = endpoint.connect_timeout(connect_timeout);
141        }
142
143        endpoint = endpoint
144            .tcp_keepalive(self.tcp_keepalive)
145            .tcp_keepalive_interval(self.tcp_keepalive_interval)
146            .tcp_keepalive_retries(self.tcp_keepalive_retries);
147
148        if let Some(limit) = self.concurrency_limit {
149            endpoint = endpoint.concurrency_limit(limit);
150        }
151
152        if let Some((limit, duration)) = self.rate_limit {
153            endpoint = endpoint.rate_limit(limit, duration);
154        }
155
156        if let Some(sz) = self.init_stream_window_size {
157            endpoint = endpoint.initial_stream_window_size(sz);
158        }
159
160        if let Some(sz) = self.init_connection_window_size {
161            endpoint = endpoint.initial_connection_window_size(sz);
162        }
163
164        endpoint = endpoint.buffer_size(self.buffer_size);
165
166        if let Some(tcp_nodelay) = self.tcp_nodelay {
167            endpoint = endpoint.tcp_nodelay(tcp_nodelay);
168        }
169
170        if let Some(interval) = self.http2_keep_alive_interval {
171            endpoint = endpoint.http2_keep_alive_interval(interval);
172        }
173
174        if let Some(duration) = self.http2_keep_alive_timeout {
175            endpoint = endpoint.keep_alive_timeout(duration);
176        }
177
178        if let Some(enabled) = self.http2_keep_alive_while_idle {
179            endpoint = endpoint.keep_alive_while_idle(enabled);
180        }
181
182        if let Some(enabled) = self.http2_adaptive_window {
183            endpoint = endpoint.http2_adaptive_window(enabled);
184        }
185
186        if let Some(size) = self.http2_max_header_list_size {
187            endpoint = endpoint.http2_max_header_list_size(size);
188        }
189
190        endpoint = endpoint.local_address(self.local_address);
191
192        endpoint
193    }
194
195    /// Returns the hostname of the URL held in the template.
196    #[allow(clippy::missing_panics_doc)]
197    pub fn domain(&self) -> &str {
198        self.url
199            .domain()
200            .expect("already checked in the constructor")
201    }
202
203    fn build_uri(&self, ip_addr: IpAddr) -> Uri {
204        // We make sure this conversion doesn't return any errors in Self::new
205        // already so it's safe to unwrap here.
206        let mut url = self.url.clone();
207        url.set_ip_host(ip_addr)
208            .expect("already checked in the constructor by trying cannot_be_a_base");
209        Uri::from_str(url.as_str()).expect("starting from Url, this should always be a valid Uri")
210    }
211
212    /// r.f. [`Endpoint::user_agent`].
213    ///
214    /// # Errors
215    ///
216    /// Will return [`EndpointTemplateError::InvalidUserAgent`] if the provided
217    /// value cannot be converted to a [`HeaderValue`] and would cause a failure
218    /// when building an endpoint.
219    pub fn user_agent(
220        self,
221        user_agent: impl TryInto<HeaderValue>,
222    ) -> Result<Self, EndpointTemplateError> {
223        user_agent
224            .try_into()
225            .map(|ua| Self {
226                user_agent: Some(ua),
227                ..self
228            })
229            .map_err(|_| EndpointTemplateError::InvalidUserAgent)
230    }
231
232    /// r.f. [`Endpoint::origin`].
233    #[must_use]
234    pub fn origin(self, origin: Uri) -> Self {
235        Self {
236            origin: Some(origin),
237            ..self
238        }
239    }
240
241    /// r.f. [`Endpoint::timeout`].
242    #[must_use]
243    pub fn timeout(self, dur: Duration) -> Self {
244        Self {
245            timeout: Some(dur),
246            ..self
247        }
248    }
249
250    /// r.f. [`Endpoint::connect_timeout`].
251    #[must_use]
252    pub fn connect_timeout(self, dur: Duration) -> Self {
253        Self {
254            connect_timeout: Some(dur),
255            ..self
256        }
257    }
258
259    /// r.f. [`Endpoint::tcp_keepalive`].
260    #[must_use]
261    pub fn tcp_keepalive(self, tcp_keepalive: Option<Duration>) -> Self {
262        Self {
263            tcp_keepalive,
264            ..self
265        }
266    }
267
268    /// r.f. [`Endpoint::tcp_keepalive_interval`].
269    #[must_use]
270    pub fn tcp_keepalive_interval(self, interval: Duration) -> Self {
271        Self {
272            tcp_keepalive_interval: Some(interval),
273            ..self
274        }
275    }
276
277    /// r.f. [`Endpoint::tcp_keepalive_retries`].
278    #[must_use]
279    pub fn tcp_keepalive_retries(self, retries: u32) -> Self {
280        Self {
281            tcp_keepalive_retries: Some(retries),
282            ..self
283        }
284    }
285
286    /// r.f. [`Endpoint::concurrency_limit`]
287    #[must_use]
288    pub fn concurrency_limit(self, limit: usize) -> Self {
289        Self {
290            concurrency_limit: Some(limit),
291            ..self
292        }
293    }
294
295    /// r.f. [`Endpoint::rate_limit`].
296    #[must_use]
297    pub fn rate_limit(self, limit: u64, duration: Duration) -> Self {
298        Self {
299            rate_limit: Some((limit, duration)),
300            ..self
301        }
302    }
303
304    /// r.f. [`Endpoint::initial_stream_window_size`].
305    #[must_use]
306    pub fn initial_stream_window_size(self, sz: impl Into<Option<u32>>) -> Self {
307        Self {
308            init_stream_window_size: sz.into(),
309            ..self
310        }
311    }
312
313    /// r.f. [`Endpoint::initial_connection_window_size`].
314    #[must_use]
315    pub fn initial_connection_window_size(self, sz: impl Into<Option<u32>>) -> Self {
316        Self {
317            init_connection_window_size: sz.into(),
318            ..self
319        }
320    }
321
322    /// r.f. [`Endpoint::buffer_size`].
323    #[must_use]
324    pub fn buffer_size(self, sz: impl Into<Option<usize>>) -> Self {
325        Self {
326            buffer_size: sz.into(),
327            ..self
328        }
329    }
330
331    /// r.f. [`Endpoint::tls_config`].
332    ///
333    /// # Errors
334    ///
335    /// Will return [`EndpointTemplateError::InvalidTlsConfig`] if the provided
336    /// config cannot be passed to an [`Endpoint`] and would cause a failure
337    /// when building an endpoint.
338    #[cfg(feature = "tls")]
339    pub fn tls_config(self, tls_config: ClientTlsConfig) -> Result<Self, EndpointTemplateError> {
340        // Make sure we'll be able to build the Endpoint using this ClientTlsConfig
341        let endpoint = self.build(std::net::Ipv4Addr::LOCALHOST);
342        let _ = endpoint
343            .tls_config(tls_config.clone())
344            .map_err(|_| EndpointTemplateError::InvalidTlsConfig)?;
345
346        Ok(Self {
347            tls_config: Some(tls_config),
348            ..self
349        })
350    }
351
352    /// r.f. [`Endpoint::tcp_nodelay`].
353    #[must_use]
354    pub fn tcp_nodelay(self, enabled: bool) -> Self {
355        Self {
356            tcp_nodelay: Some(enabled),
357            ..self
358        }
359    }
360
361    /// r.f. [`Endpoint::http2_keep_alive_interval`].
362    #[must_use]
363    pub fn http2_keep_alive_interval(self, interval: Duration) -> Self {
364        Self {
365            http2_keep_alive_interval: Some(interval),
366            ..self
367        }
368    }
369
370    /// r.f. [`Endpoint::keep_alive_timeout`].
371    #[must_use]
372    pub fn keep_alive_timeout(self, duration: Duration) -> Self {
373        Self {
374            http2_keep_alive_timeout: Some(duration),
375            ..self
376        }
377    }
378
379    /// r.f. [`Endpoint::keep_alive_while_idle`].
380    #[must_use]
381    pub fn keep_alive_while_idle(self, enabled: bool) -> Self {
382        Self {
383            http2_keep_alive_while_idle: Some(enabled),
384            ..self
385        }
386    }
387
388    /// r.f. [`Endpoint::http2_adaptive_window`].
389    #[must_use]
390    pub fn http2_adaptive_window(self, enabled: bool) -> Self {
391        Self {
392            http2_adaptive_window: Some(enabled),
393            ..self
394        }
395    }
396
397    /// r.f. [`Endpoint::http2_max_header_list_size`].
398    #[must_use]
399    pub fn http2_max_header_list_size(self, size: u32) -> Self {
400        Self {
401            http2_max_header_list_size: Some(size),
402            ..self
403        }
404    }
405
406    /// r.f. [`Endpoint::local_address`].
407    #[must_use]
408    pub fn local_address(self, ip: Option<IpAddr>) -> Self {
409        Self {
410            local_address: ip,
411            ..self
412        }
413    }
414}
415
416impl Debug for EndpointTemplate {
417    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
418        f.debug_struct("EndpointTemplate")
419            .field("url", &self.url.as_str())
420            .finish_non_exhaustive()
421    }
422}
423
424/// Errors that can occur when creating an [`EndpointTemplate`].
425#[derive(Debug, PartialEq, Eq, Clone, Copy, PartialOrd, Ord, Hash)]
426pub enum EndpointTemplateError {
427    /// Provided value is not a valid URL.
428    ///
429    /// Provided value could not be parsed as a URL.
430    NotAUrl,
431
432    /// The URL does not contain a host.
433    ///
434    /// Provided URL does not contain a host that can be resolved with DNS.
435    HostMissing,
436
437    /// The URL is already an IP address.
438    ///
439    /// Provided URL is already an IP address, so it cannot be used as a template.
440    AlreadyIpAddress,
441
442    /// The URL cannot be converted to an internal type.
443    ///
444    /// tonic's [`Endpoint`](tonic::transport::Endpoint) uses its own
445    /// [type](tonic::transport::Uri) for representing an address and provided
446    /// URL (after substituting hostname for an IP address) could not be
447    /// converted into it.
448    Inconvertible,
449
450    /// The provided user agent is invalid.
451    ///
452    /// Provided user agent cannot be converted to a [`HeaderValue`] and would
453    /// cause a failure when building an endpoint.
454    InvalidUserAgent,
455
456    /// The provided TLS config is invalid.
457    ///
458    /// Provided TLS config would cause a failure when building an endpoint.
459    #[cfg(feature = "tls")]
460    InvalidTlsConfig,
461}
462
463impl TryFrom<Url> for EndpointTemplate {
464    type Error = EndpointTemplateError;
465
466    fn try_from(url: Url) -> Result<Self, Self::Error> {
467        Self::new(url)
468    }
469}
470
471#[cfg_attr(coverage_nightly, coverage(off))]
472impl Display for EndpointTemplateError {
473    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
474        match self {
475            EndpointTemplateError::NotAUrl => write!(f, "not a valid URL"),
476            EndpointTemplateError::HostMissing => write!(f, "host missing"),
477            EndpointTemplateError::AlreadyIpAddress => write!(f, "already an IP address"),
478            EndpointTemplateError::Inconvertible => write!(f, "inconvertible URL"),
479            EndpointTemplateError::InvalidUserAgent => write!(f, "invalid user agent"),
480            #[cfg(feature = "tls")]
481            EndpointTemplateError::InvalidTlsConfig => write!(f, "invalid TLS config"),
482        }
483    }
484}
485
486impl Error for EndpointTemplateError {}
487
488#[cfg(test)]
489#[cfg_attr(coverage_nightly, coverage(off))]
490mod tests {
491    use std::{net::IpAddr, str::FromStr};
492
493    use http::Uri;
494    use url::Url;
495
496    use super::*;
497
498    #[test]
499    fn can_substitute_domain_fot_ipv4_address() {
500        let builder =
501            EndpointTemplate::new(Url::parse("http://example.com:50051/foo").unwrap()).unwrap();
502
503        let endpoint = builder.build("203.0.113.6".parse::<IpAddr>().unwrap());
504        assert_eq!(
505            *endpoint.uri(),
506            Uri::from_str("http://203.0.113.6:50051/foo").unwrap()
507        );
508    }
509
510    #[test]
511    fn can_substitute_domain_fot_ipv6_address() {
512        let builder =
513            EndpointTemplate::new(Url::parse("http://example.com:50051/foo").unwrap()).unwrap();
514
515        let endpoint = builder.build("2001:db8::".parse::<IpAddr>().unwrap());
516        assert_eq!(
517            *endpoint.uri(),
518            Uri::from_str("http://[2001:db8::]:50051/foo").unwrap()
519        );
520    }
521
522    #[rstest::rstest]
523    #[case("http://127.0.0.1:50051", EndpointTemplateError::AlreadyIpAddress)]
524    #[case("http://[::1]:50051", EndpointTemplateError::AlreadyIpAddress)]
525    #[case("mailto:admin@example.com", EndpointTemplateError::HostMissing)]
526    fn builder_error(#[case] input: &str, #[case] expected: EndpointTemplateError) {
527        let result = EndpointTemplate::new(Url::parse(input).unwrap());
528        assert!(result.is_err());
529        assert_eq!(result.unwrap_err(), expected);
530    }
531
532    #[rstest::rstest]
533    #[case("http://example.com:50051/foo", Ok("example.com"))]
534    #[case("http://127.0.0.1:50051", Err(EndpointTemplateError::AlreadyIpAddress))]
535    #[case("http://[::1]:50051", Err(EndpointTemplateError::AlreadyIpAddress))]
536    #[case("mailto:admin@example.com", Err(EndpointTemplateError::HostMissing))]
537    fn from_trait(#[case] url: &str, #[case] expected: Result<&str, EndpointTemplateError>) {
538        let url = Url::parse(url).unwrap();
539        let result = EndpointTemplate::try_from(url.clone());
540        let domain = result.as_ref().map(EndpointTemplate::domain);
541        assert_eq!(domain, expected.as_deref());
542    }
543
544    #[test]
545    fn setters() {
546        let url = Url::parse("http://example.com:50051/foo").unwrap();
547        let builder = EndpointTemplate::new(url.clone()).unwrap();
548
549        let origin = Uri::from_str("http://example.net:50001").unwrap();
550        let builder = builder.origin(origin.clone());
551        assert_eq!(builder.origin, Some(origin));
552
553        let user_agent = HeaderValue::from_str("my-user-agent").unwrap();
554        let builder = builder.user_agent(user_agent.clone()).unwrap();
555        assert_eq!(builder.user_agent, Some(user_agent));
556
557        let duration = Duration::from_secs(10);
558        let builder = builder.timeout(duration);
559        assert_eq!(builder.timeout, Some(duration));
560
561        let connect_timeout = Duration::from_secs(5);
562        let builder = builder.connect_timeout(connect_timeout);
563        assert_eq!(builder.connect_timeout, Some(connect_timeout));
564
565        let tcp_keepalive = Some(Duration::from_secs(30));
566        let builder = builder.tcp_keepalive(tcp_keepalive);
567        assert_eq!(builder.tcp_keepalive, tcp_keepalive);
568
569        let concurrency_limit = 10;
570        let builder = builder.concurrency_limit(concurrency_limit);
571        assert_eq!(builder.concurrency_limit, Some(concurrency_limit));
572
573        let rate_limit = (100, Duration::from_secs(1));
574        let builder = builder.rate_limit(rate_limit.0, rate_limit.1);
575        assert_eq!(builder.rate_limit, Some(rate_limit));
576
577        let init_stream_window_size = Some(64);
578        let builder = builder.initial_stream_window_size(init_stream_window_size);
579        assert_eq!(builder.init_stream_window_size, init_stream_window_size);
580
581        let init_connection_window_size = Some(128);
582        let builder = builder.initial_connection_window_size(init_connection_window_size);
583        assert_eq!(
584            builder.init_connection_window_size,
585            init_connection_window_size
586        );
587
588        let buffer_size = Some(1024);
589        let builder = builder.buffer_size(buffer_size);
590        assert_eq!(builder.buffer_size, buffer_size);
591
592        let tcp_nodelay = true;
593        let builder = builder.tcp_nodelay(tcp_nodelay);
594        assert_eq!(builder.tcp_nodelay, Some(tcp_nodelay));
595
596        let http2_keep_alive_interval = Duration::from_secs(30);
597        let builder = builder.http2_keep_alive_interval(http2_keep_alive_interval);
598        assert_eq!(
599            builder.http2_keep_alive_interval,
600            Some(http2_keep_alive_interval)
601        );
602
603        let keep_alive_timeout = Duration::from_secs(60);
604        let builder = builder.keep_alive_timeout(keep_alive_timeout);
605        assert_eq!(builder.http2_keep_alive_timeout, Some(keep_alive_timeout));
606
607        let keep_alive_while_idle = true;
608        let builder = builder.keep_alive_while_idle(keep_alive_while_idle);
609        assert_eq!(
610            builder.http2_keep_alive_while_idle,
611            Some(keep_alive_while_idle)
612        );
613
614        let http2_adaptive_window = true;
615        let builder = builder.http2_adaptive_window(http2_adaptive_window);
616        assert_eq!(builder.http2_adaptive_window, Some(http2_adaptive_window));
617
618        let http2_max_header_list_size = 8192;
619        let builder = builder.http2_max_header_list_size(http2_max_header_list_size);
620        assert_eq!(
621            builder.http2_max_header_list_size,
622            Some(http2_max_header_list_size)
623        );
624
625        let local_address = Some(IpAddr::from([127, 0, 0, 2]));
626        let builder = builder.local_address(local_address);
627        assert_eq!(builder.local_address, local_address);
628
629        let _ = builder.build([127, 0, 0, 1]);
630    }
631
632    #[test]
633    fn debug_output() {
634        let url = Url::parse("http://example.com:50051/foo").unwrap();
635        let builder = EndpointTemplate::new(url.clone()).unwrap();
636
637        let debug_output = format!("{builder:?}");
638        assert_eq!(
639            debug_output,
640            "EndpointTemplate { url: \"http://example.com:50051/foo\", .. }"
641        );
642    }
643}