yang_db/redis/transaction.rs
1use crate::{DbError, RedisClient, RedisValue, Result};
2use redis::FromRedisValue;
3
4/// Redis 事务构建器
5///
6/// 提供类型安全的 Redis 事务操作接口,基于 WATCH/MULTI/EXEC 机制实现乐观锁。
7///
8/// # 特性
9/// - 支持 WATCH 键监视(乐观锁)
10/// - 原子性执行所有命令
11/// - 自动处理 WATCH 冲突并重试
12/// - 支持所有 Redis 命令
13///
14/// # 示例
15/// ```no_run
16/// use yang_db::RedisClient;
17///
18/// #[tokio::main]
19/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
20/// let client = RedisClient::connect("redis://127.0.0.1:6379").await?;
21///
22/// // 创建事务
23/// let mut tx = client.transaction();
24///
25/// // 添加命令
26/// tx.set("key1", "value1");
27/// tx.set("key2", "value2");
28/// tx.incr("counter");
29///
30/// // 执行事务
31/// let results: (String, String, i64) = tx.exec().await?;
32/// println!("事务执行结果: {:?}", results);
33///
34/// Ok(())
35/// }
36/// ```
37pub struct RedisTransaction {
38 /// Redis 客户端引用
39 client: RedisClient,
40 /// 原生 redis::Pipeline
41 pipe: redis::Pipeline,
42 /// 要监视的键列表
43 watched_keys: Vec<String>,
44}
45
46impl RedisTransaction {
47 /// 创建新的事务
48 ///
49 /// # 参数
50 /// - `client`: Redis 客户端
51 ///
52 /// # 返回
53 /// 新的事务实例
54 pub fn new(client: RedisClient) -> Self {
55 let mut pipe = redis::pipe();
56 pipe.atomic(); // 设置为原子模式(MULTI/EXEC)
57
58 Self {
59 client,
60 pipe,
61 watched_keys: Vec::new(),
62 }
63 }
64
65 /// 监视一个或多个键(用于乐观锁)
66 ///
67 /// # 参数
68 /// - `keys`: 要监视的键列表
69 ///
70 /// # 返回
71 /// 返回 self 以支持链式调用
72 ///
73 /// # 注意
74 /// - 如果被监视的键在事务执行前被修改,事务将被取消并自动重试
75 /// - 必须在添加命令之前调用
76 ///
77 /// # 示例
78 /// ```no_run
79 /// # use yang_db::RedisClient;
80 /// # #[tokio::main]
81 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
82 /// # let client = RedisClient::connect("redis://127.0.0.1:6379").await?;
83 /// let mut tx = client.transaction();
84 /// tx.watch(&["balance".to_string()]);
85 /// # Ok(())
86 /// # }
87 /// ```
88 pub fn watch(&mut self, keys: &[String]) -> &mut Self {
89 self.watched_keys.extend_from_slice(keys);
90 self
91 }
92
93 /// 添加 SET 命令
94 ///
95 /// # 参数
96 /// - `key`: 键
97 /// - `value`: 值
98 ///
99 /// # 返回
100 /// 返回 self 以支持链式调用
101 pub fn set(&mut self, key: impl Into<String>, value: impl Into<String>) -> &mut Self {
102 self.pipe.set(key.into(), value.into());
103 self
104 }
105
106 /// 添加 GET 命令
107 ///
108 /// # 参数
109 /// - `key`: 键
110 ///
111 /// # 返回
112 /// 返回 self 以支持链式调用
113 pub fn get(&mut self, key: impl Into<String>) -> &mut Self {
114 self.pipe.get(key.into());
115 self
116 }
117
118 /// 添加 DEL 命令
119 ///
120 /// # 参数
121 /// - `keys`: 要删除的键列表
122 ///
123 /// # 返回
124 /// 返回 self 以支持链式调用
125 pub fn del(&mut self, keys: &[String]) -> &mut Self {
126 self.pipe.del(keys);
127 self
128 }
129
130 /// 添加 INCR 命令
131 ///
132 /// # 参数
133 /// - `key`: 键
134 ///
135 /// # 返回
136 /// 返回 self 以支持链式调用
137 pub fn incr(&mut self, key: impl Into<String>) -> &mut Self {
138 self.pipe.incr(key.into(), 1);
139 self
140 }
141
142 /// 添加 DECRBY 命令
143 ///
144 /// # 参数
145 /// - `key`: 键
146 /// - `decrement`: 减少的数量
147 ///
148 /// # 返回
149 /// 返回 self 以支持链式调用
150 pub fn decrby(&mut self, key: impl Into<String>, decrement: i64) -> &mut Self {
151 self.pipe.decr(key.into(), decrement);
152 self
153 }
154
155 /// 添加 HSET 命令
156 ///
157 /// # 参数
158 /// - `key`: 哈希表键
159 /// - `field`: 字段名
160 /// - `value`: 字段值
161 ///
162 /// # 返回
163 /// 返回 self 以支持链式调用
164 pub fn hset(
165 &mut self,
166 key: impl Into<String>,
167 field: impl Into<String>,
168 value: impl Into<String>,
169 ) -> &mut Self {
170 self.pipe.hset(key.into(), field.into(), value.into());
171 self
172 }
173
174 /// 添加 HGET 命令
175 ///
176 /// # 参数
177 /// - `key`: 哈希表键
178 /// - `field`: 字段名
179 ///
180 /// # 返回
181 /// 返回 self 以支持链式调用
182 pub fn hget(&mut self, key: impl Into<String>, field: impl Into<String>) -> &mut Self {
183 self.pipe.hget(key.into(), field.into());
184 self
185 }
186
187 /// 添加 LPUSH 命令
188 ///
189 /// # 参数
190 /// - `key`: 列表键
191 /// - `values`: 要插入的值列表
192 ///
193 /// # 返回
194 /// 返回 self 以支持链式调用
195 pub fn lpush(&mut self, key: impl Into<String>, values: &[String]) -> &mut Self {
196 let key_str = key.into();
197 for value in values {
198 self.pipe.lpush(&key_str, value);
199 }
200 self
201 }
202
203 /// 添加 RPUSH 命令
204 ///
205 /// # 参数
206 /// - `key`: 列表键
207 /// - `values`: 要插入的值列表
208 ///
209 /// # 返回
210 /// 返回 self 以支持链式调用
211 pub fn rpush(&mut self, key: impl Into<String>, values: &[String]) -> &mut Self {
212 let key_str = key.into();
213 for value in values {
214 self.pipe.rpush(&key_str, value);
215 }
216 self
217 }
218
219 /// 添加 SADD 命令
220 ///
221 /// # 参数
222 /// - `key`: 集合键
223 /// - `members`: 要添加的成员列表
224 ///
225 /// # 返回
226 /// 返回 self 以支持链式调用
227 pub fn sadd(&mut self, key: impl Into<String>, members: &[String]) -> &mut Self {
228 let key_str = key.into();
229 for member in members {
230 self.pipe.sadd(&key_str, member);
231 }
232 self
233 }
234
235 /// 添加 ZADD 命令
236 ///
237 /// # 参数
238 /// - `key`: 有序集合键
239 /// - `members`: (分数, 成员) 元组列表
240 ///
241 /// # 返回
242 /// 返回 self 以支持链式调用
243 pub fn zadd(&mut self, key: impl Into<String>, members: &[(f64, String)]) -> &mut Self {
244 let key_str = key.into();
245 for (score, member) in members {
246 self.pipe.zadd(&key_str, member, *score);
247 }
248 self
249 }
250
251 /// 添加自定义命令
252 ///
253 /// # 参数
254 /// - `cmd`: Redis 命令
255 ///
256 /// # 返回
257 /// 返回 self 以支持链式调用
258 pub fn cmd(&mut self, cmd: redis::Cmd) -> &mut Self {
259 self.pipe.add_command(cmd);
260 self
261 }
262
263 /// 执行事务(类型化版本)
264 ///
265 /// # 类型参数
266 /// - `T`: 实现了 `FromRedisValue` 的类型
267 ///
268 /// # 返回
269 /// - `Ok(T)`: 事务执行成功,返回结果
270 /// - `Err(DbError)`: 事务执行失败
271 ///
272 /// # 错误处理
273 /// - 如果 WATCH 的键被修改,自动重试(最多 100 次)
274 /// - 如果其他错误,直接返回
275 ///
276 /// # 示例
277 /// ```no_run
278 /// # use yang_db::RedisClient;
279 /// # #[tokio::main]
280 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
281 /// # let client = RedisClient::connect("redis://127.0.0.1:6379").await?;
282 /// let mut tx = client.transaction();
283 /// tx.set("key1", "value1")
284 /// .set("key2", "value2")
285 /// .get("key1");
286 ///
287 /// // 获取类型化结果
288 /// let results: (String, String, String) = tx.exec().await?;
289 /// # Ok(())
290 /// # }
291 /// ```
292 pub async fn exec<T: FromRedisValue>(self) -> Result<T> {
293 let mut conn = self
294 .client
295 .pool()
296 .get()
297 .await
298 .map_err(|e| DbError::RedisPoolError(format!("获取连接失败: {}", e)))?;
299
300 // 最大重试次数
301 const MAX_RETRIES: usize = 100;
302 let mut retries = 0;
303
304 loop {
305 // WATCH 指定的键
306 if !self.watched_keys.is_empty() {
307 let mut watch_cmd = redis::cmd("WATCH");
308 for key in &self.watched_keys {
309 watch_cmd.arg(key);
310 }
311 watch_cmd
312 .query_async::<()>(&mut *conn)
313 .await
314 .map_err(|e| DbError::RedisCommandError(format!("WATCH 命令失败: {}", e)))?;
315 }
316
317 // 执行事务
318 match self.pipe.query_async::<T>(&mut *conn).await {
319 Ok(result) => {
320 return Ok(result);
321 }
322 Err(e) => {
323 // 检查是否是 WATCH 冲突导致的失败
324 let err_msg = e.to_string();
325 if (err_msg.contains("EXECABORT") || err_msg.contains("nil"))
326 && !self.watched_keys.is_empty()
327 {
328 retries += 1;
329 if retries >= MAX_RETRIES {
330 return Err(DbError::RedisCommandError(format!(
331 "事务执行失败,已重试 {} 次: {}",
332 MAX_RETRIES, e
333 )));
334 }
335 // WATCH 冲突,重试
336 continue;
337 } else {
338 // 其他错误,直接返回
339 return Err(DbError::RedisCommandError(format!("事务执行失败: {}", e)));
340 }
341 }
342 }
343 }
344 }
345
346 /// 执行事务(兼容模式)
347 ///
348 /// # 返回
349 /// - `Ok(Vec<RedisValue>)`: 事务执行成功,返回结果列表
350 /// - `Err(DbError)`: 事务执行失败
351 ///
352 /// # 示例
353 /// ```no_run
354 /// # use yang_db::RedisClient;
355 /// # #[tokio::main]
356 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
357 /// # let client = RedisClient::connect("redis://127.0.0.1:6379").await?;
358 /// let mut tx = client.transaction();
359 /// tx.set("key1", "value1")
360 /// .set("key2", "value2")
361 /// .incr("counter");
362 ///
363 /// let results = tx.execute().await?;
364 /// println!("事务执行结果: {:?}", results);
365 /// # Ok(())
366 /// # }
367 /// ```
368 pub async fn execute(self) -> Result<Vec<RedisValue>> {
369 let results: Vec<redis::Value> = self.exec().await?;
370 Ok(results.into_iter().map(RedisValue::from).collect())
371 }
372
373 /// 获取事务中的命令数量
374 ///
375 /// # 返回
376 /// 事务中的命令数量
377 pub fn len(&self) -> usize {
378 self.pipe.cmd_iter().count()
379 }
380
381 /// 检查事务是否为空
382 ///
383 /// # 返回
384 /// - `true`: 事务为空
385 /// - `false`: 事务不为空
386 pub fn is_empty(&self) -> bool {
387 self.len() == 0
388 }
389}
390
391#[cfg(test)]
392mod tests {
393 #[test]
394 fn test_transaction_creation() {
395 // 注意:这里只测试结构体创建,不测试实际连接
396 // 实际连接测试在集成测试中进行
397 }
398}