socks5_impl/server/connection/
associate.rs1use crate::protocol::{Address, AsyncStreamOperation, Reply, Response, StreamOperation, UdpHeader};
2use bytes::{Bytes, BytesMut};
3use std::{
4 net::SocketAddr,
5 pin::Pin,
6 sync::atomic::{AtomicUsize, Ordering},
7 task::{Context, Poll},
8};
9use tokio::{
10 io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf},
11 net::{TcpStream, ToSocketAddrs, UdpSocket},
12};
13
14#[derive(Debug)]
16pub struct UdpAssociate<S> {
17 stream: TcpStream,
18 _state: S,
19}
20
21impl<S: Default> UdpAssociate<S> {
22 #[inline]
23 pub(super) fn new(stream: TcpStream) -> Self {
24 Self {
25 stream,
26 _state: S::default(),
27 }
28 }
29
30 pub async fn reply(mut self, reply: Reply, addr: Address) -> std::io::Result<UdpAssociate<Ready>> {
34 let resp = Response::new(reply, addr);
35 resp.write_to_async_stream(&mut self.stream).await?;
36 Ok(UdpAssociate::<Ready>::new(self.stream))
37 }
38
39 #[inline]
41 pub async fn shutdown(&mut self) -> std::io::Result<()> {
42 self.stream.shutdown().await
43 }
44
45 #[inline]
47 pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
48 self.stream.local_addr()
49 }
50
51 #[inline]
53 pub fn peer_addr(&self) -> std::io::Result<SocketAddr> {
54 self.stream.peer_addr()
55 }
56
57 #[inline]
61 pub fn nodelay(&self) -> std::io::Result<bool> {
62 self.stream.nodelay()
63 }
64
65 pub fn set_nodelay(&self, nodelay: bool) -> std::io::Result<()> {
71 self.stream.set_nodelay(nodelay)
72 }
73
74 pub fn ttl(&self) -> std::io::Result<u32> {
78 self.stream.ttl()
79 }
80
81 pub fn set_ttl(&self, ttl: u32) -> std::io::Result<()> {
85 self.stream.set_ttl(ttl)
86 }
87}
88
89#[derive(Debug, Default)]
90pub struct NeedReply;
91
92#[derive(Debug, Default)]
93pub struct Ready;
94
95impl UdpAssociate<Ready> {
96 pub async fn wait_until_closed(&mut self) -> std::io::Result<()> {
101 loop {
102 match self.stream.read(&mut [0]).await {
103 Ok(0) => break Ok(()),
104 Ok(_) => {}
105 Err(err) => break Err(err),
106 }
107 }
108 }
109}
110
111impl std::ops::Deref for UdpAssociate<Ready> {
112 type Target = TcpStream;
113
114 #[inline]
115 fn deref(&self) -> &Self::Target {
116 &self.stream
117 }
118}
119
120impl std::ops::DerefMut for UdpAssociate<Ready> {
121 #[inline]
122 fn deref_mut(&mut self) -> &mut Self::Target {
123 &mut self.stream
124 }
125}
126
127impl AsyncRead for UdpAssociate<Ready> {
128 #[inline]
129 fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
130 Pin::new(&mut self.stream).poll_read(cx, buf)
131 }
132}
133
134impl AsyncWrite for UdpAssociate<Ready> {
135 #[inline]
136 fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
137 Pin::new(&mut self.stream).poll_write(cx, buf)
138 }
139
140 #[inline]
141 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
142 Pin::new(&mut self.stream).poll_flush(cx)
143 }
144
145 #[inline]
146 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
147 Pin::new(&mut self.stream).poll_shutdown(cx)
148 }
149}
150
151impl<S> From<UdpAssociate<S>> for TcpStream {
152 #[inline]
153 fn from(conn: UdpAssociate<S>) -> Self {
154 conn.stream
155 }
156}
157
158#[derive(Debug)]
172pub struct AssociatedUdpSocket {
173 socket: UdpSocket,
174 buf_size: AtomicUsize,
175}
176
177impl AssociatedUdpSocket {
178 #[inline]
180 pub async fn connect<A: ToSocketAddrs>(&self, addr: A) -> std::io::Result<()> {
181 self.socket.connect(addr).await
182 }
183
184 pub fn get_max_packet_size(&self) -> usize {
186 self.buf_size.load(Ordering::Relaxed)
187 }
188
189 pub fn set_max_packet_size(&self, size: usize) {
191 self.buf_size.store(size, Ordering::Release);
192 }
193
194 pub async fn recv(&self) -> std::io::Result<(Bytes, u8, Address)> {
200 loop {
201 let max_packet_size = self.buf_size.load(Ordering::Acquire);
202 let mut buf = vec![0; max_packet_size];
203 let len = self.socket.recv(&mut buf).await?;
204 buf.truncate(len);
205 let pkt = Bytes::from(buf);
206
207 if let Ok(header) = UdpHeader::retrieve_from_async_stream(&mut pkt.as_ref()).await {
208 let pkt = pkt.slice(header.len()..);
209 return Ok((pkt, header.frag, header.address));
210 }
211 }
212 }
213
214 pub async fn recv_from(&self) -> std::io::Result<(Bytes, u8, Address, SocketAddr)> {
217 loop {
218 let max_packet_size = self.buf_size.load(Ordering::Acquire);
219 let mut buf = vec![0; max_packet_size];
220 let (len, src_addr) = self.socket.recv_from(&mut buf).await?;
221 buf.truncate(len);
222 let pkt = Bytes::from(buf);
223
224 if let Ok(header) = UdpHeader::retrieve_from_async_stream(&mut pkt.as_ref()).await {
225 let pkt = pkt.slice(header.len()..);
226 return Ok((pkt, header.frag, header.address, src_addr));
227 }
228 }
229 }
230
231 pub async fn send<P: AsRef<[u8]>>(&self, pkt: P, frag: u8, from_addr: Address) -> std::io::Result<usize> {
233 let header = UdpHeader::new(frag, from_addr);
234 let mut buf = BytesMut::with_capacity(header.len() + pkt.as_ref().len());
235 header.write_to_buf(&mut buf);
236 buf.extend_from_slice(pkt.as_ref());
237
238 self.socket.send(&buf).await.map(|len| len - header.len())
239 }
240
241 pub async fn send_to<P: AsRef<[u8]>>(&self, pkt: P, frag: u8, from_addr: Address, to_addr: SocketAddr) -> std::io::Result<usize> {
243 let header = UdpHeader::new(frag, from_addr);
244 let mut buf = BytesMut::with_capacity(header.len() + pkt.as_ref().len());
245 header.write_to_buf(&mut buf);
246 buf.extend_from_slice(pkt.as_ref());
247
248 self.socket.send_to(&buf, to_addr).await.map(|len| len - header.len())
249 }
250}
251
252impl From<(UdpSocket, usize)> for AssociatedUdpSocket {
253 #[inline]
254 fn from(from: (UdpSocket, usize)) -> Self {
255 AssociatedUdpSocket {
256 socket: from.0,
257 buf_size: AtomicUsize::new(from.1),
258 }
259 }
260}
261
262impl From<AssociatedUdpSocket> for UdpSocket {
263 #[inline]
264 fn from(from: AssociatedUdpSocket) -> Self {
265 from.socket
266 }
267}
268
269impl AsRef<UdpSocket> for AssociatedUdpSocket {
270 #[inline]
271 fn as_ref(&self) -> &UdpSocket {
272 &self.socket
273 }
274}
275
276impl AsMut<UdpSocket> for AssociatedUdpSocket {
277 #[inline]
278 fn as_mut(&mut self) -> &mut UdpSocket {
279 &mut self.socket
280 }
281}