1use core::fmt;
2use http::HeaderValue;
3use std::error::Error;
4use std::fmt::{Debug, Display};
5use std::{net::IpAddr, str::FromStr, time::Duration};
6#[cfg(feature = "tls")]
7use tonic::transport::ClientTlsConfig;
8use tonic::transport::{Endpoint, Uri};
9use url::Host;
10use url::Url;
11
12#[derive(Clone)]
17pub struct EndpointTemplate {
18 url: Url,
19 origin: Option<Uri>,
20 user_agent: Option<HeaderValue>,
21 timeout: Option<Duration>,
22 concurrency_limit: Option<usize>,
23 rate_limit: Option<(u64, Duration)>,
24 #[cfg(feature = "tls")]
25 tls_config: Option<ClientTlsConfig>,
26 buffer_size: Option<usize>,
27 init_stream_window_size: Option<u32>,
28 init_connection_window_size: Option<u32>,
29 tcp_keepalive: Option<Duration>,
30 tcp_nodelay: Option<bool>,
31 http2_keep_alive_interval: Option<Duration>,
32 http2_keep_alive_timeout: Option<Duration>,
33 http2_keep_alive_while_idle: Option<bool>,
34 http2_max_header_list_size: Option<u32>,
35 connect_timeout: Option<Duration>,
36 http2_adaptive_window: Option<bool>,
37 local_address: Option<IpAddr>,
38 }
40
41impl EndpointTemplate {
42 pub fn new(url: impl TryInto<Url>) -> Result<Self, EndpointTemplateError> {
54 let url: Url = url.try_into().map_err(|_| EndpointTemplateError::NotAUrl)?;
55
56 match url.host() {
58 Some(host) => match host {
59 Host::Domain(_) => {}
60 _ => return Err(EndpointTemplateError::AlreadyIpAddress),
61 },
62 None => return Err(EndpointTemplateError::HostMissing),
63 }
64
65 if url.cannot_be_a_base() {
67 return Err(EndpointTemplateError::Inconvertible);
71 }
72
73 if Uri::from_str(url.as_str()).is_err() {
75 return Err(EndpointTemplateError::Inconvertible);
79 }
80
81 Ok(Self {
82 url,
83 origin: None,
84 user_agent: None,
85 timeout: None,
86 #[cfg(feature = "tls")]
87 tls_config: None,
88 concurrency_limit: None,
89 rate_limit: None,
90 buffer_size: None,
91 init_stream_window_size: None,
92 init_connection_window_size: None,
93 tcp_keepalive: None,
94 tcp_nodelay: None,
95 http2_keep_alive_interval: None,
96 http2_keep_alive_timeout: None,
97 http2_keep_alive_while_idle: None,
98 http2_max_header_list_size: None,
99 connect_timeout: None,
100 http2_adaptive_window: None,
101 local_address: None,
102 })
103 }
104
105 #[allow(clippy::missing_panics_doc)]
111 pub fn build(&self, ip_address: impl Into<IpAddr>) -> Endpoint {
112 let mut endpoint = Endpoint::from(self.build_uri(ip_address.into()));
113
114 if let Some(origin) = self.origin.clone() {
115 endpoint = endpoint.origin(origin);
116 }
117
118 if let Some(user_agent) = self.user_agent.clone() {
119 endpoint = endpoint
120 .user_agent(user_agent)
121 .expect("already checked in the setter");
122 }
123
124 if let Some(timeout) = self.timeout {
125 endpoint = endpoint.timeout(timeout);
126 }
127
128 #[cfg(feature = "tls")]
129 if let Some(tls_config) = self.tls_config.clone() {
130 endpoint = endpoint
131 .tls_config(tls_config)
132 .expect("already checked in the setter");
133 }
134
135 if let Some(connect_timeout) = self.connect_timeout {
136 endpoint = endpoint.connect_timeout(connect_timeout);
137 }
138
139 endpoint = endpoint.tcp_keepalive(self.tcp_keepalive);
140
141 if let Some(limit) = self.concurrency_limit {
142 endpoint = endpoint.concurrency_limit(limit);
143 }
144
145 if let Some((limit, duration)) = self.rate_limit {
146 endpoint = endpoint.rate_limit(limit, duration);
147 }
148
149 if let Some(sz) = self.init_stream_window_size {
150 endpoint = endpoint.initial_stream_window_size(sz);
151 }
152
153 if let Some(sz) = self.init_connection_window_size {
154 endpoint = endpoint.initial_connection_window_size(sz);
155 }
156
157 endpoint = endpoint.buffer_size(self.buffer_size);
158
159 if let Some(tcp_nodelay) = self.tcp_nodelay {
160 endpoint = endpoint.tcp_nodelay(tcp_nodelay);
161 }
162
163 if let Some(interval) = self.http2_keep_alive_interval {
164 endpoint = endpoint.http2_keep_alive_interval(interval);
165 }
166
167 if let Some(duration) = self.http2_keep_alive_timeout {
168 endpoint = endpoint.keep_alive_timeout(duration);
169 }
170
171 if let Some(enabled) = self.http2_keep_alive_while_idle {
172 endpoint = endpoint.keep_alive_while_idle(enabled);
173 }
174
175 if let Some(enabled) = self.http2_adaptive_window {
176 endpoint = endpoint.http2_adaptive_window(enabled);
177 }
178
179 if let Some(size) = self.http2_max_header_list_size {
180 endpoint = endpoint.http2_max_header_list_size(size);
181 }
182
183 endpoint = endpoint.local_address(self.local_address);
184
185 endpoint
186 }
187
188 #[allow(clippy::missing_panics_doc)]
190 pub fn domain(&self) -> &str {
191 self.url
192 .domain()
193 .expect("already checked in the constructor")
194 }
195
196 fn build_uri(&self, ip_addr: IpAddr) -> Uri {
197 let mut url = self.url.clone();
200 url.set_ip_host(ip_addr)
201 .expect("already checked in the constructor by trying cannot_be_a_base");
202 Uri::from_str(url.as_str()).expect("starting from Url, this should always be a valid Uri")
203 }
204
205 pub fn user_agent(
213 self,
214 user_agent: impl TryInto<HeaderValue>,
215 ) -> Result<Self, EndpointTemplateError> {
216 user_agent
217 .try_into()
218 .map(|ua| Self {
219 user_agent: Some(ua),
220 ..self
221 })
222 .map_err(|_| EndpointTemplateError::InvalidUserAgent)
223 }
224
225 #[must_use]
227 pub fn origin(self, origin: Uri) -> Self {
228 Self {
229 origin: Some(origin),
230 ..self
231 }
232 }
233
234 #[must_use]
236 pub fn timeout(self, dur: Duration) -> Self {
237 Self {
238 timeout: Some(dur),
239 ..self
240 }
241 }
242
243 #[must_use]
245 pub fn connect_timeout(self, dur: Duration) -> Self {
246 Self {
247 connect_timeout: Some(dur),
248 ..self
249 }
250 }
251
252 #[must_use]
254 pub fn tcp_keepalive(self, tcp_keepalive: Option<Duration>) -> Self {
255 Self {
256 tcp_keepalive,
257 ..self
258 }
259 }
260
261 #[must_use]
263 pub fn concurrency_limit(self, limit: usize) -> Self {
264 Self {
265 concurrency_limit: Some(limit),
266 ..self
267 }
268 }
269
270 #[must_use]
272 pub fn rate_limit(self, limit: u64, duration: Duration) -> Self {
273 Self {
274 rate_limit: Some((limit, duration)),
275 ..self
276 }
277 }
278
279 #[must_use]
281 pub fn initial_stream_window_size(self, sz: impl Into<Option<u32>>) -> Self {
282 Self {
283 init_stream_window_size: sz.into(),
284 ..self
285 }
286 }
287
288 #[must_use]
290 pub fn initial_connection_window_size(self, sz: impl Into<Option<u32>>) -> Self {
291 Self {
292 init_connection_window_size: sz.into(),
293 ..self
294 }
295 }
296
297 #[must_use]
299 pub fn buffer_size(self, sz: impl Into<Option<usize>>) -> Self {
300 Self {
301 buffer_size: sz.into(),
302 ..self
303 }
304 }
305
306 #[cfg(feature = "tls")]
314 pub fn tls_config(self, tls_config: ClientTlsConfig) -> Result<Self, EndpointTemplateError> {
315 let endpoint = self.build(std::net::Ipv4Addr::new(127, 0, 0, 1));
317 let _ = endpoint
318 .tls_config(tls_config.clone())
319 .map_err(|_| EndpointTemplateError::InvalidTlsConfig)?;
320
321 Ok(Self {
322 tls_config: Some(tls_config),
323 ..self
324 })
325 }
326
327 #[must_use]
329 pub fn tcp_nodelay(self, enabled: bool) -> Self {
330 Self {
331 tcp_nodelay: Some(enabled),
332 ..self
333 }
334 }
335
336 #[must_use]
338 pub fn http2_keep_alive_interval(self, interval: Duration) -> Self {
339 Self {
340 http2_keep_alive_interval: Some(interval),
341 ..self
342 }
343 }
344
345 #[must_use]
347 pub fn keep_alive_timeout(self, duration: Duration) -> Self {
348 Self {
349 http2_keep_alive_timeout: Some(duration),
350 ..self
351 }
352 }
353
354 #[must_use]
356 pub fn keep_alive_while_idle(self, enabled: bool) -> Self {
357 Self {
358 http2_keep_alive_while_idle: Some(enabled),
359 ..self
360 }
361 }
362
363 #[must_use]
365 pub fn http2_adaptive_window(self, enabled: bool) -> Self {
366 Self {
367 http2_adaptive_window: Some(enabled),
368 ..self
369 }
370 }
371
372 #[must_use]
374 pub fn http2_max_header_list_size(self, size: u32) -> Self {
375 Self {
376 http2_max_header_list_size: Some(size),
377 ..self
378 }
379 }
380
381 #[must_use]
383 pub fn local_address(self, ip: Option<IpAddr>) -> Self {
384 Self {
385 local_address: ip,
386 ..self
387 }
388 }
389}
390
391impl Debug for EndpointTemplate {
392 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
393 f.debug_struct("EndpointTemplate")
394 .field("url", &self.url.as_str())
395 .finish_non_exhaustive()
396 }
397}
398
399#[derive(Debug, PartialEq, Eq, Clone, Copy, PartialOrd, Ord, Hash)]
401pub enum EndpointTemplateError {
402 NotAUrl,
406
407 HostMissing,
411
412 AlreadyIpAddress,
416
417 Inconvertible,
424
425 InvalidUserAgent,
430
431 #[cfg(feature = "tls")]
435 InvalidTlsConfig,
436}
437
438impl TryFrom<Url> for EndpointTemplate {
439 type Error = EndpointTemplateError;
440
441 fn try_from(url: Url) -> Result<Self, Self::Error> {
442 Self::new(url)
443 }
444}
445
446#[cfg_attr(coverage_nightly, coverage(off))]
447impl Display for EndpointTemplateError {
448 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
449 match self {
450 EndpointTemplateError::NotAUrl => write!(f, "not a valid URL"),
451 EndpointTemplateError::HostMissing => write!(f, "host missing"),
452 EndpointTemplateError::AlreadyIpAddress => write!(f, "already an IP address"),
453 EndpointTemplateError::Inconvertible => write!(f, "inconvertible URL"),
454 EndpointTemplateError::InvalidUserAgent => write!(f, "invalid user agent"),
455 #[cfg(feature = "tls")]
456 EndpointTemplateError::InvalidTlsConfig => write!(f, "invalid TLS config"),
457 }
458 }
459}
460
461impl Error for EndpointTemplateError {}
462
463#[cfg(test)]
464#[cfg_attr(coverage_nightly, coverage(off))]
465mod tests {
466 use std::{net::IpAddr, str::FromStr};
467
468 use http::Uri;
469 use url::Url;
470
471 use super::*;
472
473 #[test]
474 fn can_substitute_domain_fot_ipv4_address() {
475 let builder =
476 EndpointTemplate::new(Url::parse("http://example.com:50051/foo").unwrap()).unwrap();
477
478 let endpoint = builder.build("203.0.113.6".parse::<IpAddr>().unwrap());
479 assert_eq!(
480 *endpoint.uri(),
481 Uri::from_str("http://203.0.113.6:50051/foo").unwrap()
482 );
483 }
484
485 #[test]
486 fn can_substitute_domain_fot_ipv6_address() {
487 let builder =
488 EndpointTemplate::new(Url::parse("http://example.com:50051/foo").unwrap()).unwrap();
489
490 let endpoint = builder.build("2001:db8::".parse::<IpAddr>().unwrap());
491 assert_eq!(
492 *endpoint.uri(),
493 Uri::from_str("http://[2001:db8::]:50051/foo").unwrap()
494 );
495 }
496
497 #[rstest::rstest]
498 #[case("http://127.0.0.1:50051", EndpointTemplateError::AlreadyIpAddress)]
499 #[case("http://[::1]:50051", EndpointTemplateError::AlreadyIpAddress)]
500 #[case("mailto:admin@example.com", EndpointTemplateError::HostMissing)]
501 fn builder_error(#[case] input: &str, #[case] expected: EndpointTemplateError) {
502 let result = EndpointTemplate::new(Url::parse(input).unwrap());
503 assert!(result.is_err());
504 assert_eq!(result.unwrap_err(), expected);
505 }
506
507 #[rstest::rstest]
508 #[case("http://example.com:50051/foo", Ok("example.com"))]
509 #[case("http://127.0.0.1:50051", Err(EndpointTemplateError::AlreadyIpAddress))]
510 #[case("http://[::1]:50051", Err(EndpointTemplateError::AlreadyIpAddress))]
511 #[case("mailto:admin@example.com", Err(EndpointTemplateError::HostMissing))]
512 fn from_trait(#[case] url: &str, #[case] expected: Result<&str, EndpointTemplateError>) {
513 let url = Url::parse(url).unwrap();
514 let result = EndpointTemplate::try_from(url.clone());
515 let domain = result.as_ref().map(EndpointTemplate::domain);
516 assert_eq!(domain, expected.as_deref());
517 }
518
519 #[test]
520 fn setters() {
521 let url = Url::parse("http://example.com:50051/foo").unwrap();
522 let builder = EndpointTemplate::new(url.clone()).unwrap();
523
524 let origin = Uri::from_str("http://example.net:50001").unwrap();
525 let builder = builder.origin(origin.clone());
526 assert_eq!(builder.origin, Some(origin));
527
528 let user_agent = HeaderValue::from_str("my-user-agent").unwrap();
529 let builder = builder.user_agent(user_agent.clone()).unwrap();
530 assert_eq!(builder.user_agent, Some(user_agent));
531
532 let duration = Duration::from_secs(10);
533 let builder = builder.timeout(duration);
534 assert_eq!(builder.timeout, Some(duration));
535
536 let connect_timeout = Duration::from_secs(5);
537 let builder = builder.connect_timeout(connect_timeout);
538 assert_eq!(builder.connect_timeout, Some(connect_timeout));
539
540 let tcp_keepalive = Some(Duration::from_secs(30));
541 let builder = builder.tcp_keepalive(tcp_keepalive);
542 assert_eq!(builder.tcp_keepalive, tcp_keepalive);
543
544 let concurrency_limit = 10;
545 let builder = builder.concurrency_limit(concurrency_limit);
546 assert_eq!(builder.concurrency_limit, Some(concurrency_limit));
547
548 let rate_limit = (100, Duration::from_secs(1));
549 let builder = builder.rate_limit(rate_limit.0, rate_limit.1);
550 assert_eq!(builder.rate_limit, Some(rate_limit));
551
552 let init_stream_window_size = Some(64);
553 let builder = builder.initial_stream_window_size(init_stream_window_size);
554 assert_eq!(builder.init_stream_window_size, init_stream_window_size);
555
556 let init_connection_window_size = Some(128);
557 let builder = builder.initial_connection_window_size(init_connection_window_size);
558 assert_eq!(
559 builder.init_connection_window_size,
560 init_connection_window_size
561 );
562
563 let buffer_size = Some(1024);
564 let builder = builder.buffer_size(buffer_size);
565 assert_eq!(builder.buffer_size, buffer_size);
566
567 let tcp_nodelay = true;
568 let builder = builder.tcp_nodelay(tcp_nodelay);
569 assert_eq!(builder.tcp_nodelay, Some(tcp_nodelay));
570
571 let http2_keep_alive_interval = Duration::from_secs(30);
572 let builder = builder.http2_keep_alive_interval(http2_keep_alive_interval);
573 assert_eq!(
574 builder.http2_keep_alive_interval,
575 Some(http2_keep_alive_interval)
576 );
577
578 let keep_alive_timeout = Duration::from_secs(60);
579 let builder = builder.keep_alive_timeout(keep_alive_timeout);
580 assert_eq!(builder.http2_keep_alive_timeout, Some(keep_alive_timeout));
581
582 let keep_alive_while_idle = true;
583 let builder = builder.keep_alive_while_idle(keep_alive_while_idle);
584 assert_eq!(
585 builder.http2_keep_alive_while_idle,
586 Some(keep_alive_while_idle)
587 );
588
589 let http2_adaptive_window = true;
590 let builder = builder.http2_adaptive_window(http2_adaptive_window);
591 assert_eq!(builder.http2_adaptive_window, Some(http2_adaptive_window));
592
593 let http2_max_header_list_size = 8192;
594 let builder = builder.http2_max_header_list_size(http2_max_header_list_size);
595 assert_eq!(
596 builder.http2_max_header_list_size,
597 Some(http2_max_header_list_size)
598 );
599
600 let local_address = Some(IpAddr::from([127, 0, 0, 2]));
601 let builder = builder.local_address(local_address);
602 assert_eq!(builder.local_address, local_address);
603
604 let _ = builder.build([127, 0, 0, 1]);
605 }
606
607 #[test]
608 fn debug_output() {
609 let url = Url::parse("http://example.com:50051/foo").unwrap();
610 let builder = EndpointTemplate::new(url.clone()).unwrap();
611
612 let debug_output = format!("{builder:?}");
613 assert_eq!(
614 debug_output,
615 "EndpointTemplate { url: \"http://example.com:50051/foo\", .. }"
616 );
617 }
618}