1#[cfg(feature = "redis-graph")]
2use crate::commands::GraphCommands;
3use crate::{
4 client::{BatchPreparedCommand, Client, PreparedCommand},
5 commands::{
6 BitmapCommands, BloomCommands, CountMinSketchCommands, CuckooCommands, GenericCommands,
7 GeoCommands, HashCommands, HyperLogLogCommands, JsonCommands, ListCommands,
8 ScriptingCommands, SearchCommands, ServerCommands, SetCommands, SortedSetCommands,
9 StreamCommands, StringCommands, TDigestCommands, TimeSeriesCommands, TopKCommands,
10 VectorSetCommands,
11 },
12 resp::{cmd, Command, RespDeserializer, Response},
13 Error, Result,
14};
15use serde::{
16 de::{self, DeserializeOwned, DeserializeSeed, IgnoredAny, SeqAccess, Visitor},
17 forward_to_deserialize_any, Deserializer,
18};
19use std::{fmt, marker::PhantomData};
20
21pub struct Transaction {
23 client: Client,
24 commands: Vec<Command>,
25 forget_flags: Vec<bool>,
26 retry_on_error: Option<bool>,
27}
28
29impl Transaction {
30 pub(crate) fn new(client: Client) -> Self {
31 Self {
32 client,
33 commands: vec![cmd("MULTI")],
34 forget_flags: Vec::new(),
35 retry_on_error: None,
36 }
37 }
38
39 pub fn retry_on_error(&mut self, retry_on_error: bool) {
43 self.retry_on_error = Some(retry_on_error);
44 }
45
46 pub fn queue(&mut self, command: Command) {
48 self.commands.push(command);
49 self.forget_flags.push(false);
50 }
51
52 pub fn forget(&mut self, command: Command) {
54 self.commands.push(command);
55 self.forget_flags.push(true);
56 }
57
58 pub async fn execute<T: DeserializeOwned>(mut self) -> Result<T> {
94 self.commands.push(cmd("EXEC"));
95
96 let num_commands = self.commands.len();
97
98 let results = self
99 .client
100 .send_batch(self.commands, self.retry_on_error)
101 .await?;
102
103 let mut iter = results.into_iter();
104
105 for _ in 0..num_commands - 1 {
107 if let Some(resp_buf) = iter.next() {
108 resp_buf.to::<()>()?;
109 }
110 }
111
112 if let Some(result) = iter.next() {
114 let mut deserializer = RespDeserializer::new(&result);
115 match TransactionResultSeed::new(self.forget_flags).deserialize(&mut deserializer) {
116 Ok(Some(t)) => Ok(t),
117 Ok(None) => Err(Error::Aborted),
118 Err(e) => Err(e),
119 }
120 } else {
121 Err(Error::Client(
122 "Unexpected result for transaction".to_owned(),
123 ))
124 }
125 }
126}
127
128struct TransactionResultSeed<T: DeserializeOwned> {
129 phantom: PhantomData<T>,
130 forget_flags: Vec<bool>,
131}
132
133impl<T: DeserializeOwned> TransactionResultSeed<T> {
134 pub fn new(forget_flags: Vec<bool>) -> Self {
135 Self {
136 phantom: PhantomData,
137 forget_flags,
138 }
139 }
140}
141
142impl<'de, T: DeserializeOwned> DeserializeSeed<'de> for TransactionResultSeed<T> {
143 type Value = Option<T>;
144
145 fn deserialize<D>(self, deserializer: D) -> std::result::Result<Self::Value, D::Error>
146 where
147 D: serde::Deserializer<'de>,
148 {
149 deserializer.deserialize_any(self)
150 }
151}
152
153impl<'de, T: DeserializeOwned> Visitor<'de> for TransactionResultSeed<T> {
154 type Value = Option<T>;
155
156 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
157 formatter.write_str("Option<T>")
158 }
159
160 fn visit_seq<A>(self, mut seq: A) -> std::result::Result<Self::Value, A::Error>
161 where
162 A: serde::de::SeqAccess<'de>,
163 {
164 if self
165 .forget_flags
166 .iter()
167 .fold(0, |acc, flag| if *flag { acc } else { acc + 1 })
168 == 1
169 {
170 for forget in &self.forget_flags {
171 if *forget {
172 seq.next_element::<IgnoredAny>()?;
173 } else {
174 return seq.next_element::<T>();
175 }
176 }
177 Ok(None)
178 } else {
179 let deserializer = SeqAccessDeserializer {
180 forget_flags: self.forget_flags.into_iter(),
181 seq_access: seq,
182 };
183
184 T::deserialize(deserializer)
185 .map(Some)
186 .map_err(de::Error::custom)
187 }
188 }
189
190 fn visit_none<E>(self) -> std::result::Result<Self::Value, E>
191 where
192 E: serde::de::Error,
193 {
194 Ok(None)
195 }
196}
197
198struct SeqAccessDeserializer<A> {
199 forget_flags: std::vec::IntoIter<bool>,
200 seq_access: A,
201}
202
203impl<'de, A> Deserializer<'de> for SeqAccessDeserializer<A>
204where
205 A: serde::de::SeqAccess<'de>,
206{
207 type Error = Error;
208
209 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
210 where
211 V: Visitor<'de>,
212 {
213 self.deserialize_seq(visitor)
214 }
215
216 fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
217 where
218 V: Visitor<'de>,
219 {
220 visitor.visit_seq(self)
221 }
222
223 forward_to_deserialize_any! {
224 bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str
225 bytes byte_buf unit_struct newtype_struct string tuple
226 tuple_struct map struct enum identifier ignored_any unit option
227 }
228}
229
230impl<'de, A> SeqAccess<'de> for SeqAccessDeserializer<A>
231where
232 A: serde::de::SeqAccess<'de>,
233{
234 type Error = Error;
235
236 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
237 where
238 T: DeserializeSeed<'de>,
239 {
240 for forget in self.forget_flags.by_ref() {
241 if forget {
242 self.seq_access
243 .next_element::<IgnoredAny>()
244 .map_err::<Error, _>(de::Error::custom)?;
245 } else {
246 return self
247 .seq_access
248 .next_element_seed(seed)
249 .map_err(de::Error::custom);
250 }
251 }
252 Ok(None)
253 }
254}
255
256impl<'a, R: Response> BatchPreparedCommand for PreparedCommand<'a, &'a mut Transaction, R> {
257 fn queue(self) {
259 self.executor.queue(self.command)
260 }
261
262 fn forget(self) {
264 self.executor.forget(self.command)
265 }
266}
267
268impl<'a> BitmapCommands<'a> for &'a mut Transaction {}
269impl<'a> BloomCommands<'a> for &'a mut Transaction {}
270impl<'a> CountMinSketchCommands<'a> for &'a mut Transaction {}
271impl<'a> CuckooCommands<'a> for &'a mut Transaction {}
272impl<'a> GenericCommands<'a> for &'a mut Transaction {}
273impl<'a> GeoCommands<'a> for &'a mut Transaction {}
274#[cfg_attr(docsrs, doc(cfg(feature = "redis-graph")))]
275#[cfg(feature = "redis-graph")]
276impl<'a> GraphCommands<'a> for &'a mut Transaction {}
277impl<'a> HashCommands<'a> for &'a mut Transaction {}
278impl<'a> HyperLogLogCommands<'a> for &'a mut Transaction {}
279impl<'a> JsonCommands<'a> for &'a mut Transaction {}
280impl<'a> ListCommands<'a> for &'a mut Transaction {}
281impl<'a> SearchCommands<'a> for &'a mut Transaction {}
282impl<'a> SetCommands<'a> for &'a mut Transaction {}
283impl<'a> ScriptingCommands<'a> for &'a mut Transaction {}
284impl<'a> ServerCommands<'a> for &'a mut Transaction {}
285impl<'a> SortedSetCommands<'a> for &'a mut Transaction {}
286impl<'a> StreamCommands<'a> for &'a mut Transaction {}
287impl<'a> StringCommands<'a> for &'a mut Transaction {}
288impl<'a> TDigestCommands<'a> for &'a mut Transaction {}
289impl<'a> TimeSeriesCommands<'a> for &'a mut Transaction {}
290impl<'a> TopKCommands<'a> for &'a mut Transaction {}
291impl<'a> VectorSetCommands<'a> for &'a Transaction {}