1use crate::TcpStream;
2use net_pool::backend::Address;
3use net_pool::strategy::LbStrategy;
4use net_pool::{Error, trace};
5use std::io;
6use std::net::SocketAddr;
7use std::sync::Arc;
8
9pub struct Pool {
14 state: net_pool::pool::BaseState,
15}
16
17impl Pool {
18 pub fn new(strategy: Arc<dyn LbStrategy>) -> Self {
19 Pool {
20 state: net_pool::pool::BaseState::new(strategy),
21 }
22 }
23}
24
25impl Default for Pool {
26 fn default() -> Self {
27 Pool::new(Arc::new(net_pool::strategy::HashStrategy::default()))
28 }
29}
30
31impl net_pool::pool::Pool for Pool {
32 net_pool::macros::base_pool_impl! {state}
33}
34
35pub trait TcpPool {
36 fn get(self: Arc<Self>, key: &str) -> impl Future<Output = Result<TcpStream, Error>> + Send;
37}
38
39impl TcpPool for Pool {
40 async fn get(self: Arc<Self>, key: &str) -> Result<TcpStream, Error> {
41 net_pool::pool::increase_current(&self.state.max_conn, &self.state.cur_conn)?;
43
44 let tcp = {
45 match self
46 .state
47 .lb_strategy
48 .get_backend(key)
49 .ok_or(Error::NoBackend)
50 {
51 Err(e) => Err(e),
52 Ok(bs) => create_tcp_stream(bs.get_address()).await,
53 }
54 };
55
56 if tcp.is_err() {
57 assert!(
58 self.state
59 .cur_conn
60 .fetch_sub(1, std::sync::atomic::Ordering::Relaxed)
61 > 0
62 );
63 } else {
64 trace!(
65 "[tcp pool] [incr] current connection count: {}",
66 self.state
67 .cur_conn
68 .load(std::sync::atomic::Ordering::Relaxed)
69 );
70 }
71
72 let pool = self.clone();
73 tcp.map(|t| {
74 TcpStream::new(
75 move || {
76 assert!(
77 pool.state
78 .cur_conn
79 .fetch_sub(1, std::sync::atomic::Ordering::Relaxed)
80 > 0
81 );
82 trace!(
83 "[tcp pool] [desc] current connection count: {}",
84 pool.state
85 .cur_conn
86 .load(std::sync::atomic::Ordering::Relaxed)
87 );
88 },
89 t,
90 )
91 })
92 }
93}
94
95async fn create_tcp_stream(addrs: &Address) -> Result<tokio::net::TcpStream, Error> {
96 let a = to_socket_addrs(addrs).await?;
97 tokio::net::TcpStream::connect(&a[..])
98 .await
99 .map_err(|e| Error::from_other(e))
100}
101
102async fn to_socket_addrs(addr: &Address) -> io::Result<Vec<SocketAddr>> {
103 async fn inner<T: tokio::net::ToSocketAddrs>(host: &T) -> io::Result<Vec<SocketAddr>> {
104 tokio::net::lookup_host(host)
105 .await
106 .map(|a| a.into_iter().collect())
107 }
108
109 match addr {
110 Address::Ori(ori) => inner(ori).await,
111 Address::Addr(addr) => inner(addr).await,
112 }
113}
114
115pub async fn get<P: TcpPool + Send>(pool: Arc<P>, key: &str) -> Result<TcpStream, Error> {
116 TcpPool::get(pool.clone(), key).await
117}