1use std::error::Error as StdError;
3use std::fmt::{self, Debug, Formatter};
4use std::io::{Error as IoError, Result as IoResult};
5use std::marker::PhantomData;
6use std::pin::Pin;
7use std::sync::Arc;
8use std::task::{Context, Poll};
9
10use futures_util::stream::{BoxStream, Stream, StreamExt};
11use futures_util::task::noop_waker_ref;
12use tokio::io::{AsyncRead, AsyncWrite};
13use tokio_rustls::server::TlsStream;
14
15use crate::conn::tcp::{DynTcpAcceptor, TcpCoupler, ToDynTcpAcceptor};
16use crate::conn::{Accepted, Acceptor, HandshakeStream, Holding, IntoConfigStream, Listener};
17use crate::fuse::ArcFuseFactory;
18use crate::http::uri::Scheme;
19
20use super::ServerConfig;
21
22pub struct RustlsListener<S, C, T, E> {
24 config_stream: S,
25 inner: T,
26 _phantom: PhantomData<(C, E)>,
27}
28
29impl<S, C, T, E> Debug for RustlsListener<S, C, T, E> {
30 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
31 f.debug_struct("RustlsListener").finish()
32 }
33}
34
35impl<S, C, T, E> RustlsListener<S, C, T, E>
36where
37 S: IntoConfigStream<C> + Send + 'static,
38 C: TryInto<ServerConfig, Error = E> + Send + 'static,
39 T: Listener + Send,
40 E: StdError + Send,
41{
42 #[inline]
44 pub fn new(config_stream: S, inner: T) -> Self {
45 Self {
46 config_stream,
47 inner,
48 _phantom: PhantomData,
49 }
50 }
51}
52
53impl<S, C, T, E> Listener for RustlsListener<S, C, T, E>
54where
55 S: IntoConfigStream<C> + Send + 'static,
56 C: TryInto<ServerConfig, Error = E> + Send + 'static,
57 T: Listener + Send + 'static,
58 T::Acceptor: Send + 'static,
59 <T::Acceptor as Acceptor>::Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static,
60 E: StdError + Send + 'static,
61{
62 type Acceptor = RustlsAcceptor<BoxStream<'static, C>, C, T::Acceptor, E>;
63
64 async fn try_bind(self) -> crate::Result<Self::Acceptor> {
65 Ok(RustlsAcceptor::new(
66 self.config_stream.into_stream().boxed(),
67 self.inner.try_bind().await?,
68 ))
69 }
70}
71
72pub struct RustlsAcceptor<S, C, T, E> {
74 config_stream: S,
75 inner: T,
76 holdings: Vec<Holding>,
77 tls_acceptor: Option<tokio_rustls::TlsAcceptor>,
78 _phantom: PhantomData<(C, E)>,
79}
80
81impl<S, C, T, E> Debug for RustlsAcceptor<S, C, T, E> {
82 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
83 f.debug_struct("RustlsAcceptor").finish()
84 }
85}
86
87impl<S, C, T, E> RustlsAcceptor<S, C, T, E>
88where
89 S: Stream<Item = C> + Unpin + Send + 'static,
90 C: TryInto<ServerConfig, Error = E> + Send + 'static,
91 T: Acceptor + Send + 'static,
92 T::Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static,
93 E: StdError + Send + 'static,
94{
95 pub fn new(config_stream: S, inner: T) -> Self {
97 let holdings = inner
98 .holdings()
99 .iter()
100 .map(|h| {
101 #[allow(unused_mut)]
102 let mut versions = h.http_versions.clone();
103 #[cfg(feature = "http1")]
104 if !versions.contains(&crate::http::Version::HTTP_11) {
105 versions.push(crate::http::Version::HTTP_11);
106 }
107 #[cfg(feature = "http2")]
108 if !versions.contains(&crate::http::Version::HTTP_2) {
109 versions.push(crate::http::Version::HTTP_2);
110 }
111 Holding {
112 local_addr: h.local_addr.clone(),
113 http_versions: versions,
114 http_scheme: Scheme::HTTPS,
115 }
116 })
117 .collect();
118 Self {
119 config_stream,
120 inner,
121 holdings,
122 tls_acceptor: None,
123 _phantom: PhantomData,
124 }
125 }
126
127 pub fn inner(&self) -> &T {
129 &self.inner
130 }
131
132 pub fn into_boxed(self) -> Box<dyn DynTcpAcceptor> {
134 Box::new(ToDynTcpAcceptor(self))
135 }
136}
137
138impl<S, C, T, E> Acceptor for RustlsAcceptor<S, C, T, E>
139where
140 S: Stream<Item = C> + Send + Unpin + 'static,
141 C: TryInto<ServerConfig, Error = E> + Send + 'static,
142 T: Acceptor + Send + 'static,
143 <T as Acceptor>::Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static,
144 E: StdError + Send,
145{
146 type Coupler = TcpCoupler<Self::Stream>;
147 type Stream = HandshakeStream<TlsStream<T::Stream>>;
148
149 fn holdings(&self) -> &[Holding] {
150 &self.holdings
151 }
152
153 async fn accept(
154 &mut self,
155 fuse_factory: Option<ArcFuseFactory>,
156 ) -> IoResult<Accepted<Self::Coupler, Self::Stream>> {
157 let config = {
158 let mut config = None;
159 while let Poll::Ready(Some(item)) = Pin::new(&mut self.config_stream)
160 .poll_next(&mut Context::from_waker(noop_waker_ref()))
161 {
162 config = Some(item);
163 }
164 config
165 };
166 if let Some(config) = config {
167 let config: ServerConfig = config
168 .try_into()
169 .map_err(|e| IoError::other(e.to_string()))?;
170 let tls_acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(config));
171 if self.tls_acceptor.is_some() {
172 tracing::info!("tls config changed.");
173 } else {
174 tracing::info!("tls config loaded.");
175 }
176 self.tls_acceptor = Some(tls_acceptor);
177 }
178 let Some(tls_acceptor) = &self.tls_acceptor else {
179 return Err(IoError::other("rustls: invalid tls config."));
180 };
181
182 let Accepted {
183 coupler: _,
184 stream,
185 fusewire,
186 local_addr,
187 remote_addr,
188 ..
189 } = self.inner.accept(fuse_factory).await?;
190 Ok(Accepted {
191 coupler: TcpCoupler::new(),
192 stream: HandshakeStream::new(tls_acceptor.accept(stream), fusewire.clone()),
193 fusewire,
194 local_addr,
195 remote_addr,
196 http_scheme: Scheme::HTTPS,
197 })
198 }
199}