socks5_impl/server/connection/
connect.rs1use crate::protocol::{Address, AsyncStreamOperation, Reply, Response};
2use std::{
3 io::IoSlice,
4 net::SocketAddr,
5 pin::Pin,
6 task::{Context, Poll},
7};
8use tokio::{
9 io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf},
10 net::{
11 TcpStream,
12 tcp::{OwnedReadHalf, OwnedWriteHalf, ReadHalf, WriteHalf},
13 },
14};
15
16#[derive(Debug)]
20pub struct Connect<S> {
21 stream: TcpStream,
22 _state: S,
23}
24
25impl<S: Default> Connect<S> {
26 #[inline]
27 pub(super) fn new(stream: TcpStream) -> Self {
28 Self {
29 stream,
30 _state: S::default(),
31 }
32 }
33
34 #[inline]
36 pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
37 self.stream.local_addr()
38 }
39
40 #[inline]
42 pub fn peer_addr(&self) -> std::io::Result<SocketAddr> {
43 self.stream.peer_addr()
44 }
45
46 #[inline]
48 pub async fn shutdown(&mut self) -> std::io::Result<()> {
49 self.stream.shutdown().await
50 }
51}
52
53#[derive(Debug, Default)]
54pub struct NeedReply;
55
56#[derive(Debug, Default)]
57pub struct Ready;
58
59impl Connect<NeedReply> {
60 #[inline]
62 pub async fn reply(mut self, reply: Reply, addr: Address) -> std::io::Result<Connect<Ready>> {
63 let resp = Response::new(reply, addr);
64 resp.write_to_async_stream(&mut self.stream).await?;
65 Ok(Connect::<Ready>::new(self.stream))
66 }
67}
68
69impl Connect<Ready> {
70 #[inline]
72 pub fn split(&mut self) -> (ReadHalf, WriteHalf) {
73 self.stream.split()
74 }
75
76 #[inline]
78 pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) {
79 self.stream.into_split()
80 }
81}
82
83impl std::ops::Deref for Connect<Ready> {
84 type Target = TcpStream;
85
86 #[inline]
87 fn deref(&self) -> &Self::Target {
88 &self.stream
89 }
90}
91
92impl std::ops::DerefMut for Connect<Ready> {
93 #[inline]
94 fn deref_mut(&mut self) -> &mut Self::Target {
95 &mut self.stream
96 }
97}
98
99impl AsyncRead for Connect<Ready> {
100 #[inline]
101 fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
102 Pin::new(&mut self.stream).poll_read(cx, buf)
103 }
104}
105
106impl AsyncWrite for Connect<Ready> {
107 #[inline]
108 fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
109 Pin::new(&mut self.stream).poll_write(cx, buf)
110 }
111
112 #[inline]
113 fn poll_write_vectored(mut self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>]) -> Poll<std::io::Result<usize>> {
114 Pin::new(&mut self.stream).poll_write_vectored(cx, bufs)
115 }
116
117 #[inline]
118 fn is_write_vectored(&self) -> bool {
119 self.stream.is_write_vectored()
120 }
121
122 #[inline]
123 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
124 Pin::new(&mut self.stream).poll_flush(cx)
125 }
126
127 #[inline]
128 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
129 Pin::new(&mut self.stream).poll_shutdown(cx)
130 }
131}
132
133impl<S> From<Connect<S>> for TcpStream {
134 #[inline]
135 fn from(conn: Connect<S>) -> Self {
136 conn.stream
137 }
138}