tcp_pool/
pool.rs

1use crate::TcpStream;
2use net_pool::backend::Address;
3use net_pool::strategy::LbStrategy;
4use net_pool::{Error, trace};
5use std::io;
6use std::net::SocketAddr;
7use std::sync::Arc;
8
9/// tcp连接池
10/// 连接池中的连接受max_conn数据限制
11/// keepalive数据没有任何效果
12/// 连接池中的连接不复用, get出来后tcp stream不被引用则会导致连接断开
13pub struct Pool {
14    state: net_pool::pool::BaseState,
15}
16
17impl Pool {
18    pub fn new(strategy: Arc<dyn LbStrategy>) -> Self {
19        Pool {
20            state: net_pool::pool::BaseState::new(strategy),
21        }
22    }
23}
24
25impl Default for Pool {
26    fn default() -> Self {
27        Pool::new(Arc::new(net_pool::strategy::HashStrategy::default()))
28    }
29}
30
31impl net_pool::pool::Pool for Pool {
32    net_pool::macros::base_pool_impl! {state}
33}
34
35pub trait TcpPool {
36    fn get(self: Arc<Self>, key: &str) -> impl Future<Output = Result<TcpStream, Error>> + Send;
37}
38
39impl TcpPool for Pool {
40    async fn get(self: Arc<Self>, key: &str) -> Result<TcpStream, Error> {
41        // 预分配数量
42        net_pool::pool::increase_current(&self.state.max_conn, &self.state.cur_conn)?;
43
44        let tcp = {
45            match self
46                .state
47                .lb_strategy
48                .get_backend(key)
49                .ok_or(Error::NoBackend)
50            {
51                Err(e) => Err(e),
52                Ok(bs) => create_tcp_stream(bs.get_address()).await,
53            }
54        };
55
56        if tcp.is_err() {
57            assert!(
58                self.state
59                    .cur_conn
60                    .fetch_sub(1, std::sync::atomic::Ordering::Relaxed)
61                    > 0
62            );
63        } else {
64            trace!(
65                "[tcp pool] [incr] current connection count: {}",
66                self.state
67                    .cur_conn
68                    .load(std::sync::atomic::Ordering::Relaxed)
69            );
70        }
71
72        let pool = self.clone();
73        tcp.map(|t| {
74            TcpStream::new(
75                move || {
76                    assert!(
77                        pool.state
78                            .cur_conn
79                            .fetch_sub(1, std::sync::atomic::Ordering::Relaxed)
80                            > 0
81                    );
82                    trace!(
83                        "[tcp pool] [desc] current connection count: {}",
84                        pool.state
85                            .cur_conn
86                            .load(std::sync::atomic::Ordering::Relaxed)
87                    );
88                },
89                t,
90            )
91        })
92    }
93}
94
95async fn create_tcp_stream(addrs: &Address) -> Result<tokio::net::TcpStream, Error> {
96    let a = to_socket_addrs(addrs).await?;
97    tokio::net::TcpStream::connect(&a[..])
98        .await
99        .map_err(|e| Error::from_other(e))
100}
101
102async fn to_socket_addrs(addr: &Address) -> io::Result<Vec<SocketAddr>> {
103    async fn inner<T: tokio::net::ToSocketAddrs>(host: &T) -> io::Result<Vec<SocketAddr>> {
104        tokio::net::lookup_host(host)
105            .await
106            .map(|a| a.into_iter().collect())
107    }
108
109    match addr {
110        Address::Ori(ori) => inner(ori).await,
111        Address::Addr(addr) => inner(addr).await,
112    }
113}
114
115pub async fn get<P: TcpPool + Send>(pool: Arc<P>, key: &str) -> Result<TcpStream, Error> {
116    TcpPool::get(pool.clone(), key).await
117}