1use async_trait::async_trait;
4use futures::{AsyncRead, AsyncWrite, StreamExt as _, stream};
5use std::io::{Error as IoError, ErrorKind as IoErrorKind, Result as IoResult};
6use std::net;
7use std::task::Poll;
8use std::{pin::Pin, task::Context};
9use tor_general_addr::unix;
10use tracing::instrument;
11
12use crate::{NetStreamListener, NetStreamProvider, StreamOps};
13use tor_general_addr::general;
14
15pub use general::{AddrParseError, SocketAddr};
16
17trait ReadAndWrite: AsyncRead + AsyncWrite + StreamOps + Send + Sync {}
21impl<T> ReadAndWrite for T where T: AsyncRead + AsyncWrite + StreamOps + Send + Sync {}
22
23pub struct Stream(Pin<Box<dyn ReadAndWrite>>);
25impl AsyncRead for Stream {
26 fn poll_read(
27 mut self: Pin<&mut Self>,
28 cx: &mut Context<'_>,
29 buf: &mut [u8],
30 ) -> Poll<IoResult<usize>> {
31 self.0.as_mut().poll_read(cx, buf)
32 }
33}
34impl AsyncWrite for Stream {
35 fn poll_write(
36 mut self: Pin<&mut Self>,
37 cx: &mut Context<'_>,
38 buf: &[u8],
39 ) -> Poll<IoResult<usize>> {
40 self.0.as_mut().poll_write(cx, buf)
41 }
42
43 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
44 self.0.as_mut().poll_flush(cx)
45 }
46
47 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
48 self.0.as_mut().poll_close(cx)
49 }
50}
51
52impl StreamOps for Stream {
53 fn set_tcp_notsent_lowat(&self, notsent_lowat: u32) -> IoResult<()> {
54 self.0.set_tcp_notsent_lowat(notsent_lowat)
55 }
56
57 fn new_handle(&self) -> Box<dyn StreamOps + Send + Unpin> {
58 self.0.new_handle()
59 }
60}
61
62type StreamItem = IoResult<(Stream, general::SocketAddr)>;
64
65pub struct IncomingStreams(Pin<Box<dyn stream::Stream<Item = StreamItem> + Send + Sync>>);
67
68impl stream::Stream for IncomingStreams {
69 type Item = IoResult<(Stream, general::SocketAddr)>;
70
71 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
72 self.0.as_mut().poll_next(cx)
73 }
74}
75
76pub struct Listener {
78 streams: IncomingStreams,
80 local_addr: general::SocketAddr,
82}
83
84impl NetStreamListener<general::SocketAddr> for Listener {
85 type Stream = Stream;
86 type Incoming = IncomingStreams;
87
88 fn incoming(self) -> IncomingStreams {
89 self.streams
90 }
91
92 fn local_addr(&self) -> IoResult<general::SocketAddr> {
93 Ok(self.local_addr.clone())
94 }
95}
96
97async fn abstract_listener_on<ADDR, P>(
100 provider: &P,
101 address: &ADDR,
102 options: &P::ListenOptions,
103) -> IoResult<Listener>
104where
105 P: NetStreamProvider<ADDR>,
106 general::SocketAddr: From<ADDR>,
107{
108 let lis = provider.listen(address, options).await?;
109 let local_addr = general::SocketAddr::from(lis.local_addr()?);
110 let streams = lis.incoming().map(|result| {
111 result.map(|(socket, addr)| (Stream(Box::pin(socket)), general::SocketAddr::from(addr)))
112 });
113 let streams = IncomingStreams(Box::pin(streams));
114 Ok(Listener {
115 streams,
116 local_addr,
117 })
118}
119
120#[async_trait]
121impl<T> NetStreamProvider<general::SocketAddr> for T
122where
123 T: NetStreamProvider<net::SocketAddr> + NetStreamProvider<unix::SocketAddr>,
124{
125 type Stream = Stream;
126 type Listener = Listener;
127 type ConnectOptions = ();
130 type ListenOptions = ();
133
134 #[instrument(skip_all, level = "trace")]
135 async fn connect(
136 &self,
137 addr: &general::SocketAddr,
138 (): &Self::ConnectOptions,
139 ) -> IoResult<Stream> {
140 use general::SocketAddr as G;
141 match addr {
142 G::Inet(a) => {
143 let options = Default::default();
144 Ok(Stream(Box::pin(self.connect(a, &options).await?)))
145 }
146 G::Unix(a) => {
147 let options = Default::default();
148 Ok(Stream(Box::pin(self.connect(a, &options).await?)))
149 }
150 other => Err(IoError::new(
151 IoErrorKind::InvalidInput,
152 UnsupportedAddress(other.clone()),
153 )),
154 }
155 }
156 async fn listen(
157 &self,
158 addr: &general::SocketAddr,
159 (): &Self::ListenOptions,
160 ) -> IoResult<Listener> {
161 use general::SocketAddr as G;
162 match addr {
163 G::Inet(a) => abstract_listener_on(self, a, &Default::default()).await,
164 G::Unix(a) => abstract_listener_on(self, a, &Default::default()).await,
165 other => Err(IoError::new(
166 IoErrorKind::InvalidInput,
167 UnsupportedAddress(other.clone()),
168 )),
169 }
170 }
171}
172
173#[derive(Clone, Debug, thiserror::Error)]
175#[error("Socket address {0:?} is not supported by tor-rtcompat")]
176pub struct UnsupportedAddress(general::SocketAddr);