1use net_relay::{Builder, Error};
2use std::net::SocketAddr;
3use std::pin::Pin;
4use std::sync::Arc;
5use std::task::{Context, Poll};
6use std::time::Duration;
7use tcp_pool::net_pool::{Pool, debug, info, instrument_debug_span, tokio_spawn, warn2};
8
9pub struct Relay<F, S, P = tcp_pool::Pool>
11where
12 F: Fn(Arc<P>, tokio::net::TcpStream, SocketAddr) -> S,
13 S: Future<Output = ()>,
14 P: tcp_pool::TcpPool + Pool,
15{
16 parts: net_relay::builder::Parts<P, F>,
17 pending: Option<Pin<Box<dyn Future<Output = Result<(), net_relay::Error>> + Send + 'static>>>,
18}
19
20impl<F, S, P> Relay<F, S, P>
21where
22 F: Fn(Arc<P>, tokio::net::TcpStream, SocketAddr) -> S,
23 S: Future<Output = ()>,
24 P: tcp_pool::TcpPool + Pool,
25{
26 pub fn build<B: FnOnce(Builder<P, F>) -> Builder<P, F>>(b: B) -> Result<Self, Error> {
27 let builder = Builder::new();
28 let parts = b(builder).build()?;
29 Ok(Relay {
30 parts,
31 pending: None,
32 })
33 }
34
35 pub fn bind_addrs(&self) -> &Vec<SocketAddr> {
36 &self.parts.bind_addrs
37 }
38
39 pub fn relay_fn(&self) -> Arc<F> {
40 self.parts.relay_fn.as_ref().unwrap().clone()
41 }
42
43 pub fn pool(&self) -> Arc<P> {
44 self.parts.pools[0].clone()
45 }
46
47 pub fn set_max_conn(&self, max: Option<usize>) {
49 self.pool().set_max_conn(max)
50 }
51
52 pub fn set_keepalive(&self, duration: Option<Duration>) {
54 self.pool().set_keepalive(duration)
55 }
56}
57
58impl<F, S, P> net_relay::Relay for Relay<F, S, P>
59where
60 F: Fn(Arc<P>, tokio::net::TcpStream, SocketAddr) -> S + Send + Sync + 'static,
61 S: Future<Output = ()> + Send + 'static,
62 P: tcp_pool::TcpPool + Pool + Send + 'static,
63{
64 fn poll_run(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
65 if self.pending.is_none() {
66 let tuple = (self.bind_addrs().clone(), self.pool(), self.relay_fn());
67
68 self.pending = Some(Box::pin(async move {
69 let listener = tokio::net::TcpListener::bind(tuple.0.as_slice()).await?;
70
71 info!(
72 "[Tcp Relay] listen on: {:?}",
73 listener.local_addr().unwrap()
74 );
75
76 loop {
77 match listener.accept().await {
78 Ok((client, addr)) => {
79 let tuple = (tuple.1.clone(), tuple.2.clone());
80 tokio_spawn! {
81 instrument_debug_span! {
82 async move {
83 debug!("[Tcp Relay] connection accepted");
84 let res = tuple.1(tuple.0, client, addr).await;
85 debug!("[Tcp Relay] connection closed");
86 res
87 },
88 "new_tcp_stream",
89 address=addr.to_string()
90 }
91 };
92 }
93 Err(_e) => {
94 warn2!("[Tcp Relay] accept from listen, error occurred: {:?}", _e);
95 }
96 }
97 }
98 }));
99 }
100
101 self.pending.as_mut().unwrap().as_mut().poll(cx)
102 }
103}
104
105pub async fn default_relay_fn<P: Pool + tcp_pool::TcpPool + Send>(
106 pool: Arc<P>,
107 mut client: tokio::net::TcpStream,
108 addr: SocketAddr,
109) {
110 let id = tcp_pool::net_pool::utils::socketaddr_to_hash_code(&addr);
111 let mut proxy = match pool.clone().get(&id.to_string()).await {
112 Err(_e) => {
113 warn2!(
114 "[Tcp Relay] get tcp stream from pool, error occurred: {:?}",
115 _e
116 );
117 return;
118 }
119 Ok(t) => t,
120 };
121
122 let proxy_mut: &mut tokio::net::TcpStream = &mut proxy;
123 let _cnt = tokio::io::copy_bidirectional(proxy_mut, &mut client).await;
124
125 debug!("[Tcp Relay] data exchange byte count: {:?}", _cnt);
126}