salvo_core/conn/acme/
listener.rs

1use std::fmt::{self, Debug, Formatter};
2use std::io::Result as IoResult;
3use std::path::PathBuf;
4use std::sync::{Arc, Weak};
5use std::time::Duration;
6
7use tokio::io::{AsyncRead, AsyncWrite};
8use tokio_rustls::TlsAcceptor;
9use tokio_rustls::rustls::crypto::ring::sign::any_ecdsa_type;
10use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer};
11use tokio_rustls::rustls::server::ServerConfig;
12use tokio_rustls::rustls::sign::CertifiedKey;
13use tokio_rustls::server::TlsStream;
14
15use crate::Router;
16use crate::conn::HandshakeStream;
17use crate::conn::tcp::{DynTcpAcceptor, TcpCoupler, ToDynTcpAcceptor};
18use crate::conn::{Accepted, Acceptor, Holding, Listener};
19use crate::fuse::ArcFuseFactory;
20use crate::http::Version;
21use crate::http::uri::Scheme;
22
23use super::config::{AcmeConfig, AcmeConfigBuilder};
24use super::resolver::{ACME_TLS_ALPN_NAME, ResolveServerCert};
25use super::{AcmeCache, AcmeClient, ChallengeType, Http01Handler, WELL_KNOWN_PATH};
26
27cfg_feature! {
28    #![feature = "quinn"]
29    use crate::conn::quinn::QuinnAcceptor;
30    use crate::conn::joined::JoinedAcceptor;
31    use crate::conn::quinn::QuinnListener;
32    use futures_util::stream::BoxStream;
33}
34/// A wrapper around an underlying listener which implements the ACME.
35pub struct AcmeListener<T> {
36    inner: T,
37    config_builder: AcmeConfigBuilder,
38    check_duration: Duration,
39}
40
41impl<T> Debug for AcmeListener<T>
42where
43    T: Debug,
44{
45    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
46        f.debug_struct("AcmeListener")
47            .field("inner", &self.inner)
48            .field("config_builder", &self.config_builder)
49            .field("check_duration", &self.check_duration)
50            .finish()
51    }
52}
53
54impl<T> AcmeListener<T> {
55    /// Create `AcmeListener`
56    #[inline]
57    #[must_use]
58    pub fn new(inner: T) -> Self{
59        Self {
60            inner,
61            config_builder: AcmeConfig::builder(),
62            check_duration: Duration::from_secs(10 * 60),
63        }
64    }
65
66    /// Sets the directory.
67    ///
68    /// Defaults to lets encrypt.
69    #[inline]
70    #[must_use]
71    pub fn get_directory(self, name: impl Into<String>, url: impl Into<String>) -> Self {
72        Self {
73            config_builder: self.config_builder.directory(name, url),
74            ..self
75        }
76    }
77
78    /// Sets domains.
79    #[inline]
80    #[must_use]
81    pub fn domains(self, domains: impl Into<Vec<String>>) -> Self {
82        Self {
83            config_builder: self.config_builder.domains(domains),
84            ..self
85        }
86    }
87    /// Add a domain.
88    #[inline]
89    #[must_use]
90    pub fn add_domain(self, domain: impl Into<String>) -> Self {
91        Self {
92            config_builder: self.config_builder.add_domain(domain),
93            ..self
94        }
95    }
96
97    /// Add contact emails for the ACME account.
98    #[inline]
99    #[must_use]
100    pub fn contacts(self, contacts: impl Into<Vec<String>>) -> Self {
101        Self {
102            config_builder: self.config_builder.contacts(contacts.into()),
103            ..self
104        }
105    }
106    /// Add a contact email for the ACME account.
107    #[inline]
108    #[must_use]
109    pub fn add_contact(self, contact: impl Into<String>) -> Self {
110        Self {
111            config_builder: self.config_builder.add_contact(contact.into()),
112            ..self
113        }
114    }
115
116    /// Create an handler for HTTP-01 challenge
117    #[must_use]
118    pub fn http01_challenge(self, router: &mut Router) -> Self {
119        let config_builder = self.config_builder.http01_challenge();
120        if let Some(keys_for_http01) = &config_builder.keys_for_http01 {
121            let handler = Http01Handler {
122                keys: keys_for_http01.clone(),
123            };
124            router.routers.insert(
125                0,
126                Router::with_path(format!("{WELL_KNOWN_PATH}/{{token}}")).goal(handler),
127            );
128        } else {
129            panic!("`HTTP-01` challenge's key should not be none");
130        }
131        Self {
132            config_builder,
133            ..self
134        }
135    }
136    /// Create an handler for HTTP-01 challenge
137    #[inline]
138    #[must_use]
139    pub fn tls_alpn01_challenge(self) -> Self {
140        Self {
141            config_builder: self.config_builder.tls_alpn01_challenge(),
142            ..self
143        }
144    }
145
146    /// Sets the cache path for caching certificates.
147    ///
148    /// This is not a necessary option. If you do not configure the cache path,
149    /// the obtained certificate will be stored in memory and will need to be
150    /// obtained again when the server is restarted next time.
151    #[inline]
152    #[must_use]
153    pub fn cache_path(self, path: impl Into<PathBuf>) -> Self {
154        Self {
155            config_builder: self.config_builder.cache_path(path),
156            ..self
157        }
158    }
159    cfg_feature! {
160        #![feature = "quinn"]
161        /// Enable Http3 using quinn.
162        pub fn quinn<A>(self, local_addr: A) -> AcmeQuinnListener<T, A>
163        where
164            A: std::net::ToSocketAddrs + Send,
165        {
166            AcmeQuinnListener::new(self, local_addr)
167        }
168    }
169    async fn build_server_config(
170        acme_config: &AcmeConfig,
171    ) -> crate::Result<(ServerConfig, Arc<ResolveServerCert>)> {
172        let mut cached_key = None;
173        let mut cached_certs = None;
174        if let Some(cache_path) = &acme_config.cache_path {
175            let key_data = cache_path
176                .read_key(&acme_config.directory_name, &acme_config.domains)
177                .await?;
178            if let Some(key_data) = key_data {
179                tracing::debug!("load private key from cache");
180                if let Some(key) =
181                    rustls_pemfile::pkcs8_private_keys(&mut key_data.as_slice()).next()
182                {
183                    match key {
184                        Ok(key) => {
185                            cached_key = Some(key);
186                        }
187                        Err(e) => {
188                            tracing::warn!(error = ?e, "parse cached private key failed");
189                        }
190                    }
191                } else {
192                    tracing::warn!("parse cached private key failed");
193                }
194            }
195            let cert_data = cache_path
196                .read_cert(&acme_config.directory_name, &acme_config.domains)
197                .await?;
198            if let Some(cert_data) = cert_data {
199                tracing::debug!("load certificate from cache");
200                let certs = rustls_pemfile::certs(&mut cert_data.as_slice())
201                    .filter_map(|i| i.ok())
202                    .collect::<Vec<_>>();
203                if !certs.is_empty() {
204                    cached_certs = Some(certs);
205                } else {
206                    tracing::warn!("parse cached tls certificates failed")
207                };
208            }
209        };
210
211        let cert_resolver = Arc::new(ResolveServerCert::default());
212        if let (Some(cached_certs), Some(cached_key)) = (cached_certs, cached_key) {
213            let certs = cached_certs
214                .into_iter()
215                .collect::<Vec<CertificateDer<'static>>>();
216            tracing::debug!("using cached tls certificates");
217            *cert_resolver.cert.write() = Some(Arc::new(CertifiedKey::new(
218                certs,
219                any_ecdsa_type(&PrivateKeyDer::Pkcs8(cached_key))
220                    .expect("parse private key failed"),
221            )));
222        }
223
224        let mut server_config = ServerConfig::builder()
225            .with_no_client_auth()
226            .with_cert_resolver(cert_resolver.clone());
227
228        server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
229
230        if acme_config.challenge_type == ChallengeType::TlsAlpn01 {
231            server_config
232                .alpn_protocols
233                .push(ACME_TLS_ALPN_NAME.to_vec());
234        }
235        Ok((server_config, cert_resolver))
236    }
237}
238
239impl<T> Listener for AcmeListener<T>
240where
241    T: Listener + Send + 'static,
242    T::Acceptor: Send + 'static,
243    <T::Acceptor as Acceptor>::Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static,
244{
245    type Acceptor = AcmeAcceptor<T::Acceptor>;
246
247    async fn try_bind(self) -> crate::Result<Self::Acceptor> {
248        let Self {
249            inner,
250            config_builder,
251            check_duration,
252            ..
253        } = self;
254
255        let acme_config = config_builder.build()?;
256        let (server_config, cert_resolver) = Self::build_server_config(&acme_config).await?;
257        let server_config = Arc::new(server_config);
258        let tls_acceptor = TlsAcceptor::from(server_config.clone());
259        let inner = inner.try_bind().await?;
260        let acceptor = AcmeAcceptor::new(
261            acme_config,
262            server_config,
263            cert_resolver,
264            inner,
265            tls_acceptor,
266            check_duration,
267        )
268        .await?;
269        Ok(acceptor)
270    }
271}
272
273cfg_feature! {
274    #![feature = "quinn"]
275    /// A wrapper around an underlying listener which implements the ACME and Quinn.
276    pub struct AcmeQuinnListener<T, A> {
277        acme: AcmeListener<T>,
278        local_addr: A,
279    }
280
281    impl<T:Debug, A:Debug> Debug for AcmeQuinnListener<T, A> {
282        fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
283            f.debug_struct("AcmeQuinnListener")
284                .field("acme", &self.acme)
285                .field("local_addr", &self.local_addr)
286                .finish()
287        }
288    }
289
290    impl <T, A> AcmeQuinnListener<T, A>
291    where
292        A: std::net::ToSocketAddrs + Send,
293    {
294        /// Create `AcmeQuinnListener`.
295        pub fn new(acme: AcmeListener<T>, local_addr: A) -> Self {
296            Self { acme, local_addr }
297        }
298    }
299
300    impl<T, A> Listener for AcmeQuinnListener<T, A>
301    where
302        T: Listener + Send + 'static,
303        T::Acceptor: Send + Unpin + 'static,
304        <T::Acceptor as Acceptor>::Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static,
305        A: std::net::ToSocketAddrs + Send + 'static,
306    {
307        type Acceptor = JoinedAcceptor<AcmeAcceptor<T::Acceptor>, QuinnAcceptor<BoxStream<'static, crate::conn::quinn::ServerConfig>, crate::conn::quinn::ServerConfig, std::convert::Infallible>>;
308
309        async fn try_bind(self) -> crate::Result<Self::Acceptor>{
310            let Self { acme, local_addr } = self;
311            let a = acme.try_bind().await?;
312
313            let mut crypto = a.server_config.as_ref().clone();
314            crypto.alpn_protocols = vec![b"h3-29".to_vec(), b"h3-28".to_vec(), b"h3-27".to_vec(), b"h3".to_vec()];
315            let crypto = quinn::crypto::rustls::QuicServerConfig::try_from(crypto).map_err(crate::Error::other)?;
316            let config = crate::conn::quinn::ServerConfig::with_crypto(Arc::new(crypto));
317            let b = QuinnListener::new(config, local_addr).try_bind().await?;
318            Ok(JoinedAcceptor::new(a, b))
319        }
320    }
321}
322
323/// Acceptor for ACME.
324pub struct AcmeAcceptor<T> {
325    config: Arc<AcmeConfig>,
326    server_config: Arc<ServerConfig>,
327    inner: T,
328    holdings: Vec<Holding>,
329    tls_acceptor: tokio_rustls::TlsAcceptor,
330}
331impl<T> Debug for AcmeAcceptor<T> {
332    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
333        f.debug_struct("AcmeAcceptor")
334            .field("config", &self.config)
335            .field("server_config", &self.server_config)
336            .field("holdings", &self.holdings)
337            .finish()
338    }
339}
340
341impl<T> AcmeAcceptor<T>
342where
343    T: Acceptor + Send + 'static,
344    T::Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static,
345{
346    pub(crate) async fn new(
347        config: impl Into<Arc<AcmeConfig>> + Send,
348        server_config: impl Into<Arc<ServerConfig>> + Send,
349        cert_resolver: Arc<ResolveServerCert>,
350        inner: T,
351        tls_acceptor: TlsAcceptor,
352        check_duration: Duration,
353    ) -> crate::Result<Self>
354    where
355        T: Send,
356    {
357        let holdings = inner
358            .holdings()
359            .iter()
360            .map(|h| {
361                let mut versions = h.http_versions.clone();
362                #[cfg(feature = "http1")]
363                if !versions.contains(&Version::HTTP_11) {
364                    versions.push(Version::HTTP_11);
365                }
366                #[cfg(feature = "http2")]
367                if !versions.contains(&Version::HTTP_2) {
368                    versions.push(Version::HTTP_2);
369                }
370                Holding {
371                    local_addr: h.local_addr.clone(),
372                    http_versions: versions,
373                    http_scheme: Scheme::HTTPS,
374                }
375            })
376            .collect();
377
378        let acceptor = Self {
379            config: config.into(),
380            server_config: server_config.into(),
381            inner,
382            holdings,
383            tls_acceptor,
384        };
385        let config = acceptor.config.clone();
386        let weak_cert_resolver = Arc::downgrade(&cert_resolver);
387        let mut client = AcmeClient::new(
388            &config.directory_url,
389            config.key_pair.clone(),
390            config.contacts.clone(),
391        )
392        .await?;
393        tokio::spawn(async move {
394            while let Some(cert_resolver) = Weak::upgrade(&weak_cert_resolver) {
395                if cert_resolver.will_expired(config.before_expired) {
396                    if let Err(e) =
397                        super::issuer::issue_cert(&mut client, &config, &cert_resolver).await
398                    {
399                        tracing::error!(error = ?e, "issue certificate failed");
400                    }
401                }
402                tokio::time::sleep(check_duration).await;
403            }
404        });
405        Ok(acceptor)
406    }
407
408    /// Returns the config of this acceptor.
409    pub fn server_config(&self) -> Arc<ServerConfig> {
410        self.server_config.clone()
411    }
412
413    /// Convert this `AcmeAcceptor` into a boxed `DynTcpAcceptor`.
414    pub fn into_boxed(self) -> Box<dyn DynTcpAcceptor> {
415        Box::new(ToDynTcpAcceptor(self))
416    }
417}
418
419impl<T> Acceptor for AcmeAcceptor<T>
420where
421    T: Acceptor + Send + 'static,
422    <T as Acceptor>::Stream: AsyncRead + AsyncWrite + Send + Unpin + 'static,
423{
424    type Coupler = TcpCoupler<Self::Stream>;
425    type Stream = HandshakeStream<TlsStream<T::Stream>>;
426
427    #[inline]
428    fn holdings(&self) -> &[Holding] {
429        &self.holdings
430    }
431
432    #[inline]
433    async fn accept(
434        &mut self,
435        fuse_factory: Option<ArcFuseFactory>,
436    ) -> IoResult<Accepted<Self::Coupler, Self::Stream>> {
437        let Accepted {
438            coupler: _,
439            stream,
440            fusewire,
441            local_addr,
442            remote_addr,
443            ..
444        } = self.inner.accept(fuse_factory).await?;
445        Ok(Accepted {
446            coupler: TcpCoupler::new(),
447            stream: HandshakeStream::new(self.tls_acceptor.accept(stream), fusewire.clone()),
448            fusewire,
449            local_addr,
450            remote_addr,
451            http_scheme: Scheme::HTTPS,
452        })
453    }
454}