socks5_impl/server/connection/
connect.rs

1use crate::protocol::{Address, AsyncStreamOperation, Reply, Response};
2use std::{
3    io::IoSlice,
4    net::SocketAddr,
5    pin::Pin,
6    task::{Context, Poll},
7};
8use tokio::{
9    io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf},
10    net::{
11        TcpStream,
12        tcp::{OwnedReadHalf, OwnedWriteHalf, ReadHalf, WriteHalf},
13    },
14};
15
16/// Socks5 connection type `Connect`
17///
18/// This connection can be used as a regular async TCP stream after replying the client.
19#[derive(Debug)]
20pub struct Connect<S> {
21    stream: TcpStream,
22    _state: S,
23}
24
25impl<S: Default> Connect<S> {
26    #[inline]
27    pub(super) fn new(stream: TcpStream) -> Self {
28        Self {
29            stream,
30            _state: S::default(),
31        }
32    }
33
34    /// Returns the local address that this stream is bound to.
35    #[inline]
36    pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
37        self.stream.local_addr()
38    }
39
40    /// Returns the remote address that this stream is connected to.
41    #[inline]
42    pub fn peer_addr(&self) -> std::io::Result<SocketAddr> {
43        self.stream.peer_addr()
44    }
45
46    /// Shutdown the TCP stream.
47    #[inline]
48    pub async fn shutdown(&mut self) -> std::io::Result<()> {
49        self.stream.shutdown().await
50    }
51}
52
53#[derive(Debug, Default)]
54pub struct NeedReply;
55
56#[derive(Debug, Default)]
57pub struct Ready;
58
59impl Connect<NeedReply> {
60    /// Reply to the client.
61    #[inline]
62    pub async fn reply(mut self, reply: Reply, addr: Address) -> std::io::Result<Connect<Ready>> {
63        let resp = Response::new(reply, addr);
64        resp.write_to_async_stream(&mut self.stream).await?;
65        Ok(Connect::<Ready>::new(self.stream))
66    }
67}
68
69impl Connect<Ready> {
70    /// Returns the read/write half of the stream.
71    #[inline]
72    pub fn split(&mut self) -> (ReadHalf, WriteHalf) {
73        self.stream.split()
74    }
75
76    /// Returns the owned read/write half of the stream.
77    #[inline]
78    pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) {
79        self.stream.into_split()
80    }
81}
82
83impl std::ops::Deref for Connect<Ready> {
84    type Target = TcpStream;
85
86    #[inline]
87    fn deref(&self) -> &Self::Target {
88        &self.stream
89    }
90}
91
92impl std::ops::DerefMut for Connect<Ready> {
93    #[inline]
94    fn deref_mut(&mut self) -> &mut Self::Target {
95        &mut self.stream
96    }
97}
98
99impl AsyncRead for Connect<Ready> {
100    #[inline]
101    fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
102        Pin::new(&mut self.stream).poll_read(cx, buf)
103    }
104}
105
106impl AsyncWrite for Connect<Ready> {
107    #[inline]
108    fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
109        Pin::new(&mut self.stream).poll_write(cx, buf)
110    }
111
112    #[inline]
113    fn poll_write_vectored(mut self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>]) -> Poll<std::io::Result<usize>> {
114        Pin::new(&mut self.stream).poll_write_vectored(cx, bufs)
115    }
116
117    #[inline]
118    fn is_write_vectored(&self) -> bool {
119        self.stream.is_write_vectored()
120    }
121
122    #[inline]
123    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
124        Pin::new(&mut self.stream).poll_flush(cx)
125    }
126
127    #[inline]
128    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
129        Pin::new(&mut self.stream).poll_shutdown(cx)
130    }
131}
132
133impl<S> From<Connect<S>> for TcpStream {
134    #[inline]
135    fn from(conn: Connect<S>) -> Self {
136        conn.stream
137    }
138}