Skip to main content

yang_db/redis/
transaction.rs

1use crate::{DbError, RedisClient, RedisValue, Result};
2use redis::FromRedisValue;
3
4/// Redis 事务构建器
5///
6/// 提供类型安全的 Redis 事务操作接口,基于 WATCH/MULTI/EXEC 机制实现乐观锁。
7///
8/// # 特性
9/// - 支持 WATCH 键监视(乐观锁)
10/// - 原子性执行所有命令
11/// - 自动处理 WATCH 冲突并重试
12/// - 支持所有 Redis 命令
13///
14/// # 示例
15/// ```no_run
16/// use yang_db::RedisClient;
17///
18/// #[tokio::main]
19/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
20///     let client = RedisClient::connect("redis://127.0.0.1:6379").await?;
21///     
22///     // 创建事务
23///     let mut tx = client.transaction();
24///     
25///     // 添加命令
26///     tx.set("key1", "value1");
27///     tx.set("key2", "value2");
28///     tx.incr("counter");
29///     
30///     // 执行事务
31///     let results: (String, String, i64) = tx.exec().await?;
32///     println!("事务执行结果: {:?}", results);
33///     
34///     Ok(())
35/// }
36/// ```
37pub struct RedisTransaction {
38    /// Redis 客户端引用
39    client: RedisClient,
40    /// 原生 redis::Pipeline
41    pipe: redis::Pipeline,
42    /// 要监视的键列表
43    watched_keys: Vec<String>,
44}
45
46impl RedisTransaction {
47    /// 创建新的事务
48    ///
49    /// # 参数
50    /// - `client`: Redis 客户端
51    ///
52    /// # 返回
53    /// 新的事务实例
54    pub fn new(client: RedisClient) -> Self {
55        let mut pipe = redis::pipe();
56        pipe.atomic(); // 设置为原子模式(MULTI/EXEC)
57
58        Self {
59            client,
60            pipe,
61            watched_keys: Vec::new(),
62        }
63    }
64
65    /// 监视一个或多个键(用于乐观锁)
66    ///
67    /// # 参数
68    /// - `keys`: 要监视的键列表
69    ///
70    /// # 返回
71    /// 返回 self 以支持链式调用
72    ///
73    /// # 注意
74    /// - 如果被监视的键在事务执行前被修改,事务将被取消并自动重试
75    /// - 必须在添加命令之前调用
76    ///
77    /// # 示例
78    /// ```no_run
79    /// # use yang_db::RedisClient;
80    /// # #[tokio::main]
81    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
82    /// # let client = RedisClient::connect("redis://127.0.0.1:6379").await?;
83    /// let mut tx = client.transaction();
84    /// tx.watch(&["balance".to_string()]);
85    /// # Ok(())
86    /// # }
87    /// ```
88    pub fn watch(&mut self, keys: &[String]) -> &mut Self {
89        self.watched_keys.extend_from_slice(keys);
90        self
91    }
92
93    /// 添加 SET 命令
94    ///
95    /// # 参数
96    /// - `key`: 键
97    /// - `value`: 值
98    ///
99    /// # 返回
100    /// 返回 self 以支持链式调用
101    pub fn set(&mut self, key: impl Into<String>, value: impl Into<String>) -> &mut Self {
102        self.pipe.set(key.into(), value.into());
103        self
104    }
105
106    /// 添加 GET 命令
107    ///
108    /// # 参数
109    /// - `key`: 键
110    ///
111    /// # 返回
112    /// 返回 self 以支持链式调用
113    pub fn get(&mut self, key: impl Into<String>) -> &mut Self {
114        self.pipe.get(key.into());
115        self
116    }
117
118    /// 添加 DEL 命令
119    ///
120    /// # 参数
121    /// - `keys`: 要删除的键列表
122    ///
123    /// # 返回
124    /// 返回 self 以支持链式调用
125    pub fn del(&mut self, keys: &[String]) -> &mut Self {
126        self.pipe.del(keys);
127        self
128    }
129
130    /// 添加 INCR 命令
131    ///
132    /// # 参数
133    /// - `key`: 键
134    ///
135    /// # 返回
136    /// 返回 self 以支持链式调用
137    pub fn incr(&mut self, key: impl Into<String>) -> &mut Self {
138        self.pipe.incr(key.into(), 1);
139        self
140    }
141
142    /// 添加 DECRBY 命令
143    ///
144    /// # 参数
145    /// - `key`: 键
146    /// - `decrement`: 减少的数量
147    ///
148    /// # 返回
149    /// 返回 self 以支持链式调用
150    pub fn decrby(&mut self, key: impl Into<String>, decrement: i64) -> &mut Self {
151        self.pipe.decr(key.into(), decrement);
152        self
153    }
154
155    /// 添加 HSET 命令
156    ///
157    /// # 参数
158    /// - `key`: 哈希表键
159    /// - `field`: 字段名
160    /// - `value`: 字段值
161    ///
162    /// # 返回
163    /// 返回 self 以支持链式调用
164    pub fn hset(
165        &mut self,
166        key: impl Into<String>,
167        field: impl Into<String>,
168        value: impl Into<String>,
169    ) -> &mut Self {
170        self.pipe.hset(key.into(), field.into(), value.into());
171        self
172    }
173
174    /// 添加 HGET 命令
175    ///
176    /// # 参数
177    /// - `key`: 哈希表键
178    /// - `field`: 字段名
179    ///
180    /// # 返回
181    /// 返回 self 以支持链式调用
182    pub fn hget(&mut self, key: impl Into<String>, field: impl Into<String>) -> &mut Self {
183        self.pipe.hget(key.into(), field.into());
184        self
185    }
186
187    /// 添加 LPUSH 命令
188    ///
189    /// # 参数
190    /// - `key`: 列表键
191    /// - `values`: 要插入的值列表
192    ///
193    /// # 返回
194    /// 返回 self 以支持链式调用
195    pub fn lpush(&mut self, key: impl Into<String>, values: &[String]) -> &mut Self {
196        let key_str = key.into();
197        for value in values {
198            self.pipe.lpush(&key_str, value);
199        }
200        self
201    }
202
203    /// 添加 RPUSH 命令
204    ///
205    /// # 参数
206    /// - `key`: 列表键
207    /// - `values`: 要插入的值列表
208    ///
209    /// # 返回
210    /// 返回 self 以支持链式调用
211    pub fn rpush(&mut self, key: impl Into<String>, values: &[String]) -> &mut Self {
212        let key_str = key.into();
213        for value in values {
214            self.pipe.rpush(&key_str, value);
215        }
216        self
217    }
218
219    /// 添加 SADD 命令
220    ///
221    /// # 参数
222    /// - `key`: 集合键
223    /// - `members`: 要添加的成员列表
224    ///
225    /// # 返回
226    /// 返回 self 以支持链式调用
227    pub fn sadd(&mut self, key: impl Into<String>, members: &[String]) -> &mut Self {
228        let key_str = key.into();
229        for member in members {
230            self.pipe.sadd(&key_str, member);
231        }
232        self
233    }
234
235    /// 添加 ZADD 命令
236    ///
237    /// # 参数
238    /// - `key`: 有序集合键
239    /// - `members`: (分数, 成员) 元组列表
240    ///
241    /// # 返回
242    /// 返回 self 以支持链式调用
243    pub fn zadd(&mut self, key: impl Into<String>, members: &[(f64, String)]) -> &mut Self {
244        let key_str = key.into();
245        for (score, member) in members {
246            self.pipe.zadd(&key_str, member, *score);
247        }
248        self
249    }
250
251    /// 添加自定义命令
252    ///
253    /// # 参数
254    /// - `cmd`: Redis 命令
255    ///
256    /// # 返回
257    /// 返回 self 以支持链式调用
258    pub fn cmd(&mut self, cmd: redis::Cmd) -> &mut Self {
259        self.pipe.add_command(cmd);
260        self
261    }
262
263    /// 执行事务(类型化版本)
264    ///
265    /// # 类型参数
266    /// - `T`: 实现了 `FromRedisValue` 的类型
267    ///
268    /// # 返回
269    /// - `Ok(T)`: 事务执行成功,返回结果
270    /// - `Err(DbError)`: 事务执行失败
271    ///
272    /// # 错误处理
273    /// - 如果 WATCH 的键被修改,自动重试(最多 100 次)
274    /// - 如果其他错误,直接返回
275    ///
276    /// # 示例
277    /// ```no_run
278    /// # use yang_db::RedisClient;
279    /// # #[tokio::main]
280    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
281    /// # let client = RedisClient::connect("redis://127.0.0.1:6379").await?;
282    /// let mut tx = client.transaction();
283    /// tx.set("key1", "value1")
284    ///   .set("key2", "value2")
285    ///   .get("key1");
286    ///
287    /// // 获取类型化结果
288    /// let results: (String, String, String) = tx.exec().await?;
289    /// # Ok(())
290    /// # }
291    /// ```
292    pub async fn exec<T: FromRedisValue>(self) -> Result<T> {
293        let mut conn = self
294            .client
295            .pool()
296            .get()
297            .await
298            .map_err(|e| DbError::RedisPoolError(format!("获取连接失败: {}", e)))?;
299
300        // 最大重试次数
301        const MAX_RETRIES: usize = 100;
302        let mut retries = 0;
303
304        loop {
305            // WATCH 指定的键
306            if !self.watched_keys.is_empty() {
307                let mut watch_cmd = redis::cmd("WATCH");
308                for key in &self.watched_keys {
309                    watch_cmd.arg(key);
310                }
311                watch_cmd
312                    .query_async::<()>(&mut *conn)
313                    .await
314                    .map_err(|e| DbError::RedisCommandError(format!("WATCH 命令失败: {}", e)))?;
315            }
316
317            // 执行事务
318            match self.pipe.query_async::<T>(&mut *conn).await {
319                Ok(result) => {
320                    return Ok(result);
321                }
322                Err(e) => {
323                    // 检查是否是 WATCH 冲突导致的失败
324                    let err_msg = e.to_string();
325                    if (err_msg.contains("EXECABORT") || err_msg.contains("nil"))
326                        && !self.watched_keys.is_empty()
327                    {
328                        retries += 1;
329                        if retries >= MAX_RETRIES {
330                            return Err(DbError::RedisCommandError(format!(
331                                "事务执行失败,已重试 {} 次: {}",
332                                MAX_RETRIES, e
333                            )));
334                        }
335                        // WATCH 冲突,重试
336                        continue;
337                    } else {
338                        // 其他错误,直接返回
339                        return Err(DbError::RedisCommandError(format!("事务执行失败: {}", e)));
340                    }
341                }
342            }
343        }
344    }
345
346    /// 执行事务(兼容模式)
347    ///
348    /// # 返回
349    /// - `Ok(Vec<RedisValue>)`: 事务执行成功,返回结果列表
350    /// - `Err(DbError)`: 事务执行失败
351    ///
352    /// # 示例
353    /// ```no_run
354    /// # use yang_db::RedisClient;
355    /// # #[tokio::main]
356    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
357    /// # let client = RedisClient::connect("redis://127.0.0.1:6379").await?;
358    /// let mut tx = client.transaction();
359    /// tx.set("key1", "value1")
360    ///   .set("key2", "value2")
361    ///   .incr("counter");
362    ///
363    /// let results = tx.execute().await?;
364    /// println!("事务执行结果: {:?}", results);
365    /// # Ok(())
366    /// # }
367    /// ```
368    pub async fn execute(self) -> Result<Vec<RedisValue>> {
369        let results: Vec<redis::Value> = self.exec().await?;
370        Ok(results.into_iter().map(RedisValue::from).collect())
371    }
372
373    /// 获取事务中的命令数量
374    ///
375    /// # 返回
376    /// 事务中的命令数量
377    pub fn len(&self) -> usize {
378        self.pipe.cmd_iter().count()
379    }
380
381    /// 检查事务是否为空
382    ///
383    /// # 返回
384    /// - `true`: 事务为空
385    /// - `false`: 事务不为空
386    pub fn is_empty(&self) -> bool {
387        self.len() == 0
388    }
389}
390
391#[cfg(test)]
392mod tests {
393    #[test]
394    fn test_transaction_creation() {
395        // 注意:这里只测试结构体创建,不测试实际连接
396        // 实际连接测试在集成测试中进行
397    }
398}