tcp_relay_rust/
lib.rs

1use std::fmt::Debug;
2use std::net::SocketAddr;
3use std::net::ToSocketAddrs;
4#[cfg(unix)]
5use std::path::Path;
6
7use tokio::io;
8use tokio::io::AsyncRead;
9use tokio::io::AsyncWrite;
10use tokio::net::{TcpListener, TcpStream};
11
12#[derive(Clone)]
13pub enum Socket {
14    #[cfg(unix)]
15    Unix(String),
16    Tcp(SocketAddr),
17}
18
19enum SocketStream {
20    #[cfg(unix)]
21    Unix(tokio::net::UnixStream),
22    Tcp(TcpStream),
23}
24
25enum SocketListener {
26    #[cfg(unix)]
27    Unix(tokio::net::UnixListener),
28    Tcp(TcpListener),
29}
30
31pub struct RelayError(String);
32
33impl Debug for RelayError {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        f.debug_struct("RelayError")
36            .field("error", &self.0)
37            .finish()
38    }
39}
40
41impl TryFrom<String> for Socket {
42    type Error = RelayError;
43
44    fn try_from(value: String) -> Result<Self, Self::Error> {
45        match value.to_socket_addrs() {
46            Ok(mut socket_addr) => Ok(Socket::Tcp(socket_addr.next().unwrap())),
47            Err(err) => {
48                #[cfg(unix)]
49                if cfg!(unix) {
50                    let path = &Path::new(&value);
51                    return match path.exists() || path.parent().unwrap().exists() {
52                        true => Ok(Socket::Unix(value)),
53                        false => Err(RelayError(format!(
54                            "tcp failed with {}, unix socket will as both parent dir and current file does't exist",
55                            err
56                        ))),
57                    };
58                }
59                Err(RelayError(format!("parsing failed with error {}", err)))
60            }
61        }
62    }
63}
64
65impl Socket {
66    async fn connect(&self) -> anyhow::Result<SocketStream> {
67        match self {
68            #[cfg(unix)]
69            Socket::Unix(path) => Ok(SocketStream::Unix(
70                tokio::net::UnixStream::connect(path).await?,
71            )),
72            Socket::Tcp(addr) => Ok(SocketStream::Tcp(TcpStream::connect(addr).await?)),
73        }
74    }
75
76    async fn accept(&self, listener: &SocketListener) -> anyhow::Result<SocketStream> {
77        match listener {
78            #[cfg(unix)]
79            SocketListener::Unix(unixlistener) => {
80                let (listener, addr) = unixlistener.accept().await?;
81                println!("recieved connection from {:?}", &addr);
82                Ok(SocketStream::Unix(listener))
83            }
84            SocketListener::Tcp(tcplistener) => {
85                let (listener, addr) = tcplistener.accept().await?;
86                println!("recieved connection from {:?}", &addr);
87                Ok(SocketStream::Tcp(listener))
88            }
89        }
90    }
91
92    async fn listen(&self) -> anyhow::Result<SocketListener> {
93        match self {
94            #[cfg(unix)]
95            Socket::Unix(path) => Ok(SocketListener::Unix(tokio::net::UnixListener::bind(path)?)),
96            Socket::Tcp(addr) => Ok(SocketListener::Tcp(TcpListener::bind(addr).await?)),
97        }
98    }
99
100    pub async fn run(self, remote: Self) -> anyhow::Result<()> {
101        let socket_listener = self.listen().await?;
102        loop {
103            let socket_stream = match self.accept(&socket_listener).await {
104                Ok(socket_stream) => socket_stream,
105                Err(accept_error) => {
106                    println!("accpeting socket failed with error {}", accept_error);
107                    continue;
108                }
109            };
110            let remote_stream = match remote.clone().connect().await {
111                Ok(remote_stream) => remote_stream,
112                Err(connect_error) => {
113                    println!(
114                        "connecting to remote socket failed with error {}",
115                        connect_error
116                    );
117                    continue;
118                }
119            };
120            tokio::spawn(socket_stream.proxy(remote_stream));
121        }
122    }
123}
124
125pub enum StdOrSocket {
126    Socket(Socket),
127    Std,
128}
129
130impl StdOrSocket {
131    pub async fn run(self, remote: Socket) -> anyhow::Result<()> {
132        match self {
133            StdOrSocket::Socket(local) => Ok(local.run(remote).await?),
134            StdOrSocket::Std => {
135                let stdin = tokio::io::stdin();
136                let stdout = tokio::io::stdout();
137                let connect = remote.connect().await;
138                if connect.is_err() {
139                    println!("unable to connecto remote host");
140                }
141                let sockstream = connect?;
142                match sockstream {
143                    #[cfg(unix)]
144                    SocketStream::Unix(unixstream) => proxy_std(stdin, stdout, unixstream).await,
145                    SocketStream::Tcp(tcpstream) => proxy_std(stdin, stdout, tcpstream).await,
146                };
147                Ok(())
148            }
149        }
150    }
151}
152
153pub async fn proxy_std<T1, T2, T3>(mut read: T1, mut write: T2, other: T3)
154where
155    T1: AsyncRead + Unpin,
156    T2: AsyncWrite + Unpin,
157    T3: AsyncRead + Unpin + AsyncWrite,
158{
159    let (mut read_2, mut write_2) = io::split(other);
160    tokio::select! {
161        _=io::copy(&mut read, &mut write_2)=>{},
162        _=io::copy(&mut read_2, &mut write)=>{}
163    }
164    println!("closing connection");
165}
166
167pub async fn proxy<T1, T2>(s1: T1, s2: T2)
168where
169    T1: AsyncRead + AsyncWrite + Unpin,
170    T2: AsyncRead + AsyncWrite + Unpin,
171{
172    let (mut read_1, mut write_1) = io::split(s1);
173    let (mut read_2, mut write_2) = io::split(s2);
174    tokio::select! {
175        _=io::copy(&mut read_1, &mut write_2)=>{},
176        _=io::copy(&mut read_2, &mut write_1)=>{}
177    }
178    println!("closing connection");
179}
180
181impl SocketStream {
182    async fn proxy(self, socket2: SocketStream) {
183        match (self, socket2) {
184            #[cfg(unix)]
185            (SocketStream::Unix(unixstream), SocketStream::Tcp(tcpstream)) => {
186                proxy(unixstream, tcpstream).await;
187            }
188            #[cfg(unix)]
189            (SocketStream::Tcp(tcpstream), SocketStream::Unix(unixstream)) => {
190                proxy(tcpstream, unixstream).await;
191            }
192            #[cfg(unix)]
193            (SocketStream::Unix(unixstream), SocketStream::Unix(unixstream2)) => {
194                proxy(unixstream, unixstream2).await;
195            }
196            (SocketStream::Tcp(tcpstream), SocketStream::Tcp(tcpstream2)) => {
197                proxy(tcpstream, tcpstream2).await;
198            }
199        }
200    }
201}