salvo_core/conn/
tcp.rs

1//! TcpListener and it's implements.
2use std::fmt::{self, Debug, Formatter};
3use std::io::{Error as IoError, Result as IoResult};
4use std::net::SocketAddr;
5use std::sync::Arc;
6use std::vec;
7
8use futures_util::future::{BoxFuture, FutureExt};
9use tokio::io::{AsyncRead, AsyncWrite};
10use tokio::net::{TcpListener as TokioTcpListener, TcpStream, ToSocketAddrs};
11use tokio_util::sync::CancellationToken;
12
13use crate::conn::{Holding, HttpBuilder, StraightStream};
14use crate::fuse::{ArcFuseFactory, FuseEvent, FuseInfo, TransProto};
15use crate::http::Version;
16use crate::http::uri::Scheme;
17use crate::service::HyperHandler;
18
19use super::{Accepted, Acceptor, Coupler, DynStream, Listener};
20
21#[cfg(any(feature = "rustls", feature = "native-tls", feature = "openssl"))]
22use crate::conn::IntoConfigStream;
23
24#[cfg(feature = "rustls")]
25use crate::conn::rustls::RustlsListener;
26
27#[cfg(feature = "native-tls")]
28use crate::conn::native_tls::NativeTlsListener;
29
30#[cfg(feature = "openssl")]
31use crate::conn::openssl::OpensslListener;
32
33#[cfg(feature = "acme")]
34use crate::conn::acme::AcmeListener;
35
36/// `TcpListener` is used to create a TCP connection listener.
37pub struct TcpListener<T> {
38    local_addr: T,
39    ttl: Option<u32>,
40    #[cfg(feature = "socket2")]
41    backlog: Option<u32>,
42}
43impl<T: Debug> Debug for TcpListener<T> {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        f.debug_struct("TcpListener")
46            .field("local_addr", &self.local_addr)
47            .field("ttl", &self.ttl)
48            .finish()
49    }
50}
51impl<T: ToSocketAddrs + Send + 'static> TcpListener<T> {
52    /// Bind to socket address.
53    #[cfg(not(feature = "socket2"))]
54    #[inline]
55    pub fn new(local_addr: T) -> Self {
56        #[cfg(not(feature = "socket2"))]
57        Self {
58            local_addr,
59            ttl: None,
60        }
61    }
62    /// Bind to socket address.
63    #[cfg(feature = "socket2")]
64    #[inline]
65    pub fn new(local_addr: T) -> Self {
66        Self {
67            local_addr,
68            ttl: None,
69            backlog: None,
70        }
71    }
72
73    cfg_feature! {
74        #![feature = "rustls"]
75
76        /// Creates a new `RustlsListener` from current `TcpListener`.
77        #[inline]
78        pub fn rustls<S, C, E>(self, config_stream: S) -> RustlsListener<S, C, Self, E>
79        where
80            S: IntoConfigStream<C> + Send + 'static,
81            C: TryInto<crate::conn::rustls::ServerConfig, Error = E> + Send + 'static,
82            E: std::error::Error + Send + 'static
83        {
84            RustlsListener::new(config_stream, self)
85        }
86    }
87
88    cfg_feature! {
89        #![feature = "native-tls"]
90
91        /// Creates a new `NativeTlsListener` from current `TcpListener`.
92        #[inline]
93        pub fn native_tls<S, C, E>(self, config_stream: S) -> NativeTlsListener<S, C, Self, E>
94        where
95            S: IntoConfigStream<C> + Send + 'static,
96            C: TryInto<crate::conn::native_tls::Identity, Error = E> + Send + 'static,
97            E: std::error::Error + Send + 'static
98        {
99            NativeTlsListener::new(config_stream, self)
100        }
101    }
102
103    cfg_feature! {
104        #![feature = "openssl"]
105
106        /// Creates a new `OpensslListener` from current `TcpListener`.
107        #[inline]
108        pub fn openssl<S, C, E>(self, config_stream: S) -> OpensslListener<S, C, Self, E>
109        where
110            S: IntoConfigStream<C> + Send + 'static,
111            C: TryInto<crate::conn::openssl::SslAcceptorBuilder, Error = E> + Send + 'static,
112            E: std::error::Error + Send + 'static
113        {
114            OpensslListener::new(config_stream, self)
115        }
116    }
117    cfg_feature! {
118        #![feature = "acme"]
119
120        /// Creates a new `AcmeListener` from current `TcpListener`.
121        #[inline]
122        pub fn acme(self) -> AcmeListener<Self>
123        {
124            AcmeListener::new(self)
125        }
126    }
127
128    /// Sets the value for the `IP_TTL` option on this socket.
129    ///
130    /// This value sets the time-to-live field that is used in every packet sent
131    /// from this socket.
132    #[must_use]
133    pub fn ttl(mut self, ttl: u32) -> Self {
134        self.ttl = Some(ttl);
135        self
136    }
137
138    cfg_feature! {
139        #![feature = "socket2"]
140        /// Set backlog capacity.
141        #[inline]
142    #[must_use]
143        pub fn backlog(mut self, backlog: u32) -> Self {
144            self.backlog = Some(backlog);
145            self
146        }
147    }
148}
149impl<T> Listener for TcpListener<T>
150where
151    T: ToSocketAddrs + Send + 'static,
152{
153    type Acceptor = TcpAcceptor;
154
155    async fn try_bind(self) -> crate::Result<Self::Acceptor> {
156        let inner = TokioTcpListener::bind(self.local_addr).await?;
157
158        #[cfg(feature = "socket2")]
159        if let Some(backlog) = self.backlog {
160            let socket = socket2::SockRef::from(&inner);
161            socket.listen(backlog as _)?;
162        }
163        if let Some(ttl) = self.ttl {
164            inner.set_ttl(ttl)?;
165        }
166
167        Ok(inner.try_into()?)
168    }
169}
170/// `TcpAcceptor` is used to accept a TCP connection.
171#[derive(Debug)]
172pub struct TcpAcceptor {
173    inner: TokioTcpListener,
174    holdings: Vec<Holding>,
175}
176
177impl TcpAcceptor {
178    /// Get the inner `TokioTcpListener`.
179    pub fn inner(&self) -> &TokioTcpListener {
180        &self.inner
181    }
182
183    /// Get the local address that this listener is bound to.
184    ///
185    /// This can be useful, for example, when binding to port 0 to figure out
186    /// which port was actually bound.
187    pub fn local_addr(&self) -> IoResult<SocketAddr> {
188        self.inner.local_addr()
189    }
190
191    /// Gets the value of the `IP_TTL` option for this socket.
192    pub fn ttl(&self) -> IoResult<u32> {
193        self.inner.ttl()
194    }
195
196    /// Sets the value for the `IP_TTL` option on this socket.
197    ///
198    /// This value sets the time-to-live field that is used in every packet sent
199    /// from this socket.
200    pub fn set_ttl(&self, ttl: u32) -> IoResult<()> {
201        self.inner.set_ttl(ttl)
202    }
203
204    /// Convert this `TcpAcceptor` into a boxed `DynTcpAcceptor`.
205    pub fn into_boxed(self) -> Box<dyn DynTcpAcceptor> {
206        Box::new(ToDynTcpAcceptor(self))
207    }
208}
209
210impl TryFrom<TokioTcpListener> for TcpAcceptor {
211    type Error = IoError;
212    fn try_from(inner: TokioTcpListener) -> Result<Self, Self::Error> {
213        let holdings = vec![Holding {
214            local_addr: inner.local_addr()?.into(),
215            #[cfg(not(feature = "http2-cleartext"))]
216            http_versions: vec![Version::HTTP_11],
217            #[cfg(feature = "http2-cleartext")]
218            http_versions: vec![Version::HTTP_11, Version::HTTP_2],
219            http_scheme: Scheme::HTTP,
220        }];
221
222        Ok(Self { inner, holdings })
223    }
224}
225
226impl Acceptor for TcpAcceptor {
227    type Coupler = TcpCoupler<Self::Stream>;
228    type Stream = StraightStream<TcpStream>;
229
230    #[inline]
231    fn holdings(&self) -> &[Holding] {
232        &self.holdings
233    }
234
235    #[inline]
236    async fn accept(
237        &mut self,
238        fuse_factory: Option<ArcFuseFactory>,
239    ) -> IoResult<Accepted<Self::Coupler, Self::Stream>> {
240        self.inner.accept().await.map(move |(conn, remote_addr)| {
241            let local_addr = self.holdings[0].local_addr.clone();
242            let fusewire = fuse_factory.map(|f| {
243                f.create(FuseInfo {
244                    trans_proto: TransProto::Tcp,
245                    remote_addr: remote_addr.into(),
246                    local_addr: local_addr.clone(),
247                })
248            });
249            Accepted {
250                coupler: TcpCoupler::new(),
251                stream: StraightStream::new(conn, fusewire.clone()),
252                fusewire,
253                remote_addr: remote_addr.into(),
254                local_addr,
255                http_scheme: Scheme::HTTP,
256            }
257        })
258    }
259}
260
261#[doc(hidden)]
262pub struct TcpCoupler<S>
263where
264    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
265{
266    _marker: std::marker::PhantomData<S>,
267}
268impl<S> TcpCoupler<S>
269where
270    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
271{
272    /// Create a new `TcpCoupler`.
273    #[must_use]
274    pub fn new() -> Self {
275        Self {
276            _marker: std::marker::PhantomData,
277        }
278    }
279}
280impl<S> Default for TcpCoupler<S>
281where
282    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
283{
284    fn default() -> Self {
285        Self::new()
286    }
287}
288
289impl<S> Coupler for TcpCoupler<S>
290where
291    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
292{
293    type Stream = S;
294
295    fn couple(
296        &self,
297        stream: Self::Stream,
298        handler: HyperHandler,
299        builder: Arc<HttpBuilder>,
300        graceful_stop_token: Option<CancellationToken>,
301    ) -> BoxFuture<'static, IoResult<()>> {
302        let fusewire = handler.fusewire.clone();
303        if let Some(fusewire) = &fusewire {
304            fusewire.event(FuseEvent::Alive);
305        }
306        async move {
307            builder
308                .serve_connection(stream, handler, fusewire, graceful_stop_token)
309                .await
310                .map_err(IoError::other)
311        }
312        .boxed()
313    }
314}
315impl<S> Debug for TcpCoupler<S>
316where
317    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
318{
319    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
320        f.debug_struct("TcpCoupler").finish()
321    }
322}
323
324/// Dynamic TCP acceptor trait.
325pub trait DynTcpAcceptor: Send {
326    /// Returns the holdings of the acceptor.
327    fn holdings(&self) -> &[Holding];
328
329    /// Accept a new connection.
330    fn accept(
331        &mut self,
332        fuse_factory: Option<ArcFuseFactory>,
333    ) -> BoxFuture<'_, IoResult<Accepted<TcpCoupler<DynStream>, DynStream>>>;
334}
335impl Acceptor for dyn DynTcpAcceptor {
336    type Coupler = TcpCoupler<DynStream>;
337    type Stream = DynStream;
338
339    #[inline]
340    fn holdings(&self) -> &[Holding] {
341        DynTcpAcceptor::holdings(self)
342    }
343
344    #[inline]
345    async fn accept(
346        &mut self,
347        fuse_factory: Option<ArcFuseFactory>,
348    ) -> IoResult<Accepted<Self::Coupler, Self::Stream>> {
349        DynTcpAcceptor::accept(self, fuse_factory).await
350    }
351}
352
353/// Convert an `Acceptor` into a boxed `DynTcpAcceptor`.
354pub struct ToDynTcpAcceptor<A>(pub A);
355impl<A> DynTcpAcceptor for ToDynTcpAcceptor<A>
356where
357    A: Acceptor + 'static,
358    A::Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static,
359{
360    #[inline]
361    fn holdings(&self) -> &[Holding] {
362        self.0.holdings()
363    }
364
365    #[inline]
366    fn accept(
367        &mut self,
368        fuse_factory: Option<ArcFuseFactory>,
369    ) -> BoxFuture<'_, IoResult<Accepted<TcpCoupler<DynStream>, DynStream>>> {
370        async move {
371            let accepted = self.0.accept(fuse_factory).await?;
372            Ok(accepted.map_into(|_| TcpCoupler::new(), DynStream::new))
373        }
374        .boxed()
375    }
376}
377impl<A: Debug> Debug for ToDynTcpAcceptor<A> {
378    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
379        f.debug_struct("ToDynTcpAcceptor")
380            .field("inner", &self.0)
381            .finish()
382    }
383}
384
385/// Dynamic TCP acceptors.
386pub struct DynTcpAcceptors {
387    inners: Vec<Box<dyn DynTcpAcceptor>>,
388    holdings: Vec<Holding>,
389}
390impl DynTcpAcceptors {
391    /// Create a new `DynTcpAcceptors`.
392    #[must_use]
393    pub fn new(inners: Vec<Box<dyn DynTcpAcceptor>>) -> Self {
394        let holdings = inners
395            .iter()
396            .flat_map(|inner| inner.holdings())
397            .cloned()
398            .collect();
399        Self { inners, holdings }
400    }
401}
402impl DynTcpAcceptor for DynTcpAcceptors {
403    #[inline]
404    fn holdings(&self) -> &[Holding] {
405        &self.holdings
406    }
407
408    #[inline]
409    fn accept(
410        &mut self,
411        fuse_factory: Option<ArcFuseFactory>,
412    ) -> BoxFuture<'_, IoResult<Accepted<TcpCoupler<DynStream>, DynStream>>> {
413        async move {
414            let mut set = Vec::new();
415            for inner in &mut self.inners {
416                let fuse_factory = fuse_factory.clone();
417                set.push(async move { inner.accept(fuse_factory).await }.boxed());
418            }
419            futures_util::future::select_all(set.into_iter()).await.0
420        }
421        .boxed()
422    }
423}
424impl Acceptor for DynTcpAcceptors {
425    type Coupler = TcpCoupler<DynStream>;
426    type Stream = DynStream;
427
428    #[inline]
429    fn holdings(&self) -> &[Holding] {
430        DynTcpAcceptor::holdings(self)
431    }
432
433    #[inline]
434    async fn accept(
435        &mut self,
436        fuse_factory: Option<ArcFuseFactory>,
437    ) -> IoResult<Accepted<Self::Coupler, Self::Stream>> {
438        DynTcpAcceptor::accept(self, fuse_factory).await
439    }
440}
441impl Debug for DynTcpAcceptors {
442    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
443        f.debug_struct("DynTcpAcceptors")
444            .field("holdings", &self.holdings)
445            .finish()
446    }
447}
448
449#[cfg(test)]
450mod tests {
451    use tokio::io::{AsyncReadExt, AsyncWriteExt};
452
453    use super::*;
454
455    #[tokio::test]
456    async fn test_tcp_listener() {
457        let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 6878));
458        let mut acceptor = TcpListener::new(addr).bind().await;
459        let addr = acceptor.holdings()[0]
460            .local_addr
461            .clone()
462            .into_std()
463            .unwrap();
464        tokio::spawn(async move {
465            let mut stream = TcpStream::connect(addr).await.unwrap();
466            stream.write_i32(150).await.unwrap();
467        });
468
469        let Accepted { mut stream, .. } = acceptor.accept(None).await.unwrap();
470        assert_eq!(stream.read_i32().await.unwrap(), 150);
471    }
472}