1use net_pool::{debug, info, instrument_debug_span, tokio_spawn, warn2};
2use net_relay::{Builder, Error};
3use std::net::SocketAddr;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7use std::time::Duration;
8
9pub struct Relay<F, S, P = tcp_pool::Pool>
11where
12 F: Fn(Arc<P>, tokio::net::TcpStream, SocketAddr) -> S,
13 S: Future<Output = ()>,
14 P: tcp_pool::TcpPool + net_pool::Pool,
15{
16 parts: net_relay::builder::Parts<P, F>,
17 pending: Option<Pin<Box<dyn Future<Output = Result<(), net_relay::Error>> + Send + 'static>>>,
18}
19
20impl<F, S, P> Relay<F, S, P>
21where
22 F: Fn(Arc<P>, tokio::net::TcpStream, SocketAddr) -> S,
23 S: Future<Output = ()>,
24 P: tcp_pool::TcpPool + net_pool::Pool,
25{
26 pub fn build<B: FnOnce(Builder<P, F>) -> Builder<P, F>>(b: B) -> Result<Self, Error> {
27 let builder = Builder::new();
28 let parts = b(builder).build()?;
29 Ok(Relay {
30 parts,
31 pending: None,
32 })
33 }
34
35 pub fn bind_addrs(&self) -> &Vec<SocketAddr> {
36 &self.parts.bind_addrs
37 }
38
39 pub fn relay_fn(&self) -> Arc<F> {
40 self.parts.relay_fn.as_ref().unwrap().clone()
41 }
42
43 pub fn pool(&self) -> Arc<P> {
44 self.parts.pools[0].clone()
45 }
46
47 pub fn set_max_conn(&self, max: Option<usize>) {
49 self.pool().set_max_conn(max)
50 }
51
52 pub fn set_keepalive(&self, duration: Option<Duration>) {
54 self.pool().set_keepalive(duration)
55 }
56
57 pub fn add_backend(&self, addr: net_pool::backend::Address) {
59 self.pool().add_backend(addr)
60 }
61
62 pub fn remove_backend(&self, addr: &net_pool::backend::Address) -> bool {
64 self.pool().remove_backend(addr)
65 }
66}
67
68impl<F, S, P> net_relay::Relay for Relay<F, S, P>
69where
70 F: Fn(Arc<P>, tokio::net::TcpStream, SocketAddr) -> S + Send + Sync + 'static,
71 S: Future<Output = ()> + Send + 'static,
72 P: tcp_pool::TcpPool + net_pool::Pool + Send + 'static,
73{
74 fn poll_run(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
75 if self.pending.is_none() {
76 let tuple = (self.bind_addrs().clone(), self.pool(), self.relay_fn());
77
78 self.pending = Some(Box::pin(async move {
79 let listener = tokio::net::TcpListener::bind(tuple.0.as_slice()).await?;
80
81 info!(
82 "[Tcp Relay] listen on: {:?}",
83 listener.local_addr().unwrap()
84 );
85
86 loop {
87 match listener.accept().await {
88 Ok((client, addr)) => {
89 let tuple = (tuple.1.clone(), tuple.2.clone());
90 tokio_spawn! {
91 instrument_debug_span! {
92 async move {
93 debug!("[Tcp Relay] connection accepted");
94 let res = tuple.1(tuple.0, client, addr).await;
95 debug!("[Tcp Relay] connection closed");
96 res
97 },
98 "new_tcp_stream",
99 address=addr.to_string()
100 }
101 };
102 }
103 Err(_e) => {
104 warn2!("[Tcp Relay] accept from listen, error occurred: {:?}", _e);
105 }
106 }
107 }
108 }));
109 }
110
111 self.pending.as_mut().unwrap().as_mut().poll(cx)
112 }
113}
114
115pub async fn default_relay_fn<P: net_pool::Pool + tcp_pool::TcpPool + Send>(
116 pool: Arc<P>,
117 mut client: tokio::net::TcpStream,
118 addr: SocketAddr,
119) {
120 let id = net_pool::utils::socketaddr_to_hash_code(&addr);
121 let mut proxy = match pool.clone().get(&id.to_string()).await {
122 Err(_e) => {
123 warn2!(
124 "[Tcp Relay] get tcp stream from pool, error occurred: {:?}",
125 _e
126 );
127 return;
128 }
129 Ok(t) => t,
130 };
131
132 let proxy_mut: &mut tokio::net::TcpStream = &mut proxy;
133 let _cnt = tokio::io::copy_bidirectional(proxy_mut, &mut client).await;
134
135 debug!("[Tcp Relay] data exchange byte count: {:?}", _cnt);
136}