1use std::fmt::Debug;
2use std::net::SocketAddr;
3use std::net::ToSocketAddrs;
4#[cfg(unix)]
5use std::path::Path;
6
7use tokio::io;
8use tokio::io::AsyncRead;
9use tokio::io::AsyncWrite;
10use tokio::net::{TcpListener, TcpStream};
11
12#[derive(Clone)]
13pub enum Socket {
14 #[cfg(unix)]
15 Unix(String),
16 Tcp(SocketAddr),
17}
18
19enum SocketStream {
20 #[cfg(unix)]
21 Unix(tokio::net::UnixStream),
22 Tcp(TcpStream),
23}
24
25enum SocketListener {
26 #[cfg(unix)]
27 Unix(tokio::net::UnixListener),
28 Tcp(TcpListener),
29}
30
31pub struct RelayError(String);
32
33impl Debug for RelayError {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 f.debug_struct("RelayError")
36 .field("error", &self.0)
37 .finish()
38 }
39}
40
41impl TryFrom<String> for Socket {
42 type Error = RelayError;
43
44 fn try_from(value: String) -> Result<Self, Self::Error> {
45 match value.to_socket_addrs() {
46 Ok(mut socket_addr) => Ok(Socket::Tcp(socket_addr.next().unwrap())),
47 Err(err) => {
48 #[cfg(unix)]
49 if cfg!(unix) {
50 let path = &Path::new(&value);
51 return match path.exists() || path.parent().unwrap().exists() {
52 true => Ok(Socket::Unix(value)),
53 false => Err(RelayError(format!(
54 "tcp failed with {}, unix socket will as both parent dir and current file does't exist",
55 err
56 ))),
57 };
58 }
59 Err(RelayError(format!("parsing failed with error {}", err)))
60 }
61 }
62 }
63}
64
65impl Socket {
66 async fn connect(&self) -> anyhow::Result<SocketStream> {
67 match self {
68 #[cfg(unix)]
69 Socket::Unix(path) => Ok(SocketStream::Unix(
70 tokio::net::UnixStream::connect(path).await?,
71 )),
72 Socket::Tcp(addr) => Ok(SocketStream::Tcp(TcpStream::connect(addr).await?)),
73 }
74 }
75
76 async fn accept(&self, listener: &SocketListener) -> anyhow::Result<SocketStream> {
77 match listener {
78 #[cfg(unix)]
79 SocketListener::Unix(unixlistener) => {
80 let (listener, addr) = unixlistener.accept().await?;
81 println!("recieved connection from {:?}", &addr);
82 Ok(SocketStream::Unix(listener))
83 }
84 SocketListener::Tcp(tcplistener) => {
85 let (listener, addr) = tcplistener.accept().await?;
86 println!("recieved connection from {:?}", &addr);
87 Ok(SocketStream::Tcp(listener))
88 }
89 }
90 }
91
92 async fn listen(&self) -> anyhow::Result<SocketListener> {
93 match self {
94 #[cfg(unix)]
95 Socket::Unix(path) => Ok(SocketListener::Unix(tokio::net::UnixListener::bind(path)?)),
96 Socket::Tcp(addr) => Ok(SocketListener::Tcp(TcpListener::bind(addr).await?)),
97 }
98 }
99
100 pub async fn run(self, remote: Self) -> anyhow::Result<()> {
101 let socket_listener = self.listen().await?;
102 loop {
103 let socket_stream = match self.accept(&socket_listener).await {
104 Ok(socket_stream) => socket_stream,
105 Err(accept_error) => {
106 println!("accpeting socket failed with error {}", accept_error);
107 continue;
108 }
109 };
110 let remote_stream = match remote.clone().connect().await {
111 Ok(remote_stream) => remote_stream,
112 Err(connect_error) => {
113 println!(
114 "connecting to remote socket failed with error {}",
115 connect_error
116 );
117 continue;
118 }
119 };
120 tokio::spawn(socket_stream.proxy(remote_stream));
121 }
122 }
123}
124
125pub enum StdOrSocket {
126 Socket(Socket),
127 Std,
128}
129
130impl StdOrSocket {
131 pub async fn run(self, remote: Socket) -> anyhow::Result<()> {
132 match self {
133 StdOrSocket::Socket(local) => Ok(local.run(remote).await?),
134 StdOrSocket::Std => {
135 let stdin = tokio::io::stdin();
136 let stdout = tokio::io::stdout();
137 let connect = remote.connect().await;
138 if connect.is_err() {
139 println!("unable to connecto remote host");
140 }
141 let sockstream = connect?;
142 match sockstream {
143 #[cfg(unix)]
144 SocketStream::Unix(unixstream) => proxy_std(stdin, stdout, unixstream).await,
145 SocketStream::Tcp(tcpstream) => proxy_std(stdin, stdout, tcpstream).await,
146 };
147 Ok(())
148 }
149 }
150 }
151}
152
153pub async fn proxy_std<T1, T2, T3>(mut read: T1, mut write: T2, other: T3)
154where
155 T1: AsyncRead + Unpin,
156 T2: AsyncWrite + Unpin,
157 T3: AsyncRead + Unpin + AsyncWrite,
158{
159 let (mut read_2, mut write_2) = io::split(other);
160 tokio::select! {
161 _=io::copy(&mut read, &mut write_2)=>{},
162 _=io::copy(&mut read_2, &mut write)=>{}
163 }
164 println!("closing connection");
165}
166
167pub async fn proxy<T1, T2>(s1: T1, s2: T2)
168where
169 T1: AsyncRead + AsyncWrite + Unpin,
170 T2: AsyncRead + AsyncWrite + Unpin,
171{
172 let (mut read_1, mut write_1) = io::split(s1);
173 let (mut read_2, mut write_2) = io::split(s2);
174 tokio::select! {
175 _=io::copy(&mut read_1, &mut write_2)=>{},
176 _=io::copy(&mut read_2, &mut write_1)=>{}
177 }
178 println!("closing connection");
179}
180
181impl SocketStream {
182 async fn proxy(self, socket2: SocketStream) {
183 match (self, socket2) {
184 #[cfg(unix)]
185 (SocketStream::Unix(unixstream), SocketStream::Tcp(tcpstream)) => {
186 proxy(unixstream, tcpstream).await;
187 }
188 #[cfg(unix)]
189 (SocketStream::Tcp(tcpstream), SocketStream::Unix(unixstream)) => {
190 proxy(tcpstream, unixstream).await;
191 }
192 #[cfg(unix)]
193 (SocketStream::Unix(unixstream), SocketStream::Unix(unixstream2)) => {
194 proxy(unixstream, unixstream2).await;
195 }
196 (SocketStream::Tcp(tcpstream), SocketStream::Tcp(tcpstream2)) => {
197 proxy(tcpstream, tcpstream2).await;
198 }
199 }
200 }
201}