1use std::net::SocketAddr;
2
3use bytes::BytesMut;
4
5use crate::io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
6use crate::v5::{Request, Response, method::Method};
7
8pub struct Stream<T> {
9 inner: T,
10 peer_addr: SocketAddr,
11 local_addr: SocketAddr,
12}
13
14impl<T> Stream<T> {
15 #[inline]
16 pub fn version(&self) -> u8 {
17 0x05
18 }
19
20 #[inline]
21 pub fn peer_addr(&self) -> SocketAddr {
22 self.peer_addr
23 }
24
25 #[inline]
26 pub fn local_addr(&self) -> SocketAddr {
27 self.local_addr
28 }
29
30 #[inline]
31 pub fn with(inner: T, peer_addr: SocketAddr, local_addr: SocketAddr) -> Self {
32 Self {
33 inner,
34 peer_addr,
35 local_addr,
36 }
37 }
38}
39
40impl<T> Stream<T>
42where
43 T: AsyncRead + AsyncWrite + Unpin,
44{
45 #[inline]
55 pub async fn read_methods(&mut self) -> io::Result<Vec<Method>> {
56 let mut buffer = [0u8; 2];
57 self.read_exact(&mut buffer).await?;
58
59 let method_num = buffer[1];
60 if method_num == 1 {
61 let method = self.read_u8().await?;
62 return Ok(vec![Method::from_u8(method)]);
63 }
64
65 let mut methods = vec![0u8; method_num as usize];
66 self.read_exact(&mut methods).await?;
67
68 let result = methods.into_iter().map(Method::from_u8).collect();
69
70 Ok(result)
71 }
72
73 #[inline]
82 pub async fn write_auth_method(&mut self, method: Method) -> io::Result<usize> {
83 let bytes = [self.version(), method.as_u8()];
84 self.write(&bytes).await
85 }
86
87 #[inline]
97 pub async fn read_request(&mut self) -> io::Result<Request> {
98 let _version = self.read_u8().await?;
99 Request::from_async_read(self).await
100 }
101
102 #[inline]
112 pub async fn write_response<'a>(&mut self, resp: &Response<'a>) -> io::Result<usize> {
113 let bytes = prepend_u8(resp.to_bytes(), self.version());
114 self.write(&bytes).await
115 }
116
117 #[inline]
118 pub async fn write_response_unspecified(&mut self) -> io::Result<usize> {
119 use crate::v5::Address;
120 self.write_response(&Response::Success(Address::unspecified()))
121 .await
122 }
123
124 #[inline]
125 pub async fn write_response_unsupported(&mut self) -> io::Result<usize> {
126 self.write_response(&Response::CommandNotSupported).await
127 }
128}
129
130#[inline]
131fn prepend_u8(mut bytes: BytesMut, value: u8) -> BytesMut {
132 bytes.reserve(1);
133
134 unsafe {
135 let ptr = bytes.as_mut_ptr();
136 std::ptr::copy(ptr, ptr.add(1), bytes.len());
137 std::ptr::write(ptr, value);
138 let new_len = bytes.len() + 1;
139 bytes.set_len(new_len);
140 }
141
142 bytes
143}
144
145mod async_impl {
146 use super::Stream;
147
148 use std::io;
149 use std::pin::Pin;
150 use std::task::{Context, Poll};
151 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
152
153 impl<T: AsyncRead + Unpin> AsyncRead for Stream<T> {
154 fn poll_read(
155 mut self: Pin<&mut Self>,
156 cx: &mut Context<'_>,
157 buf: &mut ReadBuf<'_>,
158 ) -> Poll<io::Result<()>> {
159 Pin::new(&mut self.inner).poll_read(cx, buf)
160 }
161 }
162
163 impl<T: AsyncWrite + Unpin> AsyncWrite for Stream<T> {
164 fn poll_write(
165 mut self: Pin<&mut Self>,
166 cx: &mut Context<'_>,
167 buf: &[u8],
168 ) -> Poll<io::Result<usize>> {
169 Pin::new(&mut self.inner).poll_write(cx, buf)
170 }
171
172 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
173 Pin::new(&mut self.inner).poll_flush(cx)
174 }
175
176 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
177 Pin::new(&mut self.inner).poll_shutdown(cx)
178 }
179 }
180}