salvo_core/conn/rustls/
listener.rs

1//! rustls module
2use std::error::Error as StdError;
3use std::fmt::{self, Debug, Formatter};
4use std::io::{Error as IoError, Result as IoResult};
5use std::marker::PhantomData;
6use std::pin::Pin;
7use std::sync::Arc;
8use std::task::{Context, Poll};
9
10use futures_util::stream::{BoxStream, Stream, StreamExt};
11use futures_util::task::noop_waker_ref;
12use tokio::io::{AsyncRead, AsyncWrite};
13use tokio_rustls::server::TlsStream;
14
15use crate::conn::tcp::{DynTcpAcceptor, TcpCoupler, ToDynTcpAcceptor};
16use crate::conn::{Accepted, Acceptor, HandshakeStream, Holding, IntoConfigStream, Listener};
17use crate::fuse::ArcFuseFactory;
18use crate::http::uri::Scheme;
19
20use super::ServerConfig;
21
22/// A wrapper of `Listener` with rustls.
23pub struct RustlsListener<S, C, T, E> {
24    config_stream: S,
25    inner: T,
26    _phantom: PhantomData<(C, E)>,
27}
28
29impl<S, C, T, E> Debug for RustlsListener<S, C, T, E> {
30    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
31        f.debug_struct("RustlsListener").finish()
32    }
33}
34
35impl<S, C, T, E> RustlsListener<S, C, T, E>
36where
37    S: IntoConfigStream<C> + Send + 'static,
38    C: TryInto<ServerConfig, Error = E> + Send + 'static,
39    T: Listener + Send,
40    E: StdError + Send,
41{
42    /// Create a new `RustlsListener`.
43    #[inline]
44    pub fn new(config_stream: S, inner: T) -> Self {
45        Self {
46            config_stream,
47            inner,
48            _phantom: PhantomData,
49        }
50    }
51}
52
53impl<S, C, T, E> Listener for RustlsListener<S, C, T, E>
54where
55    S: IntoConfigStream<C> + Send + 'static,
56    C: TryInto<ServerConfig, Error = E> + Send + 'static,
57    T: Listener + Send + 'static,
58    T::Acceptor: Send + 'static,
59    <T::Acceptor as Acceptor>::Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static,
60    E: StdError + Send + 'static,
61{
62    type Acceptor = RustlsAcceptor<BoxStream<'static, C>, C, T::Acceptor, E>;
63
64    async fn try_bind(self) -> crate::Result<Self::Acceptor> {
65        Ok(RustlsAcceptor::new(
66            self.config_stream.into_stream().boxed(),
67            self.inner.try_bind().await?,
68        ))
69    }
70}
71
72/// A wrapper of `Acceptor` with rustls.
73pub struct RustlsAcceptor<S, C, T, E> {
74    config_stream: S,
75    inner: T,
76    holdings: Vec<Holding>,
77    tls_acceptor: Option<tokio_rustls::TlsAcceptor>,
78    _phantom: PhantomData<(C, E)>,
79}
80
81impl<S, C, T, E> Debug for RustlsAcceptor<S, C, T, E> {
82    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
83        f.debug_struct("RustlsAcceptor").finish()
84    }
85}
86
87impl<S, C, T, E> RustlsAcceptor<S, C, T, E>
88where
89    S: Stream<Item = C> + Unpin + Send + 'static,
90    C: TryInto<ServerConfig, Error = E> + Send + 'static,
91    T: Acceptor + Send + 'static,
92    T::Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static,
93    E: StdError + Send + 'static,
94{
95    /// Create a new `RustlsAcceptor`.
96    pub fn new(config_stream: S, inner: T) -> Self {
97        let holdings = inner
98            .holdings()
99            .iter()
100            .map(|h| {
101                #[allow(unused_mut)]
102                let mut versions = h.http_versions.clone();
103                #[cfg(feature = "http1")]
104                if !versions.contains(&crate::http::Version::HTTP_11) {
105                    versions.push(crate::http::Version::HTTP_11);
106                }
107                #[cfg(feature = "http2")]
108                if !versions.contains(&crate::http::Version::HTTP_2) {
109                    versions.push(crate::http::Version::HTTP_2);
110                }
111                Holding {
112                    local_addr: h.local_addr.clone(),
113                    http_versions: versions,
114                    http_scheme: Scheme::HTTPS,
115                }
116            })
117            .collect();
118        Self {
119            config_stream,
120            inner,
121            holdings,
122            tls_acceptor: None,
123            _phantom: PhantomData,
124        }
125    }
126
127    /// Get the inner `Acceptor`.
128    pub fn inner(&self) -> &T {
129        &self.inner
130    }
131
132    /// Convert this `RustlsAcceptor` into a boxed `DynTcpAcceptor`.
133    pub fn into_boxed(self) -> Box<dyn DynTcpAcceptor> {
134        Box::new(ToDynTcpAcceptor(self))
135    }
136}
137
138impl<S, C, T, E> Acceptor for RustlsAcceptor<S, C, T, E>
139where
140    S: Stream<Item = C> + Send + Unpin + 'static,
141    C: TryInto<ServerConfig, Error = E> + Send + 'static,
142    T: Acceptor + Send + 'static,
143    <T as Acceptor>::Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static,
144    E: StdError + Send,
145{
146    type Coupler = TcpCoupler<Self::Stream>;
147    type Stream = HandshakeStream<TlsStream<T::Stream>>;
148
149    fn holdings(&self) -> &[Holding] {
150        &self.holdings
151    }
152
153    async fn accept(
154        &mut self,
155        fuse_factory: Option<ArcFuseFactory>,
156    ) -> IoResult<Accepted<Self::Coupler, Self::Stream>> {
157        let config = {
158            let mut config = None;
159            while let Poll::Ready(Some(item)) = Pin::new(&mut self.config_stream)
160                .poll_next(&mut Context::from_waker(noop_waker_ref()))
161            {
162                config = Some(item);
163            }
164            config
165        };
166        if let Some(config) = config {
167            let config: ServerConfig = config
168                .try_into()
169                .map_err(|e| IoError::other(e.to_string()))?;
170            let tls_acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(config));
171            if self.tls_acceptor.is_some() {
172                tracing::info!("tls config changed.");
173            } else {
174                tracing::info!("tls config loaded.");
175            }
176            self.tls_acceptor = Some(tls_acceptor);
177        }
178        let Some(tls_acceptor) = &self.tls_acceptor else {
179            return Err(IoError::other("rustls: invalid tls config."));
180        };
181
182        let Accepted {
183            coupler: _,
184            stream,
185            fusewire,
186            local_addr,
187            remote_addr,
188            ..
189        } = self.inner.accept(fuse_factory).await?;
190        Ok(Accepted {
191            coupler: TcpCoupler::new(),
192            stream: HandshakeStream::new(tls_acceptor.accept(stream), fusewire.clone()),
193            fusewire,
194            local_addr,
195            remote_addr,
196            http_scheme: Scheme::HTTPS,
197        })
198    }
199}