Skip to main content

tibba_cache/
pool.rs

1// Copyright 2026 Tree xie.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use super::{
16    ClusterBuildSnafu, ClusterConnectSnafu, Error, RedisSnafu, SingleBuildSnafu,
17    SingleConnectSnafu, new_redis_config,
18};
19use deadpool_redis::cluster::Hook as ClusterHook;
20use deadpool_redis::{Hook, HookError, Metrics, PoolConfig, Timeouts};
21use redis::aio::ConnectionLike;
22use redis::{Arg, Cmd, Pipeline, RedisFuture, Value};
23use snafu::ResultExt;
24use std::borrow::Cow;
25use std::sync::Arc;
26use std::sync::atomic::{AtomicU64, Ordering};
27use std::time::Duration;
28use tibba_config::Config;
29use tracing::info;
30
31use super::LOG_TARGET;
32
33type Result<T> = std::result::Result<T, Error>;
34
35/// `pre_recycle` 的返回类型,兼容单节点和集群 Hook。
36/// 两种 manager 的 `HookError` 均解析为同一具体类型。
37type HookResult = std::result::Result<(), HookError>;
38
39#[derive(Debug, Default)]
40pub struct RedisCmdStat {
41    pub cmd: String,
42    pub elapsed: Duration,
43    pub error: Option<String>,
44}
45
46#[derive(Debug, Default)]
47pub struct RedisStat {
48    pub pool_max_size: usize,
49    pub pool_size: usize,
50    pub pool_available: usize,
51    pub pool_waiting: usize,
52    pub conn_created: u64,
53    pub conn_recycled: u64,
54    /// 因空闲超时而丢弃的连接数
55    pub conn_idle_timeout_dropped: u64,
56    /// 因超过最大存活时间而丢弃的连接数
57    pub conn_max_age_dropped: u64,
58}
59
60pub type RedisCmdStatCallback = dyn Fn(RedisCmdStat) + Send + Sync;
61
62/// Redis 连接池枚举,支持单节点和集群两种模式。
63#[derive(Clone)]
64enum RedisPool {
65    /// 单节点 Redis 连接池
66    Single(deadpool_redis::Pool),
67    /// Redis 集群连接池
68    Cluster(deadpool_redis::cluster::Pool),
69}
70
71#[derive(Clone)]
72pub struct RedisClient {
73    pool: RedisPool,
74    stat_callback: Option<&'static RedisCmdStatCallback>,
75    hook_stat: HookStat,
76}
77
78pub struct RedisClientConn {
79    conn: Box<dyn ConnectionLike + Send + Sync>,
80    stat_callback: Option<&'static RedisCmdStatCallback>,
81}
82
83impl RedisClient {
84    /// 从连接池获取一个连接,单节点与集群模式均适用。
85    #[inline]
86    pub async fn conn(&self) -> Result<RedisClientConn> {
87        let conn: Box<dyn ConnectionLike + Send + Sync> = match &self.pool {
88            RedisPool::Single(p) => Box::new(p.get().await.context(SingleConnectSnafu)?),
89            RedisPool::Cluster(p) => Box::new(p.get().await.context(ClusterConnectSnafu)?),
90        };
91
92        Ok(RedisClientConn {
93            conn,
94            stat_callback: self.stat_callback,
95        })
96    }
97
98    /// 设置命令统计回调,支持链式调用。
99    #[must_use]
100    pub fn with_stat_callback(mut self, callback: &'static RedisCmdStatCallback) -> Self {
101        self.stat_callback = Some(callback);
102        self
103    }
104
105    /// 获取连接池状态统计信息。
106    pub fn stat(&self) -> RedisStat {
107        let status = match &self.pool {
108            RedisPool::Single(p) => p.status(),
109            RedisPool::Cluster(p) => p.status(),
110        };
111        let inner = &self.hook_stat.inner;
112        RedisStat {
113            pool_max_size: status.max_size,
114            pool_size: status.size,
115            pool_available: status.available,
116            pool_waiting: status.waiting,
117            conn_created: inner.created.load(Ordering::Relaxed),
118            conn_recycled: inner.recycled.load(Ordering::Relaxed),
119            conn_idle_timeout_dropped: inner.idle_timeout_dropped.load(Ordering::Relaxed),
120            conn_max_age_dropped: inner.max_age_dropped.load(Ordering::Relaxed),
121        }
122    }
123
124    /// 关闭连接池(将连接数收缩至 0)。
125    pub fn close(&self) {
126        match &self.pool {
127            RedisPool::Single(p) => p.close(),
128            RedisPool::Cluster(p) => p.close(),
129        }
130    }
131
132    /// 是否为集群模式。
133    pub fn is_cluster(&self) -> bool {
134        matches!(self.pool, RedisPool::Cluster(_))
135    }
136}
137
138#[inline]
139fn get_command_name(cmd: &Cmd) -> &str {
140    if let Some(Arg::Simple(val)) = cmd.args_iter().next()
141        && let Ok(s) = std::str::from_utf8(val)
142    {
143        return s;
144    }
145    "unknown"
146}
147
148#[inline]
149fn wrap_with_stat<'a, 'cb, T>(
150    name: Cow<'static, str>,
151    fut: RedisFuture<'a, T>,
152    callback: &'cb RedisCmdStatCallback,
153) -> RedisFuture<'a, T>
154where
155    T: Send + 'a,
156    'cb: 'a,
157{
158    Box::pin(async move {
159        let start = std::time::Instant::now();
160        let res = fut.await;
161        let elapsed = start.elapsed();
162        let mut stat = RedisCmdStat {
163            cmd: name.into_owned(),
164            elapsed,
165            ..Default::default()
166        };
167        if let Err(e) = &res {
168            stat.error = Some(e.to_string());
169        }
170        callback(stat);
171        res
172    })
173}
174
175impl ConnectionLike for RedisClientConn {
176    /// 执行单条 Redis 命令,若设置了统计回调则记录耗时与错误。
177    fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> {
178        if let Some(cb) = self.stat_callback {
179            let name = Cow::Owned(get_command_name(cmd).to_owned());
180            let fut = self.conn.req_packed_command(cmd);
181            wrap_with_stat(name, fut, cb)
182        } else {
183            self.conn.req_packed_command(cmd)
184        }
185    }
186
187    /// 以 pipeline 批量执行 Redis 命令,若设置了统计回调则整体计时。
188    fn req_packed_commands<'a>(
189        &'a mut self,
190        cmd: &'a Pipeline,
191        offset: usize,
192        count: usize,
193    ) -> RedisFuture<'a, Vec<Value>> {
194        if let Some(cb) = self.stat_callback {
195            let fut = self.conn.req_packed_commands(cmd, offset, count);
196            wrap_with_stat(Cow::Borrowed("pipeline"), fut, cb)
197        } else {
198            self.conn.req_packed_commands(cmd, offset, count)
199        }
200    }
201
202    /// 获取当前数据库编号,集群模式固定返回 0(不支持多 DB)。
203    fn get_db(&self) -> i64 {
204        0
205    }
206}
207
208/// HookStat 的内部共享状态,通过原子计数器记录连接生命周期事件。
209/// 所有 hook 闭包与 RedisClient 共享同一份实例。
210struct HookStatInner {
211    created: AtomicU64,
212    recycled: AtomicU64,
213    /// 因空闲超时而丢弃的连接数
214    idle_timeout_dropped: AtomicU64,
215    /// 因超过最大存活时间而丢弃的连接数
216    max_age_dropped: AtomicU64,
217}
218
219/// 封装连接池生命周期日志与统计。
220/// 内部通过 Arc 共享,克隆开销极低,可安全分发给各 hook 闭包。
221#[derive(Clone)]
222pub struct HookStat {
223    label: &'static str,
224    max_conn_age: Duration,
225    idle_timeout: Duration,
226    inner: Arc<HookStatInner>,
227}
228
229impl HookStat {
230    fn new(label: &'static str, max_conn_age: Duration, idle_timeout: Duration) -> Self {
231        Self {
232            label,
233            max_conn_age,
234            idle_timeout,
235            inner: Arc::new(HookStatInner {
236                created: AtomicU64::new(0),
237                recycled: AtomicU64::new(0),
238                idle_timeout_dropped: AtomicU64::new(0),
239                max_age_dropped: AtomicU64::new(0),
240            }),
241        }
242    }
243
244    /// 新物理连接建立后回调,累计创建计数并打印日志。
245    fn post_create(&self) {
246        self.inner.created.fetch_add(1, Ordering::Relaxed);
247        info!(target: LOG_TARGET, label = self.label, "new connection");
248    }
249
250    /// 连接回池前回调。超过空闲时限或最大存活时限时丢弃连接并返回 Err。
251    fn pre_recycle(&self, metrics: &Metrics) -> HookResult {
252        let idle = metrics.last_used();
253        if !self.idle_timeout.is_zero() && idle > self.idle_timeout {
254            self.inner
255                .idle_timeout_dropped
256                .fetch_add(1, Ordering::Relaxed);
257            info!(
258                target: LOG_TARGET,
259                label = self.label,
260                idle = idle.as_secs(),
261                "drop connection: idle timeout exceeded"
262            );
263            return Err(HookError::message("drop"));
264        }
265        let age = metrics.age();
266        if !self.max_conn_age.is_zero() && age > self.max_conn_age {
267            self.inner.max_age_dropped.fetch_add(1, Ordering::Relaxed);
268            info!(
269                target: LOG_TARGET,
270                label = self.label,
271                age = age.as_secs(),
272                "drop connection: max age exceeded"
273            );
274            return Err(HookError::message("drop"));
275        }
276        Ok(())
277    }
278
279    /// 连接成功回池后回调,累计复用计数并打印日志。
280    fn post_recycle(&self, metrics: &Metrics) {
281        self.inner.recycled.fetch_add(1, Ordering::Relaxed);
282        info!(
283            target: LOG_TARGET,
284            label = self.label,
285            age = metrics.age().as_secs(),
286            idle = metrics.last_used().as_secs(),
287            "recycle connection"
288        );
289    }
290}
291
292/// 根据配置创建 Redis 客户端(单节点或集群)。
293/// 单节点时使用 deadpool-redis 标准池,多节点时使用集群池。
294pub fn new_redis_client(config: &Config) -> Result<RedisClient> {
295    let redis_config = new_redis_config(config)?;
296    let pool_config = PoolConfig {
297        max_size: redis_config.pool_size as usize,
298        timeouts: Timeouts {
299            wait: Some(redis_config.wait_timeout),
300            create: Some(redis_config.connection_timeout),
301            recycle: Some(redis_config.recycle_timeout),
302        },
303        ..Default::default()
304    };
305
306    let password = redis_config.password.as_deref().unwrap_or_default();
307    let nodes: Vec<_> = redis_config
308        .nodes
309        .iter()
310        .map(|v| {
311            if password.is_empty() {
312                return v.to_string();
313            }
314            v.replace(password, "***")
315        })
316        .collect();
317
318    let is_single = redis_config.nodes.len() <= 1;
319    let hook_stat = HookStat::new(
320        if is_single { "single" } else { "cluster" },
321        redis_config.max_conn_age,
322        redis_config.idle_timeout,
323    );
324
325    let (pool, hook_stat) = if is_single {
326        // 单节点模式
327        let mgr =
328            deadpool_redis::Manager::new(redis_config.nodes[0].as_str()).context(RedisSnafu {
329                category: "new_pool",
330            })?;
331        let pool = deadpool_redis::Pool::builder(mgr)
332            .config(pool_config)
333            .runtime(deadpool_redis::Runtime::Tokio1)
334            .post_create(Hook::sync_fn({
335                let stat = hook_stat.clone();
336                move |_, _| {
337                    stat.post_create();
338                    Ok(())
339                }
340            }))
341            .pre_recycle(Hook::sync_fn({
342                let stat = hook_stat.clone();
343                move |_, m| stat.pre_recycle(m)
344            }))
345            .post_recycle(Hook::sync_fn({
346                let stat = hook_stat.clone();
347                move |_, m| {
348                    stat.post_recycle(m);
349                    Ok(())
350                }
351            }))
352            .build()
353            .context(SingleBuildSnafu)?;
354        (RedisPool::Single(pool), hook_stat)
355    } else {
356        // 集群模式
357        let mut cfg = deadpool_redis::cluster::Config::from_urls(redis_config.nodes.clone());
358        cfg.pool = Some(pool_config);
359        let pool = cfg
360            .builder()
361            .map_err(deadpool_redis::cluster::CreatePoolError::Config)
362            .context(ClusterBuildSnafu)?
363            .runtime(deadpool_redis::cluster::Runtime::Tokio1)
364            .post_create(ClusterHook::sync_fn({
365                let stat = hook_stat.clone();
366                move |_, _| {
367                    stat.post_create();
368                    Ok(())
369                }
370            }))
371            .pre_recycle(ClusterHook::sync_fn({
372                let stat = hook_stat.clone();
373                move |_, m| stat.pre_recycle(m)
374            }))
375            .post_recycle(ClusterHook::sync_fn({
376                let stat = hook_stat.clone();
377                move |_, m| {
378                    stat.post_recycle(m);
379                    Ok(())
380                }
381            }))
382            .build()
383            .map_err(deadpool_redis::cluster::CreatePoolError::Build)
384            .context(ClusterBuildSnafu)?;
385        (RedisPool::Cluster(pool), hook_stat)
386    };
387    info!(target: LOG_TARGET, nodes = nodes.join(","), "connect to redis");
388    Ok(RedisClient {
389        pool,
390        stat_callback: None,
391        hook_stat,
392    })
393}