socks5_impl/server/connection/
associate.rs1use crate::protocol::{Address, AsyncStreamOperation, Reply, Response, StreamOperation, UdpHeader};
2use bytes::{Bytes, BytesMut};
3use std::{
4 net::SocketAddr,
5 pin::Pin,
6 sync::atomic::{AtomicUsize, Ordering},
7 task::{Context, Poll},
8 time::Duration,
9};
10use tokio::{
11 io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf},
12 net::{TcpStream, ToSocketAddrs, UdpSocket},
13};
14
15#[derive(Debug)]
17pub struct UdpAssociate<S> {
18 stream: TcpStream,
19 _state: S,
20}
21
22impl<S: Default> UdpAssociate<S> {
23 #[inline]
24 pub(super) fn new(stream: TcpStream) -> Self {
25 Self {
26 stream,
27 _state: S::default(),
28 }
29 }
30
31 pub async fn reply(mut self, reply: Reply, addr: Address) -> std::io::Result<UdpAssociate<Ready>> {
35 let resp = Response::new(reply, addr);
36 resp.write_to_async_stream(&mut self.stream).await?;
37 Ok(UdpAssociate::<Ready>::new(self.stream))
38 }
39
40 #[inline]
42 pub async fn shutdown(&mut self) -> std::io::Result<()> {
43 self.stream.shutdown().await
44 }
45
46 #[inline]
48 pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
49 self.stream.local_addr()
50 }
51
52 #[inline]
54 pub fn peer_addr(&self) -> std::io::Result<SocketAddr> {
55 self.stream.peer_addr()
56 }
57
58 #[inline]
62 pub fn nodelay(&self) -> std::io::Result<bool> {
63 self.stream.nodelay()
64 }
65
66 pub fn set_nodelay(&self, nodelay: bool) -> std::io::Result<()> {
72 self.stream.set_nodelay(nodelay)
73 }
74
75 pub fn ttl(&self) -> std::io::Result<u32> {
79 self.stream.ttl()
80 }
81
82 pub fn set_ttl(&self, ttl: u32) -> std::io::Result<()> {
86 self.stream.set_ttl(ttl)
87 }
88}
89
90#[derive(Debug, Default)]
91pub struct NeedReply;
92
93#[derive(Debug, Default)]
94pub struct Ready;
95
96impl UdpAssociate<Ready> {
97 pub async fn wait_until_closed(&mut self) -> std::io::Result<()> {
102 loop {
103 match self.stream.read(&mut [0]).await {
104 Ok(0) => break Ok(()),
105 Ok(_) => {}
106 Err(err) => break Err(err),
107 }
108 }
109 }
110}
111
112impl std::ops::Deref for UdpAssociate<Ready> {
113 type Target = TcpStream;
114
115 #[inline]
116 fn deref(&self) -> &Self::Target {
117 &self.stream
118 }
119}
120
121impl std::ops::DerefMut for UdpAssociate<Ready> {
122 #[inline]
123 fn deref_mut(&mut self) -> &mut Self::Target {
124 &mut self.stream
125 }
126}
127
128impl AsyncRead for UdpAssociate<Ready> {
129 #[inline]
130 fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
131 Pin::new(&mut self.stream).poll_read(cx, buf)
132 }
133}
134
135impl AsyncWrite for UdpAssociate<Ready> {
136 #[inline]
137 fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
138 Pin::new(&mut self.stream).poll_write(cx, buf)
139 }
140
141 #[inline]
142 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
143 Pin::new(&mut self.stream).poll_flush(cx)
144 }
145
146 #[inline]
147 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
148 Pin::new(&mut self.stream).poll_shutdown(cx)
149 }
150}
151
152impl<S> From<UdpAssociate<S>> for TcpStream {
153 #[inline]
154 fn from(conn: UdpAssociate<S>) -> Self {
155 conn.stream
156 }
157}
158
159#[derive(Debug)]
173pub struct AssociatedUdpSocket {
174 socket: UdpSocket,
175 buf_size: AtomicUsize,
176}
177
178impl AssociatedUdpSocket {
179 #[inline]
181 pub async fn connect<A: ToSocketAddrs>(&self, addr: A) -> std::io::Result<()> {
182 self.socket.connect(addr).await
183 }
184
185 pub fn get_max_packet_size(&self) -> usize {
187 self.buf_size.load(Ordering::Relaxed)
188 }
189
190 pub fn set_max_packet_size(&self, size: usize) {
192 self.buf_size.store(size, Ordering::Release);
193 }
194
195 pub async fn recv(&self) -> std::io::Result<(Bytes, u8, Address)> {
201 loop {
202 let max_packet_size = self.buf_size.load(Ordering::Acquire);
203 let mut buf = vec![0; max_packet_size];
204 let len = self.socket.recv(&mut buf).await?;
205 buf.truncate(len);
206 let pkt = Bytes::from(buf);
207
208 if let Ok(header) = UdpHeader::retrieve_from_async_stream(&mut pkt.as_ref()).await {
209 let pkt = pkt.slice(header.len()..);
210 return Ok((pkt, header.frag, header.address));
211 }
212 }
213 }
214
215 pub async fn recv_from(&self) -> std::io::Result<(Bytes, u8, Address, SocketAddr)> {
218 loop {
219 let max_packet_size = self.buf_size.load(Ordering::Acquire);
220 let mut buf = vec![0; max_packet_size];
221 let (len, src_addr) = self.socket.recv_from(&mut buf).await?;
222 buf.truncate(len);
223 let pkt = Bytes::from(buf);
224
225 if let Ok(header) = UdpHeader::retrieve_from_async_stream(&mut pkt.as_ref()).await {
226 let pkt = pkt.slice(header.len()..);
227 return Ok((pkt, header.frag, header.address, src_addr));
228 }
229 }
230 }
231
232 pub async fn send<P: AsRef<[u8]>>(&self, pkt: P, frag: u8, from_addr: Address) -> std::io::Result<usize> {
234 let header = UdpHeader::new(frag, from_addr);
235 let mut buf = BytesMut::with_capacity(header.len() + pkt.as_ref().len());
236 header.write_to_buf(&mut buf);
237 buf.extend_from_slice(pkt.as_ref());
238
239 self.socket.send(&buf).await.map(|len| len - header.len())
240 }
241
242 pub async fn send_to<P: AsRef<[u8]>>(&self, pkt: P, frag: u8, from_addr: Address, to_addr: SocketAddr) -> std::io::Result<usize> {
244 let header = UdpHeader::new(frag, from_addr);
245 let mut buf = BytesMut::with_capacity(header.len() + pkt.as_ref().len());
246 header.write_to_buf(&mut buf);
247 buf.extend_from_slice(pkt.as_ref());
248
249 self.socket.send_to(&buf, to_addr).await.map(|len| len - header.len())
250 }
251}
252
253impl From<(UdpSocket, usize)> for AssociatedUdpSocket {
254 #[inline]
255 fn from(from: (UdpSocket, usize)) -> Self {
256 AssociatedUdpSocket {
257 socket: from.0,
258 buf_size: AtomicUsize::new(from.1),
259 }
260 }
261}
262
263impl From<AssociatedUdpSocket> for UdpSocket {
264 #[inline]
265 fn from(from: AssociatedUdpSocket) -> Self {
266 from.socket
267 }
268}
269
270impl AsRef<UdpSocket> for AssociatedUdpSocket {
271 #[inline]
272 fn as_ref(&self) -> &UdpSocket {
273 &self.socket
274 }
275}
276
277impl AsMut<UdpSocket> for AssociatedUdpSocket {
278 #[inline]
279 fn as_mut(&mut self) -> &mut UdpSocket {
280 &mut self.socket
281 }
282}