rustis/client/
transaction.rs

1#[cfg(feature = "redis-graph")]
2use crate::commands::GraphCommands;
3use crate::{
4    client::{BatchPreparedCommand, Client, PreparedCommand},
5    commands::{
6        BitmapCommands, BloomCommands, CountMinSketchCommands, CuckooCommands, GenericCommands,
7        GeoCommands, HashCommands, HyperLogLogCommands, JsonCommands, ListCommands,
8        ScriptingCommands, SearchCommands, ServerCommands, SetCommands, SortedSetCommands,
9        StreamCommands, StringCommands, TDigestCommands, TimeSeriesCommands, TopKCommands,
10        VectorSetCommands,
11    },
12    resp::{cmd, Command, RespDeserializer, Response},
13    Error, Result,
14};
15use serde::{
16    de::{self, DeserializeOwned, DeserializeSeed, IgnoredAny, SeqAccess, Visitor},
17    forward_to_deserialize_any, Deserializer,
18};
19use std::{fmt, marker::PhantomData};
20
21/// Represents an on-going [`transaction`](https://redis.io/docs/manual/transactions/) on a specific client instance.
22pub struct Transaction {
23    client: Client,
24    commands: Vec<Command>,
25    forget_flags: Vec<bool>,
26    retry_on_error: Option<bool>,
27}
28
29impl Transaction {
30    pub(crate) fn new(client: Client) -> Self {
31        Self {
32            client,
33            commands: vec![cmd("MULTI")],
34            forget_flags: Vec::new(),
35            retry_on_error: None,
36        }
37    }
38
39    /// Set a flag to override default `retry_on_error` behavior.
40    ///
41    /// See [Config::retry_on_error](crate::client::Config::retry_on_error)
42    pub fn retry_on_error(&mut self, retry_on_error: bool) {
43        self.retry_on_error = Some(retry_on_error);
44    }
45
46    /// Queue a command into the transaction.
47    pub fn queue(&mut self, command: Command) {
48        self.commands.push(command);
49        self.forget_flags.push(false);
50    }
51
52    /// Queue a command into the transaction and forget its response.
53    pub fn forget(&mut self, command: Command) {
54        self.commands.push(command);
55        self.forget_flags.push(true);
56    }
57
58    /// Execute the transaction by the sending the queued command
59    /// as a whole batch to the Redis server.
60    ///
61    /// # Return
62    /// It is the caller responsability to use the right type to cast the server response
63    /// to the right tuple or collection depending on which command has been
64    /// [queued](BatchPreparedCommand::queue) or [forgotten](BatchPreparedCommand::forget).
65    ///
66    /// The most generic type that can be requested as a result is `Vec<resp::Value>`
67    ///
68    /// # Example
69    /// ```
70    /// use rustis::{
71    ///     client::{Client, Transaction, BatchPreparedCommand},
72    ///     commands::StringCommands,
73    ///     resp::{cmd, Value}, Result,
74    /// };
75    ///
76    /// #[cfg_attr(feature = "tokio-runtime", tokio::main)]
77    /// #[cfg_attr(feature = "async-std-runtime", async_std::main)]
78    /// async fn main() -> Result<()> {
79    ///     let client = Client::connect("127.0.0.1:6379").await?;
80    ///
81    ///     let mut transaction = client.create_transaction();
82    ///
83    ///     transaction.set("key1", "value1").forget();
84    ///     transaction.set("key2", "value2").forget();
85    ///     transaction.get::<_, String>("key1").queue();
86    ///     let value: String = transaction.execute().await?;
87    ///
88    ///     assert_eq!("value1", value);
89    ///
90    ///     Ok(())
91    /// }
92    /// ```
93    pub async fn execute<T: DeserializeOwned>(mut self) -> Result<T> {
94        self.commands.push(cmd("EXEC"));
95
96        let num_commands = self.commands.len();
97
98        let results = self
99            .client
100            .send_batch(self.commands, self.retry_on_error)
101            .await?;
102
103        let mut iter = results.into_iter();
104
105        // MULTI + QUEUED commands
106        for _ in 0..num_commands - 1 {
107            if let Some(resp_buf) = iter.next() {
108                resp_buf.to::<()>()?;
109            }
110        }
111
112        // EXEC
113        if let Some(result) = iter.next() {
114            let mut deserializer = RespDeserializer::new(&result);
115            match TransactionResultSeed::new(self.forget_flags).deserialize(&mut deserializer) {
116                Ok(Some(t)) => Ok(t),
117                Ok(None) => Err(Error::Aborted),
118                Err(e) => Err(e),
119            }
120        } else {
121            Err(Error::Client(
122                "Unexpected result for transaction".to_owned(),
123            ))
124        }
125    }
126}
127
128struct TransactionResultSeed<T: DeserializeOwned> {
129    phantom: PhantomData<T>,
130    forget_flags: Vec<bool>,
131}
132
133impl<T: DeserializeOwned> TransactionResultSeed<T> {
134    pub fn new(forget_flags: Vec<bool>) -> Self {
135        Self {
136            phantom: PhantomData,
137            forget_flags,
138        }
139    }
140}
141
142impl<'de, T: DeserializeOwned> DeserializeSeed<'de> for TransactionResultSeed<T> {
143    type Value = Option<T>;
144
145    fn deserialize<D>(self, deserializer: D) -> std::result::Result<Self::Value, D::Error>
146    where
147        D: serde::Deserializer<'de>,
148    {
149        deserializer.deserialize_any(self)
150    }
151}
152
153impl<'de, T: DeserializeOwned> Visitor<'de> for TransactionResultSeed<T> {
154    type Value = Option<T>;
155
156    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
157        formatter.write_str("Option<T>")
158    }
159
160    fn visit_seq<A>(self, mut seq: A) -> std::result::Result<Self::Value, A::Error>
161    where
162        A: serde::de::SeqAccess<'de>,
163    {
164        if self
165            .forget_flags
166            .iter()
167            .fold(0, |acc, flag| if *flag { acc } else { acc + 1 })
168            == 1
169        {
170            for forget in &self.forget_flags {
171                if *forget {
172                    seq.next_element::<IgnoredAny>()?;
173                } else {
174                    return seq.next_element::<T>();
175                }
176            }
177            Ok(None)
178        } else {
179            let deserializer = SeqAccessDeserializer {
180                forget_flags: self.forget_flags.into_iter(),
181                seq_access: seq,
182            };
183
184            T::deserialize(deserializer)
185                .map(Some)
186                .map_err(de::Error::custom)
187        }
188    }
189
190    fn visit_none<E>(self) -> std::result::Result<Self::Value, E>
191    where
192        E: serde::de::Error,
193    {
194        Ok(None)
195    }
196}
197
198struct SeqAccessDeserializer<A> {
199    forget_flags: std::vec::IntoIter<bool>,
200    seq_access: A,
201}
202
203impl<'de, A> Deserializer<'de> for SeqAccessDeserializer<A>
204where
205    A: serde::de::SeqAccess<'de>,
206{
207    type Error = Error;
208
209    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
210    where
211        V: Visitor<'de>,
212    {
213        self.deserialize_seq(visitor)
214    }
215
216    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
217    where
218        V: Visitor<'de>,
219    {
220        visitor.visit_seq(self)
221    }
222
223    forward_to_deserialize_any! {
224        bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str
225        bytes byte_buf unit_struct newtype_struct string tuple
226        tuple_struct map struct enum identifier ignored_any unit option
227    }
228}
229
230impl<'de, A> SeqAccess<'de> for SeqAccessDeserializer<A>
231where
232    A: serde::de::SeqAccess<'de>,
233{
234    type Error = Error;
235
236    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
237    where
238        T: DeserializeSeed<'de>,
239    {
240        for forget in self.forget_flags.by_ref() {
241            if forget {
242                self.seq_access
243                    .next_element::<IgnoredAny>()
244                    .map_err::<Error, _>(de::Error::custom)?;
245            } else {
246                return self
247                    .seq_access
248                    .next_element_seed(seed)
249                    .map_err(de::Error::custom);
250            }
251        }
252        Ok(None)
253    }
254}
255
256impl<'a, R: Response> BatchPreparedCommand for PreparedCommand<'a, &'a mut Transaction, R> {
257    /// Queue a command into the transaction.
258    fn queue(self) {
259        self.executor.queue(self.command)
260    }
261
262    /// Queue a command into the transaction and forget its response.
263    fn forget(self) {
264        self.executor.forget(self.command)
265    }
266}
267
268impl<'a> BitmapCommands<'a> for &'a mut Transaction {}
269impl<'a> BloomCommands<'a> for &'a mut Transaction {}
270impl<'a> CountMinSketchCommands<'a> for &'a mut Transaction {}
271impl<'a> CuckooCommands<'a> for &'a mut Transaction {}
272impl<'a> GenericCommands<'a> for &'a mut Transaction {}
273impl<'a> GeoCommands<'a> for &'a mut Transaction {}
274#[cfg_attr(docsrs, doc(cfg(feature = "redis-graph")))]
275#[cfg(feature = "redis-graph")]
276impl<'a> GraphCommands<'a> for &'a mut Transaction {}
277impl<'a> HashCommands<'a> for &'a mut Transaction {}
278impl<'a> HyperLogLogCommands<'a> for &'a mut Transaction {}
279impl<'a> JsonCommands<'a> for &'a mut Transaction {}
280impl<'a> ListCommands<'a> for &'a mut Transaction {}
281impl<'a> SearchCommands<'a> for &'a mut Transaction {}
282impl<'a> SetCommands<'a> for &'a mut Transaction {}
283impl<'a> ScriptingCommands<'a> for &'a mut Transaction {}
284impl<'a> ServerCommands<'a> for &'a mut Transaction {}
285impl<'a> SortedSetCommands<'a> for &'a mut Transaction {}
286impl<'a> StreamCommands<'a> for &'a mut Transaction {}
287impl<'a> StringCommands<'a> for &'a mut Transaction {}
288impl<'a> TDigestCommands<'a> for &'a mut Transaction {}
289impl<'a> TimeSeriesCommands<'a> for &'a mut Transaction {}
290impl<'a> TopKCommands<'a> for &'a mut Transaction {}
291impl<'a> VectorSetCommands<'a> for &'a Transaction {}