1use crate::TcpStream;
2use net_pool::backend::Address;
3use net_pool::strategy::LbStrategy;
4use net_pool::{Error, trace};
5use std::io;
6use std::net::SocketAddr;
7use std::sync::Arc;
8
9pub struct Pool {
14 state: net_pool::pool::BaseState,
15}
16
17impl Pool {
18 pub fn new(strategy: Arc<dyn LbStrategy>) -> Self {
19 Pool {
20 state: net_pool::pool::BaseState::new(strategy),
21 }
22 }
23}
24
25impl Default for Pool {
26 fn default() -> Self {
27 Pool::new(Arc::new(net_pool::strategy::HashStrategy::default()))
28 }
29}
30
31impl<L: LbStrategy + 'static> From<L> for Pool {
32 fn from(value: L) -> Self {
33 Self::new(Arc::new(value))
34 }
35}
36
37impl net_pool::pool::Pool for Pool {
38 net_pool::macros::base_pool_impl! {state}
39}
40
41pub trait TcpPool {
42 fn get(self: Arc<Self>, key: &str) -> impl Future<Output = Result<TcpStream, Error>> + Send;
43}
44
45impl TcpPool for Pool {
46 async fn get(self: Arc<Self>, key: &str) -> Result<TcpStream, Error> {
47 net_pool::pool::increase_current(&self.state.max_conn, &self.state.cur_conn)?;
49
50 let tcp = {
51 match self
52 .state
53 .lb_strategy
54 .get_backend(key)
55 .ok_or(Error::NoBackend)
56 {
57 Err(e) => Err(e),
58 Ok(bs) => create_tcp_stream(bs.get_address()).await,
59 }
60 };
61
62 if tcp.is_err() {
63 assert!(
64 self.state
65 .cur_conn
66 .fetch_sub(1, std::sync::atomic::Ordering::Relaxed)
67 > 0
68 );
69 } else {
70 trace!(
71 "[tcp pool] [incr] current connection count: {}",
72 self.state
73 .cur_conn
74 .load(std::sync::atomic::Ordering::Relaxed)
75 );
76 }
77
78 let pool = self.clone();
79 tcp.map(|t| {
80 TcpStream::new(
81 move || {
82 assert!(
83 pool.state
84 .cur_conn
85 .fetch_sub(1, std::sync::atomic::Ordering::Relaxed)
86 > 0
87 );
88 trace!(
89 "[tcp pool] [desc] current connection count: {}",
90 pool.state
91 .cur_conn
92 .load(std::sync::atomic::Ordering::Relaxed)
93 );
94 },
95 t,
96 )
97 })
98 }
99}
100
101async fn create_tcp_stream(addrs: &Address) -> Result<tokio::net::TcpStream, Error> {
102 let a = to_socket_addrs(addrs).await?;
103 tokio::net::TcpStream::connect(&a[..])
104 .await
105 .map_err(|e| Error::from_other(e))
106}
107
108async fn to_socket_addrs(addr: &Address) -> io::Result<Vec<SocketAddr>> {
109 async fn inner<T: tokio::net::ToSocketAddrs>(host: &T) -> io::Result<Vec<SocketAddr>> {
110 tokio::net::lookup_host(host)
111 .await
112 .map(|a| a.into_iter().collect())
113 }
114
115 match addr {
116 Address::Ori(ori) => inner(ori).await,
117 Address::Addr(addr) => inner(addr).await,
118 }
119}
120
121pub async fn get<P: TcpPool + Send>(pool: Arc<P>, key: &str) -> Result<TcpStream, Error> {
122 TcpPool::get(pool.clone(), key).await
123}