Skip to main content

trillium_tokio/
client.rs

1use crate::{TokioRuntime, TokioTransport};
2use async_compat::Compat;
3use std::{
4    io::{Error, ErrorKind, Result},
5    net::SocketAddr,
6    time::Duration,
7};
8use tokio::net::TcpStream;
9use trillium_server_common::{Connector, Destination, Transport, url::Url};
10
11/// configuration for the tcp Connector
12#[derive(Default, Debug, Clone, Copy)]
13pub struct ClientConfig {
14    /// disable [nagle's algorithm](https://en.wikipedia.org/wiki/Nagle%27s_algorithm)
15    /// see [`TcpStream::set_nodelay`] for more info
16    pub nodelay: Option<bool>,
17
18    /// time to live for the tcp protocol. set [`TcpStream::set_ttl`] for more info
19    pub ttl: Option<u32>,
20
21    /// sets SO_LINGER. I don't really understand this, but see
22    /// [`TcpStream::set_linger`] for more info
23    pub linger: Option<Option<Duration>>,
24}
25
26impl ClientConfig {
27    /// constructs a default ClientConfig
28    pub const fn new() -> Self {
29        Self {
30            nodelay: None,
31            ttl: None,
32            linger: None,
33        }
34    }
35
36    /// chainable setter to set default nodelay
37    pub const fn with_nodelay(mut self, nodelay: bool) -> Self {
38        self.nodelay = Some(nodelay);
39        self
40    }
41
42    /// chainable setter for ip ttl
43    pub const fn with_ttl(mut self, ttl: u32) -> Self {
44        self.ttl = Some(ttl);
45        self
46    }
47
48    /// chainable setter for linger
49    pub const fn with_linger(mut self, linger: Option<Duration>) -> Self {
50        self.linger = Some(linger);
51        self
52    }
53}
54
55impl Connector for ClientConfig {
56    type Runtime = TokioRuntime;
57    type Transport = TokioTransport<Compat<TcpStream>>;
58    type Udp = crate::TokioUdpSocket;
59
60    async fn connect(&self, url: &Url) -> Result<Self::Transport> {
61        self.connect_to(Destination::from_url(url)?).await
62    }
63
64    async fn connect_to(&self, destination: Destination) -> Result<Self::Transport> {
65        if destination.secure() {
66            return Err(Error::new(
67                ErrorKind::InvalidInput,
68                "this connector does not support TLS",
69            ));
70        }
71
72        let addrs = destination.addrs();
73        let mut tcp = if addrs.is_empty() {
74            let host = destination.host().ok_or_else(|| {
75                Error::new(
76                    ErrorKind::InvalidInput,
77                    "destination has neither host nor addresses",
78                )
79            })?;
80            Self::Transport::connect((host, destination.port())).await?
81        } else {
82            Self::Transport::connect(addrs).await?
83        };
84
85        if let Some(nodelay) = self.nodelay {
86            tcp.set_nodelay(nodelay)?;
87        }
88
89        if let Some(ttl) = self.ttl {
90            tcp.set_ip_ttl(ttl)?;
91        }
92
93        if let Some(dur) = self.linger {
94            tcp.set_linger(dur)?;
95        }
96
97        Ok(tcp)
98    }
99
100    fn runtime(&self) -> Self::Runtime {
101        TokioRuntime::default()
102    }
103
104    async fn resolve(&self, host: &str, port: u16) -> Result<Vec<SocketAddr>> {
105        tokio::net::lookup_host((host, port))
106            .await
107            .map(Iterator::collect)
108    }
109}
110
111/// A [`Connector`] that dials a fixed Unix domain socket path.
112///
113/// Every connection opens a fresh `UnixStream` to `path`, so a single `UnixClientConfig` is safe to
114/// share across a pooled [`Client`](https://docs.trillium.rs/trillium_client/struct.Client.html) making
115/// concurrent requests. The request URL's host and port are ignored — the socket path is the
116/// address — though the URL still supplies request metadata such as the `Host` header.
117///
118/// This handles only the single-socket case. To route different requests to different upstreams (a
119/// mix of TCP and Unix sockets, or several Unix sockets), compose connectors behind a routing
120/// [`Connector`] that dispatches on the [`Destination`].
121#[cfg(unix)]
122#[derive(Clone, Debug)]
123pub struct UnixClientConfig {
124    path: std::path::PathBuf,
125}
126
127#[cfg(unix)]
128impl UnixClientConfig {
129    /// Construct a `UnixClientConfig` that dials the provided Unix domain socket path.
130    pub fn new(path: impl Into<std::path::PathBuf>) -> Self {
131        Self { path: path.into() }
132    }
133
134    async fn dial(&self) -> Result<TokioTransport<Compat<tokio::net::UnixStream>>> {
135        tokio::net::UnixStream::connect(&self.path)
136            .await
137            .map(|stream| Compat::new(stream).into())
138    }
139}
140
141#[cfg(unix)]
142impl Connector for UnixClientConfig {
143    type Runtime = TokioRuntime;
144    type Transport = TokioTransport<Compat<tokio::net::UnixStream>>;
145    type Udp = ();
146
147    async fn connect(&self, url: &Url) -> Result<Self::Transport> {
148        if url.scheme() == "https" {
149            return Err(Error::new(
150                ErrorKind::InvalidInput,
151                "this connector does not support TLS",
152            ));
153        }
154        self.dial().await
155    }
156
157    async fn connect_to(&self, destination: Destination) -> Result<Self::Transport> {
158        if destination.secure() {
159            return Err(Error::new(
160                ErrorKind::InvalidInput,
161                "this connector does not support TLS",
162            ));
163        }
164        self.dial().await
165    }
166
167    fn runtime(&self) -> Self::Runtime {
168        TokioRuntime::default()
169    }
170
171    async fn resolve(&self, _host: &str, _port: u16) -> Result<Vec<SocketAddr>> {
172        Ok(vec![])
173    }
174}