1use core::fmt;
2use http::HeaderValue;
3use std::error::Error;
4use std::fmt::{Debug, Display};
5use std::{net::IpAddr, str::FromStr, time::Duration};
6#[cfg(feature = "tls")]
7use tonic::transport::ClientTlsConfig;
8use tonic::transport::{Endpoint, Uri};
9use url::Host;
10use url::Url;
11
12#[derive(Clone)]
17pub struct EndpointTemplate {
18 url: Url,
19 origin: Option<Uri>,
20 user_agent: Option<HeaderValue>,
21 timeout: Option<Duration>,
22 concurrency_limit: Option<usize>,
23 rate_limit: Option<(u64, Duration)>,
24 #[cfg(feature = "tls")]
25 tls_config: Option<ClientTlsConfig>,
26 buffer_size: Option<usize>,
27 init_stream_window_size: Option<u32>,
28 init_connection_window_size: Option<u32>,
29 tcp_keepalive: Option<Duration>,
30 tcp_keepalive_interval: Option<Duration>,
31 tcp_keepalive_retries: Option<u32>,
32 tcp_nodelay: Option<bool>,
33 http2_keep_alive_interval: Option<Duration>,
34 http2_keep_alive_timeout: Option<Duration>,
35 http2_keep_alive_while_idle: Option<bool>,
36 http2_max_header_list_size: Option<u32>,
37 connect_timeout: Option<Duration>,
38 http2_adaptive_window: Option<bool>,
39 local_address: Option<IpAddr>,
40 }
42
43impl EndpointTemplate {
44 pub fn new(url: impl TryInto<Url>) -> Result<Self, EndpointTemplateError> {
56 let url: Url = url.try_into().map_err(|_| EndpointTemplateError::NotAUrl)?;
57
58 match url.host() {
60 Some(host) => match host {
61 Host::Domain(_) => {}
62 _ => return Err(EndpointTemplateError::AlreadyIpAddress),
63 },
64 None => return Err(EndpointTemplateError::HostMissing),
65 }
66
67 if url.cannot_be_a_base() {
69 return Err(EndpointTemplateError::Inconvertible);
73 }
74
75 if Uri::from_str(url.as_str()).is_err() {
77 return Err(EndpointTemplateError::Inconvertible);
81 }
82
83 Ok(Self {
84 url,
85 origin: None,
86 user_agent: None,
87 timeout: None,
88 #[cfg(feature = "tls")]
89 tls_config: None,
90 concurrency_limit: None,
91 rate_limit: None,
92 buffer_size: None,
93 init_stream_window_size: None,
94 init_connection_window_size: None,
95 tcp_keepalive: None,
96 tcp_keepalive_interval: None,
97 tcp_keepalive_retries: None,
98 tcp_nodelay: None,
99 http2_keep_alive_interval: None,
100 http2_keep_alive_timeout: None,
101 http2_keep_alive_while_idle: None,
102 http2_max_header_list_size: None,
103 connect_timeout: None,
104 http2_adaptive_window: None,
105 local_address: None,
106 })
107 }
108
109 #[allow(clippy::missing_panics_doc)]
115 pub fn build(&self, ip_address: impl Into<IpAddr>) -> Endpoint {
116 let mut endpoint = Endpoint::from(self.build_uri(ip_address.into()));
117
118 if let Some(origin) = self.origin.clone() {
119 endpoint = endpoint.origin(origin);
120 }
121
122 if let Some(user_agent) = self.user_agent.clone() {
123 endpoint = endpoint
124 .user_agent(user_agent)
125 .expect("already checked in the setter");
126 }
127
128 if let Some(timeout) = self.timeout {
129 endpoint = endpoint.timeout(timeout);
130 }
131
132 #[cfg(feature = "tls")]
133 if let Some(tls_config) = self.tls_config.clone() {
134 endpoint = endpoint
135 .tls_config(tls_config)
136 .expect("already checked in the setter");
137 }
138
139 if let Some(connect_timeout) = self.connect_timeout {
140 endpoint = endpoint.connect_timeout(connect_timeout);
141 }
142
143 endpoint = endpoint
144 .tcp_keepalive(self.tcp_keepalive)
145 .tcp_keepalive_interval(self.tcp_keepalive_interval)
146 .tcp_keepalive_retries(self.tcp_keepalive_retries);
147
148 if let Some(limit) = self.concurrency_limit {
149 endpoint = endpoint.concurrency_limit(limit);
150 }
151
152 if let Some((limit, duration)) = self.rate_limit {
153 endpoint = endpoint.rate_limit(limit, duration);
154 }
155
156 if let Some(sz) = self.init_stream_window_size {
157 endpoint = endpoint.initial_stream_window_size(sz);
158 }
159
160 if let Some(sz) = self.init_connection_window_size {
161 endpoint = endpoint.initial_connection_window_size(sz);
162 }
163
164 endpoint = endpoint.buffer_size(self.buffer_size);
165
166 if let Some(tcp_nodelay) = self.tcp_nodelay {
167 endpoint = endpoint.tcp_nodelay(tcp_nodelay);
168 }
169
170 if let Some(interval) = self.http2_keep_alive_interval {
171 endpoint = endpoint.http2_keep_alive_interval(interval);
172 }
173
174 if let Some(duration) = self.http2_keep_alive_timeout {
175 endpoint = endpoint.keep_alive_timeout(duration);
176 }
177
178 if let Some(enabled) = self.http2_keep_alive_while_idle {
179 endpoint = endpoint.keep_alive_while_idle(enabled);
180 }
181
182 if let Some(enabled) = self.http2_adaptive_window {
183 endpoint = endpoint.http2_adaptive_window(enabled);
184 }
185
186 if let Some(size) = self.http2_max_header_list_size {
187 endpoint = endpoint.http2_max_header_list_size(size);
188 }
189
190 endpoint = endpoint.local_address(self.local_address);
191
192 endpoint
193 }
194
195 #[allow(clippy::missing_panics_doc)]
197 pub fn domain(&self) -> &str {
198 self.url
199 .domain()
200 .expect("already checked in the constructor")
201 }
202
203 fn build_uri(&self, ip_addr: IpAddr) -> Uri {
204 let mut url = self.url.clone();
207 url.set_ip_host(ip_addr)
208 .expect("already checked in the constructor by trying cannot_be_a_base");
209 Uri::from_str(url.as_str()).expect("starting from Url, this should always be a valid Uri")
210 }
211
212 pub fn user_agent(
220 self,
221 user_agent: impl TryInto<HeaderValue>,
222 ) -> Result<Self, EndpointTemplateError> {
223 user_agent
224 .try_into()
225 .map(|ua| Self {
226 user_agent: Some(ua),
227 ..self
228 })
229 .map_err(|_| EndpointTemplateError::InvalidUserAgent)
230 }
231
232 #[must_use]
234 pub fn origin(self, origin: Uri) -> Self {
235 Self {
236 origin: Some(origin),
237 ..self
238 }
239 }
240
241 #[must_use]
243 pub fn timeout(self, dur: Duration) -> Self {
244 Self {
245 timeout: Some(dur),
246 ..self
247 }
248 }
249
250 #[must_use]
252 pub fn connect_timeout(self, dur: Duration) -> Self {
253 Self {
254 connect_timeout: Some(dur),
255 ..self
256 }
257 }
258
259 #[must_use]
261 pub fn tcp_keepalive(self, tcp_keepalive: Option<Duration>) -> Self {
262 Self {
263 tcp_keepalive,
264 ..self
265 }
266 }
267
268 #[must_use]
270 pub fn tcp_keepalive_interval(self, interval: Duration) -> Self {
271 Self {
272 tcp_keepalive_interval: Some(interval),
273 ..self
274 }
275 }
276
277 #[must_use]
279 pub fn tcp_keepalive_retries(self, retries: u32) -> Self {
280 Self {
281 tcp_keepalive_retries: Some(retries),
282 ..self
283 }
284 }
285
286 #[must_use]
288 pub fn concurrency_limit(self, limit: usize) -> Self {
289 Self {
290 concurrency_limit: Some(limit),
291 ..self
292 }
293 }
294
295 #[must_use]
297 pub fn rate_limit(self, limit: u64, duration: Duration) -> Self {
298 Self {
299 rate_limit: Some((limit, duration)),
300 ..self
301 }
302 }
303
304 #[must_use]
306 pub fn initial_stream_window_size(self, sz: impl Into<Option<u32>>) -> Self {
307 Self {
308 init_stream_window_size: sz.into(),
309 ..self
310 }
311 }
312
313 #[must_use]
315 pub fn initial_connection_window_size(self, sz: impl Into<Option<u32>>) -> Self {
316 Self {
317 init_connection_window_size: sz.into(),
318 ..self
319 }
320 }
321
322 #[must_use]
324 pub fn buffer_size(self, sz: impl Into<Option<usize>>) -> Self {
325 Self {
326 buffer_size: sz.into(),
327 ..self
328 }
329 }
330
331 #[cfg(feature = "tls")]
339 pub fn tls_config(self, tls_config: ClientTlsConfig) -> Result<Self, EndpointTemplateError> {
340 let endpoint = self.build(std::net::Ipv4Addr::LOCALHOST);
342 let _ = endpoint
343 .tls_config(tls_config.clone())
344 .map_err(|_| EndpointTemplateError::InvalidTlsConfig)?;
345
346 Ok(Self {
347 tls_config: Some(tls_config),
348 ..self
349 })
350 }
351
352 #[must_use]
354 pub fn tcp_nodelay(self, enabled: bool) -> Self {
355 Self {
356 tcp_nodelay: Some(enabled),
357 ..self
358 }
359 }
360
361 #[must_use]
363 pub fn http2_keep_alive_interval(self, interval: Duration) -> Self {
364 Self {
365 http2_keep_alive_interval: Some(interval),
366 ..self
367 }
368 }
369
370 #[must_use]
372 pub fn keep_alive_timeout(self, duration: Duration) -> Self {
373 Self {
374 http2_keep_alive_timeout: Some(duration),
375 ..self
376 }
377 }
378
379 #[must_use]
381 pub fn keep_alive_while_idle(self, enabled: bool) -> Self {
382 Self {
383 http2_keep_alive_while_idle: Some(enabled),
384 ..self
385 }
386 }
387
388 #[must_use]
390 pub fn http2_adaptive_window(self, enabled: bool) -> Self {
391 Self {
392 http2_adaptive_window: Some(enabled),
393 ..self
394 }
395 }
396
397 #[must_use]
399 pub fn http2_max_header_list_size(self, size: u32) -> Self {
400 Self {
401 http2_max_header_list_size: Some(size),
402 ..self
403 }
404 }
405
406 #[must_use]
408 pub fn local_address(self, ip: Option<IpAddr>) -> Self {
409 Self {
410 local_address: ip,
411 ..self
412 }
413 }
414}
415
416impl Debug for EndpointTemplate {
417 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
418 f.debug_struct("EndpointTemplate")
419 .field("url", &self.url.as_str())
420 .finish_non_exhaustive()
421 }
422}
423
424#[derive(Debug, PartialEq, Eq, Clone, Copy, PartialOrd, Ord, Hash)]
426pub enum EndpointTemplateError {
427 NotAUrl,
431
432 HostMissing,
436
437 AlreadyIpAddress,
441
442 Inconvertible,
449
450 InvalidUserAgent,
455
456 #[cfg(feature = "tls")]
460 InvalidTlsConfig,
461}
462
463impl TryFrom<Url> for EndpointTemplate {
464 type Error = EndpointTemplateError;
465
466 fn try_from(url: Url) -> Result<Self, Self::Error> {
467 Self::new(url)
468 }
469}
470
471#[cfg_attr(coverage_nightly, coverage(off))]
472impl Display for EndpointTemplateError {
473 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
474 match self {
475 EndpointTemplateError::NotAUrl => write!(f, "not a valid URL"),
476 EndpointTemplateError::HostMissing => write!(f, "host missing"),
477 EndpointTemplateError::AlreadyIpAddress => write!(f, "already an IP address"),
478 EndpointTemplateError::Inconvertible => write!(f, "inconvertible URL"),
479 EndpointTemplateError::InvalidUserAgent => write!(f, "invalid user agent"),
480 #[cfg(feature = "tls")]
481 EndpointTemplateError::InvalidTlsConfig => write!(f, "invalid TLS config"),
482 }
483 }
484}
485
486impl Error for EndpointTemplateError {}
487
488#[cfg(test)]
489#[cfg_attr(coverage_nightly, coverage(off))]
490mod tests {
491 use std::{net::IpAddr, str::FromStr};
492
493 use http::Uri;
494 use url::Url;
495
496 use super::*;
497
498 #[test]
499 fn can_substitute_domain_fot_ipv4_address() {
500 let builder =
501 EndpointTemplate::new(Url::parse("http://example.com:50051/foo").unwrap()).unwrap();
502
503 let endpoint = builder.build("203.0.113.6".parse::<IpAddr>().unwrap());
504 assert_eq!(
505 *endpoint.uri(),
506 Uri::from_str("http://203.0.113.6:50051/foo").unwrap()
507 );
508 }
509
510 #[test]
511 fn can_substitute_domain_fot_ipv6_address() {
512 let builder =
513 EndpointTemplate::new(Url::parse("http://example.com:50051/foo").unwrap()).unwrap();
514
515 let endpoint = builder.build("2001:db8::".parse::<IpAddr>().unwrap());
516 assert_eq!(
517 *endpoint.uri(),
518 Uri::from_str("http://[2001:db8::]:50051/foo").unwrap()
519 );
520 }
521
522 #[rstest::rstest]
523 #[case("http://127.0.0.1:50051", EndpointTemplateError::AlreadyIpAddress)]
524 #[case("http://[::1]:50051", EndpointTemplateError::AlreadyIpAddress)]
525 #[case("mailto:admin@example.com", EndpointTemplateError::HostMissing)]
526 fn builder_error(#[case] input: &str, #[case] expected: EndpointTemplateError) {
527 let result = EndpointTemplate::new(Url::parse(input).unwrap());
528 assert!(result.is_err());
529 assert_eq!(result.unwrap_err(), expected);
530 }
531
532 #[rstest::rstest]
533 #[case("http://example.com:50051/foo", Ok("example.com"))]
534 #[case("http://127.0.0.1:50051", Err(EndpointTemplateError::AlreadyIpAddress))]
535 #[case("http://[::1]:50051", Err(EndpointTemplateError::AlreadyIpAddress))]
536 #[case("mailto:admin@example.com", Err(EndpointTemplateError::HostMissing))]
537 fn from_trait(#[case] url: &str, #[case] expected: Result<&str, EndpointTemplateError>) {
538 let url = Url::parse(url).unwrap();
539 let result = EndpointTemplate::try_from(url.clone());
540 let domain = result.as_ref().map(EndpointTemplate::domain);
541 assert_eq!(domain, expected.as_deref());
542 }
543
544 #[test]
545 fn setters() {
546 let url = Url::parse("http://example.com:50051/foo").unwrap();
547 let builder = EndpointTemplate::new(url.clone()).unwrap();
548
549 let origin = Uri::from_str("http://example.net:50001").unwrap();
550 let builder = builder.origin(origin.clone());
551 assert_eq!(builder.origin, Some(origin));
552
553 let user_agent = HeaderValue::from_str("my-user-agent").unwrap();
554 let builder = builder.user_agent(user_agent.clone()).unwrap();
555 assert_eq!(builder.user_agent, Some(user_agent));
556
557 let duration = Duration::from_secs(10);
558 let builder = builder.timeout(duration);
559 assert_eq!(builder.timeout, Some(duration));
560
561 let connect_timeout = Duration::from_secs(5);
562 let builder = builder.connect_timeout(connect_timeout);
563 assert_eq!(builder.connect_timeout, Some(connect_timeout));
564
565 let tcp_keepalive = Some(Duration::from_secs(30));
566 let builder = builder.tcp_keepalive(tcp_keepalive);
567 assert_eq!(builder.tcp_keepalive, tcp_keepalive);
568
569 let concurrency_limit = 10;
570 let builder = builder.concurrency_limit(concurrency_limit);
571 assert_eq!(builder.concurrency_limit, Some(concurrency_limit));
572
573 let rate_limit = (100, Duration::from_secs(1));
574 let builder = builder.rate_limit(rate_limit.0, rate_limit.1);
575 assert_eq!(builder.rate_limit, Some(rate_limit));
576
577 let init_stream_window_size = Some(64);
578 let builder = builder.initial_stream_window_size(init_stream_window_size);
579 assert_eq!(builder.init_stream_window_size, init_stream_window_size);
580
581 let init_connection_window_size = Some(128);
582 let builder = builder.initial_connection_window_size(init_connection_window_size);
583 assert_eq!(
584 builder.init_connection_window_size,
585 init_connection_window_size
586 );
587
588 let buffer_size = Some(1024);
589 let builder = builder.buffer_size(buffer_size);
590 assert_eq!(builder.buffer_size, buffer_size);
591
592 let tcp_nodelay = true;
593 let builder = builder.tcp_nodelay(tcp_nodelay);
594 assert_eq!(builder.tcp_nodelay, Some(tcp_nodelay));
595
596 let http2_keep_alive_interval = Duration::from_secs(30);
597 let builder = builder.http2_keep_alive_interval(http2_keep_alive_interval);
598 assert_eq!(
599 builder.http2_keep_alive_interval,
600 Some(http2_keep_alive_interval)
601 );
602
603 let keep_alive_timeout = Duration::from_secs(60);
604 let builder = builder.keep_alive_timeout(keep_alive_timeout);
605 assert_eq!(builder.http2_keep_alive_timeout, Some(keep_alive_timeout));
606
607 let keep_alive_while_idle = true;
608 let builder = builder.keep_alive_while_idle(keep_alive_while_idle);
609 assert_eq!(
610 builder.http2_keep_alive_while_idle,
611 Some(keep_alive_while_idle)
612 );
613
614 let http2_adaptive_window = true;
615 let builder = builder.http2_adaptive_window(http2_adaptive_window);
616 assert_eq!(builder.http2_adaptive_window, Some(http2_adaptive_window));
617
618 let http2_max_header_list_size = 8192;
619 let builder = builder.http2_max_header_list_size(http2_max_header_list_size);
620 assert_eq!(
621 builder.http2_max_header_list_size,
622 Some(http2_max_header_list_size)
623 );
624
625 let local_address = Some(IpAddr::from([127, 0, 0, 2]));
626 let builder = builder.local_address(local_address);
627 assert_eq!(builder.local_address, local_address);
628
629 let _ = builder.build([127, 0, 0, 1]);
630 }
631
632 #[test]
633 fn debug_output() {
634 let url = Url::parse("http://example.com:50051/foo").unwrap();
635 let builder = EndpointTemplate::new(url.clone()).unwrap();
636
637 let debug_output = format!("{builder:?}");
638 assert_eq!(
639 debug_output,
640 "EndpointTemplate { url: \"http://example.com:50051/foo\", .. }"
641 );
642 }
643}