tcp_pool/
pool.rs

1use crate::TcpStream;
2use net_pool::backend::Address;
3use net_pool::strategy::LbStrategy;
4use net_pool::{Error, trace};
5use std::io;
6use std::net::SocketAddr;
7use std::sync::Arc;
8
9/// tcp连接池
10/// 连接池中的连接受max_conn数据限制
11/// keepalive数据没有任何效果
12/// 连接池中的连接不复用, get出来后tcp stream不被引用则会导致连接断开
13pub struct Pool {
14    state: net_pool::pool::BaseState,
15}
16
17impl Pool {
18    pub fn new(strategy: Arc<dyn LbStrategy>) -> Self {
19        Pool {
20            state: net_pool::pool::BaseState::new(strategy),
21        }
22    }
23}
24
25impl Default for Pool {
26    fn default() -> Self {
27        Pool::new(Arc::new(net_pool::strategy::HashStrategy::default()))
28    }
29}
30
31impl<L: LbStrategy + 'static> From<L> for Pool {
32    fn from(value: L) -> Self {
33        Self::new(Arc::new(value))
34    }
35}
36
37impl net_pool::pool::Pool for Pool {
38    net_pool::macros::base_pool_impl! {state}
39}
40
41pub trait TcpPool {
42    fn get(self: Arc<Self>, key: &str) -> impl Future<Output = Result<TcpStream, Error>> + Send;
43}
44
45impl TcpPool for Pool {
46    async fn get(self: Arc<Self>, key: &str) -> Result<TcpStream, Error> {
47        // 预分配数量
48        net_pool::pool::increase_current(&self.state.max_conn, &self.state.cur_conn)?;
49
50        let tcp = {
51            match self
52                .state
53                .lb_strategy
54                .get_backend(key)
55                .ok_or(Error::NoBackend)
56            {
57                Err(e) => Err(e),
58                Ok(bs) => create_tcp_stream(bs.get_address()).await,
59            }
60        };
61
62        if tcp.is_err() {
63            assert!(
64                self.state
65                    .cur_conn
66                    .fetch_sub(1, std::sync::atomic::Ordering::Relaxed)
67                    > 0
68            );
69        } else {
70            trace!(
71                "[tcp pool] [incr] current connection count: {}",
72                self.state
73                    .cur_conn
74                    .load(std::sync::atomic::Ordering::Relaxed)
75            );
76        }
77
78        let pool = self.clone();
79        tcp.map(|t| {
80            TcpStream::new(
81                move || {
82                    assert!(
83                        pool.state
84                            .cur_conn
85                            .fetch_sub(1, std::sync::atomic::Ordering::Relaxed)
86                            > 0
87                    );
88                    trace!(
89                        "[tcp pool] [desc] current connection count: {}",
90                        pool.state
91                            .cur_conn
92                            .load(std::sync::atomic::Ordering::Relaxed)
93                    );
94                },
95                t,
96            )
97        })
98    }
99}
100
101async fn create_tcp_stream(addrs: &Address) -> Result<tokio::net::TcpStream, Error> {
102    let a = to_socket_addrs(addrs).await?;
103    tokio::net::TcpStream::connect(&a[..])
104        .await
105        .map_err(|e| Error::from_other(e))
106}
107
108async fn to_socket_addrs(addr: &Address) -> io::Result<Vec<SocketAddr>> {
109    async fn inner<T: tokio::net::ToSocketAddrs>(host: &T) -> io::Result<Vec<SocketAddr>> {
110        tokio::net::lookup_host(host)
111            .await
112            .map(|a| a.into_iter().collect())
113    }
114
115    match addr {
116        Address::Ori(ori) => inner(ori).await,
117        Address::Addr(addr) => inner(addr).await,
118    }
119}
120
121pub async fn get<P: TcpPool + Send>(pool: Arc<P>, key: &str) -> Result<TcpStream, Error> {
122    TcpPool::get(pool.clone(), key).await
123}