sillad_native_tls/
lib.rs

1use std::pin::Pin;
2
3use async_native_tls::{TlsAcceptor, TlsConnector, TlsStream};
4use async_trait::async_trait;
5use futures_lite::{AsyncRead, AsyncWrite};
6
7use sillad::{dialer::Dialer, listener::Listener, Pipe};
8
9/// TlsPipe wraps a TLS stream to implement the Pipe trait.
10pub struct TlsPipe<T: AsyncRead + AsyncWrite + Unpin + Send + 'static> {
11    inner: TlsStream<T>,
12    remote_addr: Option<String>,
13}
14
15impl<T: AsyncRead + AsyncWrite + Unpin + Send> AsyncRead for TlsPipe<T> {
16    fn poll_read(
17        self: Pin<&mut Self>,
18        cx: &mut std::task::Context<'_>,
19        buf: &mut [u8],
20    ) -> std::task::Poll<std::io::Result<usize>> {
21        Pin::new(&mut self.get_mut().inner).poll_read(cx, buf)
22    }
23}
24
25impl<T: AsyncRead + AsyncWrite + Unpin + Send> AsyncWrite for TlsPipe<T> {
26    fn poll_write(
27        self: Pin<&mut Self>,
28        cx: &mut std::task::Context<'_>,
29        buf: &[u8],
30    ) -> std::task::Poll<std::io::Result<usize>> {
31        Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
32    }
33
34    fn poll_flush(
35        self: Pin<&mut Self>,
36        cx: &mut std::task::Context<'_>,
37    ) -> std::task::Poll<std::io::Result<()>> {
38        Pin::new(&mut self.get_mut().inner).poll_flush(cx)
39    }
40
41    fn poll_close(
42        self: Pin<&mut Self>,
43        cx: &mut std::task::Context<'_>,
44    ) -> std::task::Poll<std::io::Result<()>> {
45        Pin::new(&mut self.get_mut().inner).poll_close(cx)
46    }
47}
48
49impl<T: AsyncRead + AsyncWrite + Unpin + Send> Pipe for TlsPipe<T> {
50    fn protocol(&self) -> &str {
51        "tls"
52    }
53
54    fn remote_addr(&self) -> Option<&str> {
55        self.remote_addr.as_deref()
56    }
57}
58
59/// TlsDialer wraps a Dialer to establish a TLS connection.
60pub struct TlsDialer<D: Dialer> {
61    inner: D,
62    connector: TlsConnector,
63    domain: String,
64}
65
66impl<D: Dialer> TlsDialer<D> {
67    pub fn new(inner: D, connector: TlsConnector, domain: String) -> Self {
68        Self {
69            inner,
70            connector,
71            domain,
72        }
73    }
74}
75
76#[async_trait]
77impl<D: Dialer> Dialer for TlsDialer<D>
78where
79    D::P: AsyncRead + AsyncWrite + Unpin + Send,
80{
81    type P = TlsPipe<D::P>;
82
83    async fn dial(&self) -> std::io::Result<Self::P> {
84        let stream = self.inner.dial().await?;
85        let remote_addr = stream.remote_addr().map(|s| s.to_string());
86        let tls_stream = self
87            .connector
88            .connect(&self.domain, stream)
89            .await
90            .inspect_err(|e| {
91                tracing::warn!(
92                    err = display(e),
93                    addr = debug(&remote_addr),
94                    "TLS connection failed"
95                )
96            })
97            .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err))?;
98        tracing::warn!(addr = debug(&remote_addr), "TLS connection SUCCESS");
99        Ok(TlsPipe {
100            inner: tls_stream,
101            remote_addr,
102        })
103    }
104}
105
106/// TlsListener wraps a Listener to accept TLS connections .
107pub struct TlsListener<L: Listener> {
108    // Channel that will yield successful TLS connections.
109    incoming: tachyonix::Receiver<TlsPipe<L::P>>,
110    // Keep the background task alive (cancels on drop).
111    _accept_task: async_task::Task<()>,
112}
113
114impl<L: Listener> TlsListener<L>
115where
116    L::P: AsyncRead + AsyncWrite + Unpin + Send + 'static,
117{
118    pub fn new(mut inner: L, acceptor: TlsAcceptor) -> Self {
119        // Create a channel to send successfully negotiated TLS connections.
120        let (tx, rx) = tachyonix::channel(1);
121
122        let acceptor_clone = acceptor.clone();
123        let accept_task = smolscale::spawn(async move {
124            loop {
125                // Pull the next raw connection from the underlying listener.
126                let raw_conn = match inner.accept().await {
127                    Ok(conn) => conn,
128                    Err(err) => {
129                        // Underlying listener failure: log and break out of the loop.
130                        eprintln!("Underlying listener error: {:?}", err);
131                        break;
132                    }
133                };
134
135                // For each raw connection, spawn a task to perform the TLS handshake.
136                let tx2 = tx.clone();
137                let acceptor2 = acceptor_clone.clone();
138                let remote_addr = raw_conn.remote_addr().map(|s| s.to_string());
139                smolscale::spawn(async move {
140                    match acceptor2.accept(raw_conn).await {
141                        Ok(tls_stream) => {
142                            let pipe = TlsPipe {
143                                inner: tls_stream,
144                                remote_addr,
145                            };
146                            let _ = tx2.send(pipe).await;
147                        }
148                        Err(e) => {
149                            // Handshake failure: log but do not send an error.
150                            eprintln!("TLS handshake error (ignored): {:?}", e);
151                        }
152                    }
153                })
154                .detach();
155            }
156        });
157
158        TlsListener {
159            incoming: rx,
160            _accept_task: accept_task,
161        }
162    }
163}
164
165#[async_trait]
166impl<L: Listener> Listener for TlsListener<L>
167where
168    L::P: AsyncRead + AsyncWrite + Unpin + Send + 'static,
169{
170    type P = TlsPipe<L::P>;
171
172    async fn accept(&mut self) -> std::io::Result<Self::P> {
173        // If a TLS connection is available, return it.
174        // If the channel is closed (due to an underlying listener failure), return an error.
175        match self.incoming.recv().await {
176            Ok(pipe) => Ok(pipe),
177            Err(_) => Err(std::io::Error::new(
178                std::io::ErrorKind::Other,
179                "Underlying listener failure",
180            )),
181        }
182    }
183}