1mod re_export;
2
3pub use re_export::*;
4
5use async_stream::stream;
6use futures_util::{Stream, StreamExt, TryStream, TryStreamExt};
7use std::{
8 error::Error as StdError,
9 fmt::Debug,
10 future::ready,
11 pin::Pin,
12 task::{Context, Poll},
13};
14use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
15use tokio_native_tls::{TlsAcceptor, TlsStream};
16
17pub type Error = Box<dyn StdError + Send + Sync + 'static>;
18
19pub fn incoming<S>(
20 mut incoming: S,
21 acceptor: TlsAcceptor,
22) -> impl Stream<Item = Result<TlsStreamWrapper<S::Ok>, Error>>
23where
24 S: TryStream + Unpin,
25 S::Ok: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
26 S::Error: StdError + Send + Sync + 'static,
27{
28 stream! {
29 while let Some(stream) = incoming.try_next().await.transpose() {
30 yield {
31 let acceptor = &acceptor;
32 move || async move {Ok(TlsStreamWrapper(acceptor.accept(stream?).await?))}
33 }().await;
34 }
35 }
36 .filter(|tls_stream| {
37 let ret = if let Err(_error) = tls_stream {
38 #[cfg(feature = "tracing")]
39 tracing::error!("Got error on incoming: `{_error}`.");
40 false
41 } else {
42 true
43 };
44
45 ready(ret)
46 })
47}
48
49#[derive(Debug)]
50pub struct TlsStreamWrapper<S>(TlsStream<S>);
51
52#[cfg(feature = "axum")]
53impl axum::extract::connect_info::Connected<&TlsStreamWrapper<tokio::net::TcpStream>>
54 for std::net::SocketAddr
55{
56 fn connect_info(target: &TlsStreamWrapper<tokio::net::TcpStream>) -> Self {
57 use std::net::{IpAddr, Ipv4Addr, SocketAddr};
58
59 target
60 .0
61 .get_ref()
62 .get_ref()
63 .get_ref()
64 .peer_addr()
65 .unwrap_or(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0))
66 }
67}
68
69#[cfg(feature = "tonic")]
70impl<S> tonic::transport::server::Connected for TlsStreamWrapper<S>
71where
72 S: tonic::transport::server::Connected + AsyncRead + AsyncWrite + Unpin,
73{
74 type ConnectInfo = <S as tonic::transport::server::Connected>::ConnectInfo;
75
76 fn connect_info(&self) -> Self::ConnectInfo {
77 self.0.get_ref().get_ref().get_ref().connect_info()
78 }
79}
80
81impl<S> AsyncRead for TlsStreamWrapper<S>
82where
83 S: AsyncRead + AsyncWrite + Unpin,
84{
85 fn poll_read(
86 mut self: Pin<&mut Self>,
87 cx: &mut Context<'_>,
88 buf: &mut ReadBuf<'_>,
89 ) -> Poll<std::io::Result<()>> {
90 Pin::new(&mut self.0).poll_read(cx, buf)
91 }
92}
93
94impl<S> AsyncWrite for TlsStreamWrapper<S>
95where
96 S: AsyncRead + AsyncWrite + Unpin,
97{
98 fn poll_write(
99 mut self: Pin<&mut Self>,
100 cx: &mut Context<'_>,
101 buf: &[u8],
102 ) -> Poll<Result<usize, std::io::Error>> {
103 Pin::new(&mut self.0).poll_write(cx, buf)
104 }
105
106 fn poll_flush(
107 mut self: Pin<&mut Self>,
108 cx: &mut Context<'_>,
109 ) -> Poll<Result<(), std::io::Error>> {
110 Pin::new(&mut self.0).poll_flush(cx)
111 }
112
113 fn poll_shutdown(
114 mut self: Pin<&mut Self>,
115 cx: &mut Context<'_>,
116 ) -> Poll<Result<(), std::io::Error>> {
117 Pin::new(&mut self.0).poll_shutdown(cx)
118 }
119}