1use crate::id::ID;
2use crate::{Sender, UdpTx, connection};
3use net_pool::backend::{Address, BackendState};
4use net_pool::error::Error;
5use net_pool::strategy::LbStrategy;
6use net_pool::{debug, instrument_current_span, tokio_spawn, trace};
7use std::collections::HashMap;
8use std::io;
9use std::net::SocketAddr;
10use std::sync::{Arc, Mutex};
11use tokio::net::UdpSocket;
12
13pub struct Pool {
20 state: net_pool::pool::BaseState,
21 free_conn_map: Mutex<HashMap<ID, UdpTx>>,
22}
23
24impl Pool {
25 pub fn new(strategy: Arc<dyn LbStrategy>) -> Self {
26 let p = Pool {
27 state: net_pool::pool::BaseState::new(strategy),
28 free_conn_map: Mutex::new(HashMap::new()),
29 };
30 <Pool as net_pool::pool::Pool>::set_keepalive(
31 &p,
32 Some(std::time::Duration::from_secs(60 * 5)),
33 );
34 p
35 }
36}
37
38impl Default for Pool {
39 fn default() -> Self {
40 Pool::new(Arc::new(net_pool::strategy::RRStrategy::default()))
42 }
43}
44
45impl<L: LbStrategy + 'static> From<L> for Pool {
46 fn from(value: L) -> Self {
47 Self::new(Arc::new(value))
48 }
49}
50
51impl net_pool::pool::Pool for Pool {
52 net_pool::macros::base_pool_impl! {state}
53
54 fn remove_backend(&self, addr: &Address) -> bool {
55 if self.state.lb_strategy.remove_backend(addr) {
56 self.clear_bs_tx(addr);
58 true
59 } else {
60 false
61 }
62 }
63}
64
65impl Pool {
66 fn get_tx(&self, a: &SocketAddr) -> Option<UdpTx> {
67 let id = a.into();
68 let mut guard = self.free_conn_map.lock().unwrap();
69
70 let tx = guard.get(&id)?;
72 if tx.is_closed() {
73 guard.remove(&id);
74 assert!(
75 self.state
76 .cur_conn
77 .fetch_sub(1, std::sync::atomic::Ordering::Relaxed)
78 > 0
79 );
80 trace!(
81 "[udp pool] [desc] current socket count: {}",
82 self.state
83 .cur_conn
84 .load(std::sync::atomic::Ordering::Relaxed)
85 );
86 None
87 } else {
88 Some(tx.clone())
89 }
90 }
91
92 fn clear_bs_tx(&self, a: &Address) {
94 let mut guard = self.free_conn_map.lock().unwrap();
95 guard.retain(|k, _| {
96 if k.get_bid() != a.hash_code() {
97 true
98 } else {
99 assert!(
100 self.state
101 .cur_conn
102 .fetch_sub(1, std::sync::atomic::Ordering::Relaxed)
103 > 0
104 );
105 false
106 }
107 });
108
109 trace!(
110 "[udp pool] [desc] current socket count: {}",
111 self.state
112 .cur_conn
113 .load(std::sync::atomic::Ordering::Relaxed)
114 );
115 }
116
117 fn add_tx(&self, id: ID, tx: UdpTx) {
118 let mut guard = self.free_conn_map.lock().unwrap();
119 guard.insert(id, tx);
120 }
121
122 fn remove_tx(&self, id: ID) {
123 let mut guard = self.free_conn_map.lock().unwrap();
124 if guard.remove(&id).is_some() {
125 assert!(
126 self.state
127 .cur_conn
128 .fetch_sub(1, std::sync::atomic::Ordering::Relaxed)
129 > 0
130 );
131 trace!(
132 "[udp pool] [desc] current socket count: {}",
133 self.state
134 .cur_conn
135 .load(std::sync::atomic::Ordering::Relaxed)
136 );
137 }
138 }
139
140 async fn create_tx(
141 self: Arc<Self>,
142 a: SocketAddr,
143 client: Option<Arc<UdpSocket>>,
144 ) -> Result<UdpTx, Error> {
145 let mut id = ID::new(&a);
146 let bs = self
147 .state
148 .lb_strategy
149 .get_backend(&id.to_string())
150 .ok_or(Error::NoBackend)?;
151 id.set_bid(bs.hash_code());
152
153 let proxy = create_udp_socket(&bs).await?;
154 let ka = <Pool as net_pool::pool::Pool>::get_keepalive(&self);
155
156 let (tx, conn) = io(a, client, proxy, ka);
157
158 let pool = self.clone();
160 tokio_spawn! {
161 instrument_current_span! {
162 async move {
163 let _res = conn.await;
164 debug!("[udp pool] udp socket close: {:?}", _res);
165 pool.remove_tx(id);
167 }
168 }
169 };
170
171 self.add_tx(id, tx.clone());
172 Ok(tx)
173 }
174}
175
176pub trait UdpPool {
177 fn get(
178 self: Arc<Self>,
179 a: SocketAddr,
180 send: Option<Arc<UdpSocket>>,
181 ) -> impl Future<Output = Result<Sender, Error>> + Send;
182}
183
184impl UdpPool for Pool {
185 async fn get(
186 self: Arc<Self>,
187 a: SocketAddr,
188 send: Option<Arc<UdpSocket>>,
189 ) -> Result<Sender, Error> {
190 let tx = match self.get_tx(&a) {
191 Some(s) => Ok(s),
192 None => {
193 net_pool::pool::increase_current(&self.state.max_conn, &self.state.cur_conn)?;
195 self.clone().create_tx(a, send).await.map(|s| {
196 trace!(
197 "[udp pool] [incr] current socket count: {}",
198 self.state
199 .cur_conn
200 .load(std::sync::atomic::Ordering::Relaxed)
201 );
202 s
203 })
204 }
205 };
206
207 if tx.is_err() {
208 assert!(
209 self.state
210 .cur_conn
211 .fetch_sub(1, std::sync::atomic::Ordering::Relaxed)
212 > 0
213 );
214 }
215
216 tx.map(|tx| Sender::new(tx))
217 }
218}
219
220pub async fn get<P: UdpPool>(
221 pool: Arc<P>,
222 a: SocketAddr,
223 send: Option<Arc<UdpSocket>>,
224) -> Result<Sender, Error> {
225 UdpPool::get(pool, a, send).await
226}
227
228fn io(
229 client_addr: SocketAddr,
230 client: Option<Arc<UdpSocket>>,
231 proxy: UdpSocket,
232 keepalive: Option<std::time::Duration>,
233) -> (UdpTx, connection::Connection) {
234 let (tx, rx) = tokio::sync::mpsc::channel(20);
235 (
236 tx,
237 connection::Connection::new(client_addr, client, proxy, rx, keepalive),
238 )
239}
240
241async fn create_udp_socket(bs: &BackendState) -> Result<UdpSocket, Error> {
242 let recv = UdpSocket::bind("0.0.0.0:0").await?;
243
244 let a = to_socket_addrs(bs.get_address()).await?;
245 recv.connect(&a[..]).await?;
246
247 Ok(recv)
248}
249
250async fn to_socket_addrs(addr: &Address) -> io::Result<Vec<SocketAddr>> {
251 async fn inner<T: tokio::net::ToSocketAddrs>(host: &T) -> io::Result<Vec<SocketAddr>> {
252 tokio::net::lookup_host(host)
253 .await
254 .map(|a| a.into_iter().collect())
255 }
256
257 match addr {
258 Address::Ori(ori) => inner(ori).await,
259 Address::Addr(addr) => inner(addr).await,
260 }
261}