socks5_impl/server/connection/
associate.rs1use crate::protocol::{Address, AsyncStreamOperation, Reply, Response, StreamOperation, UdpHeader};
2use bytes::{Bytes, BytesMut};
3use std::{
4 net::SocketAddr,
5 pin::Pin,
6 sync::atomic::{AtomicUsize, Ordering},
7 task::{Context, Poll},
8 time::Duration,
9};
10use tokio::{
11 io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf},
12 net::{TcpStream, ToSocketAddrs, UdpSocket},
13};
14
15#[derive(Debug)]
17pub struct UdpAssociate<S> {
18 stream: TcpStream,
19 _state: S,
20}
21
22impl<S: Default> UdpAssociate<S> {
23 #[inline]
24 pub(super) fn new(stream: TcpStream) -> Self {
25 Self {
26 stream,
27 _state: S::default(),
28 }
29 }
30
31 pub async fn reply(mut self, reply: Reply, addr: Address) -> std::io::Result<UdpAssociate<Ready>> {
35 let resp = Response::new(reply, addr);
36 resp.write_to_async_stream(&mut self.stream).await?;
37 Ok(UdpAssociate::<Ready>::new(self.stream))
38 }
39
40 #[inline]
42 pub async fn shutdown(&mut self) -> std::io::Result<()> {
43 self.stream.shutdown().await
44 }
45
46 #[inline]
48 pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
49 self.stream.local_addr()
50 }
51
52 #[inline]
54 pub fn peer_addr(&self) -> std::io::Result<SocketAddr> {
55 self.stream.peer_addr()
56 }
57
58 #[inline]
62 pub fn linger(&self) -> std::io::Result<Option<Duration>> {
63 self.stream.linger()
64 }
65
66 #[inline]
74 pub fn set_linger(&self, dur: Option<Duration>) -> std::io::Result<()> {
75 self.stream.set_linger(dur)
76 }
77
78 #[inline]
82 pub fn nodelay(&self) -> std::io::Result<bool> {
83 self.stream.nodelay()
84 }
85
86 pub fn set_nodelay(&self, nodelay: bool) -> std::io::Result<()> {
92 self.stream.set_nodelay(nodelay)
93 }
94
95 pub fn ttl(&self) -> std::io::Result<u32> {
99 self.stream.ttl()
100 }
101
102 pub fn set_ttl(&self, ttl: u32) -> std::io::Result<()> {
106 self.stream.set_ttl(ttl)
107 }
108}
109
110#[derive(Debug, Default)]
111pub struct NeedReply;
112
113#[derive(Debug, Default)]
114pub struct Ready;
115
116impl UdpAssociate<Ready> {
117 pub async fn wait_until_closed(&mut self) -> std::io::Result<()> {
122 loop {
123 match self.stream.read(&mut [0]).await {
124 Ok(0) => break Ok(()),
125 Ok(_) => {}
126 Err(err) => break Err(err),
127 }
128 }
129 }
130}
131
132impl std::ops::Deref for UdpAssociate<Ready> {
133 type Target = TcpStream;
134
135 #[inline]
136 fn deref(&self) -> &Self::Target {
137 &self.stream
138 }
139}
140
141impl std::ops::DerefMut for UdpAssociate<Ready> {
142 #[inline]
143 fn deref_mut(&mut self) -> &mut Self::Target {
144 &mut self.stream
145 }
146}
147
148impl AsyncRead for UdpAssociate<Ready> {
149 #[inline]
150 fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
151 Pin::new(&mut self.stream).poll_read(cx, buf)
152 }
153}
154
155impl AsyncWrite for UdpAssociate<Ready> {
156 #[inline]
157 fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
158 Pin::new(&mut self.stream).poll_write(cx, buf)
159 }
160
161 #[inline]
162 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
163 Pin::new(&mut self.stream).poll_flush(cx)
164 }
165
166 #[inline]
167 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
168 Pin::new(&mut self.stream).poll_shutdown(cx)
169 }
170}
171
172impl<S> From<UdpAssociate<S>> for TcpStream {
173 #[inline]
174 fn from(conn: UdpAssociate<S>) -> Self {
175 conn.stream
176 }
177}
178
179#[derive(Debug)]
193pub struct AssociatedUdpSocket {
194 socket: UdpSocket,
195 buf_size: AtomicUsize,
196}
197
198impl AssociatedUdpSocket {
199 #[inline]
201 pub async fn connect<A: ToSocketAddrs>(&self, addr: A) -> std::io::Result<()> {
202 self.socket.connect(addr).await
203 }
204
205 pub fn get_max_packet_size(&self) -> usize {
207 self.buf_size.load(Ordering::Relaxed)
208 }
209
210 pub fn set_max_packet_size(&self, size: usize) {
212 self.buf_size.store(size, Ordering::Release);
213 }
214
215 pub async fn recv(&self) -> std::io::Result<(Bytes, u8, Address)> {
221 loop {
222 let max_packet_size = self.buf_size.load(Ordering::Acquire);
223 let mut buf = vec![0; max_packet_size];
224 let len = self.socket.recv(&mut buf).await?;
225 buf.truncate(len);
226 let pkt = Bytes::from(buf);
227
228 if let Ok(header) = UdpHeader::retrieve_from_async_stream(&mut pkt.as_ref()).await {
229 let pkt = pkt.slice(header.len()..);
230 return Ok((pkt, header.frag, header.address));
231 }
232 }
233 }
234
235 pub async fn recv_from(&self) -> std::io::Result<(Bytes, u8, Address, SocketAddr)> {
238 loop {
239 let max_packet_size = self.buf_size.load(Ordering::Acquire);
240 let mut buf = vec![0; max_packet_size];
241 let (len, src_addr) = self.socket.recv_from(&mut buf).await?;
242 buf.truncate(len);
243 let pkt = Bytes::from(buf);
244
245 if let Ok(header) = UdpHeader::retrieve_from_async_stream(&mut pkt.as_ref()).await {
246 let pkt = pkt.slice(header.len()..);
247 return Ok((pkt, header.frag, header.address, src_addr));
248 }
249 }
250 }
251
252 pub async fn send<P: AsRef<[u8]>>(&self, pkt: P, frag: u8, from_addr: Address) -> std::io::Result<usize> {
254 let header = UdpHeader::new(frag, from_addr);
255 let mut buf = BytesMut::with_capacity(header.len() + pkt.as_ref().len());
256 header.write_to_buf(&mut buf);
257 buf.extend_from_slice(pkt.as_ref());
258
259 self.socket.send(&buf).await.map(|len| len - header.len())
260 }
261
262 pub async fn send_to<P: AsRef<[u8]>>(&self, pkt: P, frag: u8, from_addr: Address, to_addr: SocketAddr) -> std::io::Result<usize> {
264 let header = UdpHeader::new(frag, from_addr);
265 let mut buf = BytesMut::with_capacity(header.len() + pkt.as_ref().len());
266 header.write_to_buf(&mut buf);
267 buf.extend_from_slice(pkt.as_ref());
268
269 self.socket.send_to(&buf, to_addr).await.map(|len| len - header.len())
270 }
271}
272
273impl From<(UdpSocket, usize)> for AssociatedUdpSocket {
274 #[inline]
275 fn from(from: (UdpSocket, usize)) -> Self {
276 AssociatedUdpSocket {
277 socket: from.0,
278 buf_size: AtomicUsize::new(from.1),
279 }
280 }
281}
282
283impl From<AssociatedUdpSocket> for UdpSocket {
284 #[inline]
285 fn from(from: AssociatedUdpSocket) -> Self {
286 from.socket
287 }
288}
289
290impl AsRef<UdpSocket> for AssociatedUdpSocket {
291 #[inline]
292 fn as_ref(&self) -> &UdpSocket {
293 &self.socket
294 }
295}
296
297impl AsMut<UdpSocket> for AssociatedUdpSocket {
298 #[inline]
299 fn as_mut(&mut self) -> &mut UdpSocket {
300 &mut self.socket
301 }
302}