tcp_relay/
relay.rs

1use net_pool::{debug, info, instrument_debug_span, tokio_spawn, warn2};
2use net_relay::{Builder, Error};
3use std::net::SocketAddr;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7use std::time::Duration;
8
9/// tcp relay
10pub struct Relay<F, S, P = tcp_pool::Pool>
11where
12    F: Fn(Arc<P>, tokio::net::TcpStream, SocketAddr) -> S,
13    S: Future<Output = ()>,
14    P: tcp_pool::TcpPool + net_pool::Pool,
15{
16    parts: net_relay::builder::Parts<P, F>,
17    pending: Option<Pin<Box<dyn Future<Output = Result<(), net_relay::Error>> + Send + 'static>>>,
18}
19
20impl<F, S, P> Relay<F, S, P>
21where
22    F: Fn(Arc<P>, tokio::net::TcpStream, SocketAddr) -> S,
23    S: Future<Output = ()>,
24    P: tcp_pool::TcpPool + net_pool::Pool,
25{
26    pub fn build<B: FnOnce(Builder<P, F>) -> Builder<P, F>>(b: B) -> Result<Self, Error> {
27        let builder = Builder::new();
28        let parts = b(builder).build()?;
29        Ok(Relay {
30            parts,
31            pending: None,
32        })
33    }
34
35    pub fn bind_addrs(&self) -> &Vec<SocketAddr> {
36        &self.parts.bind_addrs
37    }
38
39    pub fn relay_fn(&self) -> Arc<F> {
40        self.parts.relay_fn.as_ref().unwrap().clone()
41    }
42
43    pub fn pool(&self) -> Arc<P> {
44        self.parts.pools[0].clone()
45    }
46
47    /// 设置最大连接数
48    pub fn set_max_conn(&self, max: Option<usize>) {
49        self.pool().set_max_conn(max)
50    }
51
52    /// 设置空闲连接保留时长
53    pub fn set_keepalive(&self, duration: Option<Duration>) {
54        self.pool().set_keepalive(duration)
55    }
56
57    /// 添加一个后端地址
58    pub fn add_backend(&self, addr: net_pool::backend::Address) {
59        self.pool().add_backend(addr)
60    }
61
62    /// 移除一个后端地址
63    pub fn remove_backend(&self, addr: &net_pool::backend::Address) -> bool {
64        self.pool().remove_backend(addr)
65    }
66}
67
68impl<F, S, P> net_relay::Relay for Relay<F, S, P>
69where
70    F: Fn(Arc<P>, tokio::net::TcpStream, SocketAddr) -> S + Send + Sync + 'static,
71    S: Future<Output = ()> + Send + 'static,
72    P: tcp_pool::TcpPool + net_pool::Pool + Send + 'static,
73{
74    fn poll_run(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
75        if self.pending.is_none() {
76            let tuple = (self.bind_addrs().clone(), self.pool(), self.relay_fn());
77
78            self.pending = Some(Box::pin(async move {
79                let listener = tokio::net::TcpListener::bind(tuple.0.as_slice()).await?;
80
81                info!(
82                    "[Tcp Relay] listen on: {:?}",
83                    listener.local_addr().unwrap()
84                );
85
86                loop {
87                    match listener.accept().await {
88                        Ok((client, addr)) => {
89                            let tuple = (tuple.1.clone(), tuple.2.clone());
90                            tokio_spawn! {
91                                instrument_debug_span! {
92                                    async move {
93                                        debug!("[Tcp Relay] connection accepted");
94                                        let res = tuple.1(tuple.0, client, addr).await;
95                                        debug!("[Tcp Relay] connection closed");
96                                        res
97                                    },
98                                    "new_tcp_stream",
99                                    address=addr.to_string()
100                                }
101                            };
102                        }
103                        Err(_e) => {
104                            warn2!("[Tcp Relay] accept from listen, error occurred: {:?}", _e);
105                        }
106                    }
107                }
108            }));
109        }
110
111        self.pending.as_mut().unwrap().as_mut().poll(cx)
112    }
113}
114
115pub async fn default_relay_fn<P: net_pool::Pool + tcp_pool::TcpPool + Send>(
116    pool: Arc<P>,
117    mut client: tokio::net::TcpStream,
118    addr: SocketAddr,
119) {
120    let id = net_pool::utils::socketaddr_to_hash_code(&addr);
121    let mut proxy = match pool.clone().get(&id.to_string()).await {
122        Err(_e) => {
123            warn2!(
124                "[Tcp Relay] get tcp stream from pool, error occurred: {:?}",
125                _e
126            );
127            return;
128        }
129        Ok(t) => t,
130    };
131
132    let proxy_mut: &mut tokio::net::TcpStream = &mut proxy;
133    let _cnt = tokio::io::copy_bidirectional(proxy_mut, &mut client).await;
134
135    debug!("[Tcp Relay] data exchange byte count: {:?}", _cnt);
136}