tcp_clone/
lib.rs

1use async_std::io;
2use async_std::net::{Shutdown, SocketAddr, TcpListener, TcpStream};
3use async_std::prelude::*;
4use async_std::sync::channel;
5use async_std::sync::Arc;
6use async_std::task;
7
8type Receiver = async_std::sync::Receiver<Arc<Vec<u8>>>;
9type Sender = async_std::sync::Sender<Arc<Vec<u8>>>;
10
11struct Addresses {
12    target_addr: SocketAddr,
13    tx_observer_addrs: Vec<SocketAddr>,
14    rx_observer_addrs: Vec<SocketAddr>,
15}
16
17impl Addresses {
18    fn new(
19        target_addr: SocketAddr,
20        tx_observer_addrs: Vec<SocketAddr>,
21        rx_observer_addrs: Vec<SocketAddr>,
22    ) -> Addresses {
23        Addresses {
24            target_addr,
25            tx_observer_addrs,
26            rx_observer_addrs,
27        }
28    }
29}
30
31struct Broadcaster {
32    txs: Vec<Sender>,
33}
34
35impl Broadcaster {
36    fn with_capacity(n: usize) -> Broadcaster {
37        Broadcaster {
38            txs: Vec::with_capacity(n + 1),
39        }
40    }
41
42    fn new_receiver(&mut self) -> Receiver {
43        let (sender, receiver) = channel(1024);
44        self.txs.push(sender);
45        receiver
46    }
47
48    fn write(&mut self, data: Vec<u8>) {
49        let data = Arc::new(data);
50        for tx in self.txs.iter() {
51            let tx = tx.clone();
52            let data = data.clone();
53            task::spawn(async move {
54                tx.send(data.clone()).await;
55            });
56        }
57    }
58}
59
60pub async fn run(
61    listen_addr: SocketAddr,
62    target_addr: SocketAddr,
63    tx_observer_addrs: Vec<SocketAddr>,
64    rx_observer_addrs: Vec<SocketAddr>,
65) -> io::Result<()> {
66    let addrs = Arc::new(Addresses::new(
67        target_addr,
68        tx_observer_addrs,
69        rx_observer_addrs,
70    ));
71    let listener = TcpListener::bind(listen_addr).await?;
72    let mut incoming = listener.incoming();
73    while let Some(client_stream) = incoming.next().await {
74        if let Ok(client_stream) = client_stream {
75            let addrs = addrs.clone();
76            task::spawn(async move {
77                handle_client(client_stream, addrs).await;
78            });
79        }
80    }
81    Ok(())
82}
83
84async fn handle_client(client_stream: TcpStream, addrs: Arc<Addresses>) {
85    if let Ok(target_stream) = TcpStream::connect(addrs.target_addr).await {
86        let mut client_tx_broadcaster = spawn_observers_write_loop(&addrs.tx_observer_addrs);
87        let mut client_rx_broadcaster = spawn_observers_write_loop(&addrs.rx_observer_addrs);
88        let target_receiver = client_tx_broadcaster.new_receiver();
89        let client_receiver = client_rx_broadcaster.new_receiver();
90        spawn_read_write_loop(target_stream, target_receiver, client_rx_broadcaster);
91        spawn_read_write_loop(client_stream, client_receiver, client_tx_broadcaster);
92    }
93}
94
95fn spawn_observers_write_loop(addrs: &[SocketAddr]) -> Broadcaster {
96    let mut broadcaster = Broadcaster::with_capacity(addrs.len() + 1);
97    for addr in addrs.iter() {
98        let addr = *addr;
99        let receiver = broadcaster.new_receiver();
100        task::spawn(async move {
101            if let Ok(stream) = TcpStream::connect(addr).await {
102                let _ = write_loop(&stream, receiver).await;
103            }
104        });
105    }
106    broadcaster
107}
108
109fn spawn_read_write_loop(stream: TcpStream, rx: Receiver, broadcaster: Broadcaster) {
110    let stream = Arc::new(stream);
111    let (reader, writer) = (stream.clone(), stream);
112    task::spawn(async move {
113        let reader = &*reader;
114        let _ = read_loop(reader, broadcaster).await;
115        let _ = reader.shutdown(Shutdown::Read);
116    });
117    task::spawn(async move {
118        let writer = &*writer;
119        let _ = write_loop(&writer, rx).await;
120        let _ = writer.shutdown(Shutdown::Write);
121    });
122}
123
124async fn write_loop(mut stream: &TcpStream, rx: Receiver) -> io::Result<()> {
125    while let Some(data) = rx.recv().await {
126        stream.write_all(&data).await?;
127    }
128    Ok(())
129}
130
131async fn read_loop(mut stream: &TcpStream, mut broadcaster: Broadcaster) -> io::Result<()> {
132    let mut buf = [0; 65535];
133    loop {
134        let n = stream.read(&mut buf).await?;
135        if n == 0 {
136            break;
137        }
138        broadcaster.write(buf[..n].to_vec());
139    }
140    Ok(())
141}