1use crate::TcpStream;
2use net_pool::backend::Address;
3use net_pool::{debug, Error, Strategy};
4use std::io;
5use std::net::SocketAddr;
6use std::sync::Arc;
7
8pub struct Pool {
13 state: net_pool::pool::BaseState,
14}
15
16impl Pool {
17 pub fn new(strategy: Arc<dyn Strategy>) -> Self {
18 Pool {
19 state: net_pool::pool::BaseState::new(strategy),
20 }
21 }
22}
23
24impl Default for Pool {
25 fn default() -> Self {
26 Pool::new(Arc::new(net_pool::strategy::HashStrategy::default()))
27 }
28}
29
30impl<L: Strategy + 'static> From<L> for Pool {
31 fn from(value: L) -> Self {
32 Self::new(Arc::new(value))
33 }
34}
35
36impl net_pool::pool::Pool for Pool {
37 net_pool::macros::base_pool_impl! {state}
38}
39
40pub trait TcpPool {
41 fn get(self: Arc<Self>, key: &str) -> impl Future<Output = Result<TcpStream, Error>> + Send;
42}
43
44impl TcpPool for Pool {
45 async fn get(self: Arc<Self>, key: &str) -> Result<TcpStream, Error> {
46 net_pool::pool::increase_current(&self.state.max_conn, &self.state.cur_conn)?;
48
49 let tcp = {
50 match self
51 .state
52 .lb_strategy
53 .get_backend(key)
54 .ok_or(Error::NoBackend)
55 {
56 Err(e) => Err(e),
57 Ok(bs) => create_tcp_stream(bs.get_address()).await,
58 }
59 };
60
61 if tcp.is_err() {
62 assert!(
63 self.state
64 .cur_conn
65 .fetch_sub(1, std::sync::atomic::Ordering::Relaxed)
66 > 0
67 );
68 } else {
69 debug!(
70 "[tcp pool] [incr] current connection count: {}",
71 self.state
72 .cur_conn
73 .load(std::sync::atomic::Ordering::Relaxed)
74 );
75 }
76
77 let pool = self.clone();
78 tcp.map(|t| {
79 TcpStream::new(
80 move || {
81 assert!(
82 pool.state
83 .cur_conn
84 .fetch_sub(1, std::sync::atomic::Ordering::Relaxed)
85 > 0
86 );
87 debug!(
88 "[tcp pool] [desc] current connection count: {}",
89 pool.state
90 .cur_conn
91 .load(std::sync::atomic::Ordering::Relaxed)
92 );
93 },
94 t,
95 )
96 })
97 }
98}
99
100async fn create_tcp_stream(addrs: &Address) -> Result<tokio::net::TcpStream, Error> {
101 let a = to_socket_addrs(addrs).await?;
102 tokio::net::TcpStream::connect(&a[..])
103 .await
104 .map_err(|e| Error::from_other(e))
105}
106
107async fn to_socket_addrs(addr: &Address) -> io::Result<Vec<SocketAddr>> {
108 async fn inner<T: tokio::net::ToSocketAddrs>(host: &T) -> io::Result<Vec<SocketAddr>> {
109 tokio::net::lookup_host(host)
110 .await
111 .map(|a| a.into_iter().collect())
112 }
113
114 match addr {
115 Address::Ori(ori) => inner(ori).await,
116 Address::Addr(addr) => inner(addr).await,
117 }
118}
119
120pub async fn get<P: TcpPool + Send>(pool: Arc<P>, key: &str) -> Result<TcpStream, Error> {
121 TcpPool::get(pool.clone(), key).await
122}