1use crate::id::ID;
2use crate::{Sender, UdpTx, connection};
3use net_pool::backend::{Address, BackendState};
4use net_pool::error::Error;
5use net_pool::{debug, instrument_current_span, Strategy, tokio_spawn};
6use std::collections::HashMap;
7use std::io;
8use std::net::SocketAddr;
9use std::sync::{Arc, Mutex};
10use tokio::net::UdpSocket;
11
12pub struct Pool {
19 state: net_pool::pool::BaseState,
20 free_conn_map: Mutex<HashMap<ID, UdpTx>>,
21}
22
23impl Pool {
24 pub fn new(strategy: Arc<dyn Strategy>) -> Self {
25 let p = Pool {
26 state: net_pool::pool::BaseState::new(strategy),
27 free_conn_map: Mutex::new(HashMap::new()),
28 };
29 <Pool as net_pool::pool::Pool>::set_keepalive(
30 &p,
31 Some(std::time::Duration::from_secs(60 * 5)),
32 );
33 p
34 }
35}
36
37impl Default for Pool {
38 fn default() -> Self {
39 Pool::new(Arc::new(net_pool::strategy::RRStrategy::default()))
41 }
42}
43
44impl<L: Strategy + 'static> From<L> for Pool {
45 fn from(value: L) -> Self {
46 Self::new(Arc::new(value))
47 }
48}
49
50impl net_pool::pool::Pool for Pool {
51 net_pool::macros::base_pool_impl! {state}
52
53 fn remove_backend(&self, addr: &Address) -> bool {
54 if self.state.lb_strategy.remove_backend(addr) {
55 self.clear_bs_tx(addr);
57 true
58 } else {
59 false
60 }
61 }
62}
63
64impl Pool {
65 fn get_tx(&self, a: &SocketAddr) -> Option<UdpTx> {
66 let id = a.into();
67 let mut guard = self.free_conn_map.lock().unwrap();
68
69 let tx = guard.get(&id)?;
71 if tx.is_closed() {
72 guard.remove(&id);
73 assert!(
74 self.state
75 .cur_conn
76 .fetch_sub(1, std::sync::atomic::Ordering::Relaxed)
77 > 0
78 );
79 debug!(
80 "[udp pool] [desc] current socket count: {}",
81 self.state
82 .cur_conn
83 .load(std::sync::atomic::Ordering::Relaxed)
84 );
85 None
86 } else {
87 Some(tx.clone())
88 }
89 }
90
91 fn clear_bs_tx(&self, a: &Address) {
93 let mut guard = self.free_conn_map.lock().unwrap();
94 guard.retain(|k, _| {
95 if k.get_bid() != a.hash_code() {
96 true
97 } else {
98 assert!(
99 self.state
100 .cur_conn
101 .fetch_sub(1, std::sync::atomic::Ordering::Relaxed)
102 > 0
103 );
104 false
105 }
106 });
107
108 debug!(
109 "[udp pool] [desc] current socket count: {}",
110 self.state
111 .cur_conn
112 .load(std::sync::atomic::Ordering::Relaxed)
113 );
114 }
115
116 fn add_tx(&self, id: ID, tx: UdpTx) {
117 let mut guard = self.free_conn_map.lock().unwrap();
118 guard.insert(id, tx);
119 }
120
121 fn remove_tx(&self, id: ID) {
122 let mut guard = self.free_conn_map.lock().unwrap();
123 if guard.remove(&id).is_some() {
124 assert!(
125 self.state
126 .cur_conn
127 .fetch_sub(1, std::sync::atomic::Ordering::Relaxed)
128 > 0
129 );
130 debug!(
131 "[udp pool] [desc] current socket count: {}",
132 self.state
133 .cur_conn
134 .load(std::sync::atomic::Ordering::Relaxed)
135 );
136 }
137 }
138
139 async fn create_tx(
140 self: Arc<Self>,
141 a: SocketAddr,
142 client: Option<Arc<UdpSocket>>,
143 ) -> Result<UdpTx, Error> {
144 let mut id = ID::new(&a);
145 let bs = self
146 .state
147 .lb_strategy
148 .get_backend(&id.to_string())
149 .ok_or(Error::NoBackend)?;
150 id.set_bid(bs.hash_code());
151
152 let proxy = create_udp_socket(&bs).await?;
153 let ka = <Pool as net_pool::pool::Pool>::get_keepalive(&self);
154
155 let (tx, conn) = io(a, client, proxy, ka);
156
157 let pool = self.clone();
159 tokio_spawn! {
160 instrument_current_span! {
161 async move {
162 let _res = conn.await;
163 debug!("[udp pool] udp socket close: {:?}", _res);
164 pool.remove_tx(id);
166 }
167 }
168 };
169
170 self.add_tx(id, tx.clone());
174 Ok(tx)
175 }
176}
177
178pub trait UdpPool {
179 fn get(
180 self: Arc<Self>,
181 a: SocketAddr,
182 send: Option<Arc<UdpSocket>>,
183 ) -> impl Future<Output = Result<Sender, Error>> + Send;
184}
185
186impl UdpPool for Pool {
187 async fn get(
188 self: Arc<Self>,
189 a: SocketAddr,
190 send: Option<Arc<UdpSocket>>,
191 ) -> Result<Sender, Error> {
192 let tx = match self.get_tx(&a) {
193 Some(s) => Ok(s),
194 None => {
195 net_pool::pool::increase_current(&self.state.max_conn, &self.state.cur_conn)?;
197 self.clone().create_tx(a, send).await.map(|s| {
198 debug!(
199 "[udp pool] [incr] current socket count: {}",
200 self.state
201 .cur_conn
202 .load(std::sync::atomic::Ordering::Relaxed)
203 );
204 s
205 })
206 }
207 };
208
209 if tx.is_err() {
210 assert!(
211 self.state
212 .cur_conn
213 .fetch_sub(1, std::sync::atomic::Ordering::Relaxed)
214 > 0
215 );
216 }
217
218 tx.map(|tx| Sender::new(tx))
219 }
220}
221
222pub async fn get<P: UdpPool>(
223 pool: Arc<P>,
224 a: SocketAddr,
225 send: Option<Arc<UdpSocket>>,
226) -> Result<Sender, Error> {
227 UdpPool::get(pool, a, send).await
228}
229
230fn io(
231 client_addr: SocketAddr,
232 client: Option<Arc<UdpSocket>>,
233 proxy: UdpSocket,
234 keepalive: Option<std::time::Duration>,
235) -> (UdpTx, connection::Connection) {
236 let (tx, rx) = tokio::sync::mpsc::channel(20);
237 (
238 tx,
239 connection::Connection::new(client_addr, client, proxy, rx, keepalive),
240 )
241}
242
243async fn create_udp_socket(bs: &BackendState) -> Result<UdpSocket, Error> {
244 let recv = UdpSocket::bind("0.0.0.0:0").await?;
245
246 let a = to_socket_addrs(bs.get_address()).await?;
247 recv.connect(&a[..]).await?;
248
249 Ok(recv)
250}
251
252async fn to_socket_addrs(addr: &Address) -> io::Result<Vec<SocketAddr>> {
253 match addr {
254 Address::Ori(ori) => tokio::net::lookup_host(ori)
255 .await
256 .map(|a| a.into_iter().collect()),
257 Address::Addr(addr) => Ok(vec![addr.clone()]),
258 }
259}