spanner_rs/
client.rs

1use std::future::Future;
2use std::pin::Pin;
3
4use bb8::{Pool, PooledConnection};
5use tonic::Code;
6
7use crate::result_set::ResultSet;
8use crate::statement::Statement;
9use crate::TimestampBound;
10use crate::ToSpanner;
11use crate::{session::SessionManager, ConfigBuilder, Connection, Error, TransactionSelector};
12
13/// An asynchronous Cloud Spanner client.
14pub struct Client {
15    connection: Box<dyn Connection>,
16    session_pool: Pool<SessionManager>,
17}
18
19impl Client {
20    /// Returns a new [`ConfigBuilder`] which can be used to configure how to connect to a Cloud Spanner instance and database.
21    pub fn configure() -> ConfigBuilder {
22        ConfigBuilder::default()
23    }
24}
25
26impl Client {
27    pub(crate) fn connect(
28        connection: Box<dyn Connection>,
29        session_pool: Pool<SessionManager>,
30    ) -> Self {
31        Self {
32            connection,
33            session_pool,
34        }
35    }
36
37    /// Returns a [`ReadContext`] that can be used to read data out of Cloud Spanner.
38    /// The returned context uses [`TimestampBound::Strong`] consistency for each individual read.
39    pub fn read_only(&self) -> impl ReadContext {
40        ReadOnly {
41            connection: self.connection.clone(),
42            bound: None,
43            session_pool: self.session_pool.clone(),
44        }
45    }
46
47    /// Returns a [`ReadContext`] that can be used to read data out of Cloud Spanner.
48    /// The returned context uses the specified bounded consistency for each individual read.
49    pub fn read_only_with_bound(&self, bound: TimestampBound) -> impl ReadContext {
50        ReadOnly {
51            connection: self.connection.clone(),
52            bound: Some(bound),
53            session_pool: self.session_pool.clone(),
54        }
55    }
56
57    /// Returns a [`TxRunner`] that can be used to execute transactions using a [`TransactionContext`]
58    /// to read and write data from/into Cloud Spanner.
59    pub fn read_write(&self) -> TxRunner {
60        TxRunner {
61            connection: self.connection.clone(),
62            session_pool: self.session_pool.clone(),
63        }
64    }
65}
66
67/// Defines the interface to read data out of Cloud Spanner.
68#[async_trait::async_trait]
69pub trait ReadContext {
70    /// Execute a read-only SQL statement and returns a [ResultSet].
71    ///
72    /// # Parameters
73    ///
74    /// As per the [Cloud Spanner documentation](https://cloud.google.com/spanner/docs/sql-best-practices#query-parameters), the statement may contain named parameters, e.g.: `@param_name`.
75    /// When such parameters are present in the SQL query, their value must be provided in the second argument to this function.
76    ///
77    /// See [`ToSpanner`] to determine how Rust values can be mapped to Cloud Spanner values.
78    ///
79    /// If the parameter values do not line up with parameters in the statement, an [Error] is returned.
80    ///
81    /// # Example
82    ///
83    ///  ```no_run
84    /// # use spanner_rs::{Client, Error, ReadContext};
85    /// # #[tokio::main]
86    /// # async fn main() -> Result<(), Error> {
87    /// # let mut client = Client::configure().connect().await?;
88    /// let my_id = 42;
89    /// let rs = client.read_only().execute_query(
90    ///     "SELECT id FROM person WHERE id > @my_id",
91    ///     &[("my_id", &my_id)],
92    /// ).await?;
93    /// for row in rs.iter() {
94    ///     let id: u32 = row.get("id")?;
95    ///     println!("id: {}", id);
96    /// }
97    /// # Ok(()) }
98    ///  ```
99    async fn execute_query(
100        &mut self,
101        statement: &str,
102        parameters: &[(&str, &(dyn ToSpanner + Sync))],
103    ) -> Result<ResultSet, Error>;
104}
105
106struct ReadOnly {
107    connection: Box<dyn Connection>,
108    bound: Option<TimestampBound>,
109    session_pool: Pool<SessionManager>,
110}
111
112#[async_trait::async_trait]
113impl ReadContext for ReadOnly {
114    async fn execute_query(
115        &mut self,
116        statement: &str,
117        parameters: &[(&str, &(dyn ToSpanner + Sync))],
118    ) -> Result<ResultSet, Error> {
119        let session = self.session_pool.get().await?;
120        let result = self
121            .connection
122            .execute_sql(
123                &session,
124                &TransactionSelector::SingleUse(self.bound.clone()),
125                statement,
126                parameters,
127                None,
128            )
129            .await?;
130
131        Ok(result)
132    }
133}
134
135/// Defines the interface to read from and write into Cloud Spanner.
136///
137/// This extends [`ReadContext`] to provide additional write functionalities.
138#[async_trait::async_trait]
139pub trait TransactionContext: ReadContext {
140    /// Execute a DML SQL statement and returns the number of affected rows.
141    ///
142    /// # Parameters
143    ///
144    /// Like its [`ReadContext::execute_sql`] counterpart, this function also supports query parameters.
145    ///
146    /// # Example
147    ///
148    /// ```no_run
149    /// # use spanner_rs::{Client, Error, TransactionContext};
150    /// # #[tokio::main]
151    /// # async fn main() -> Result<(), Error> {
152    /// # let mut client = Client::configure().connect().await?;
153    /// let id = 42;
154    /// let name = "ferris";
155    /// let rows = client
156    ///     .read_write()
157    ///     .run(|tx| {
158    ///         Box::pin(async move {
159    ///             tx.execute_update(
160    ///                 "INSERT INTO person(id, name) VALUES (@id, @name)",
161    ///                 &[("id", &id), ("name", &name)],
162    ///             )
163    ///             .await
164    ///         })
165    ///     })
166    ///     .await?;
167    ///
168    /// println!("Inserted {} row", rows);
169    /// # Ok(()) }
170    /// ```
171    async fn execute_update(
172        &mut self,
173        statement: &str,
174        parameters: &[(&str, &(dyn ToSpanner + Sync))],
175    ) -> Result<i64, Error>;
176
177    /// Execute a batch of DML SQL statements and returns the number of affected rows for each statement.
178    ///
179    /// # Statements
180    ///
181    /// Each DML statement has its own SQL statement and parameters. See [`Statement`] for more details.
182    ///
183    /// # Example
184    ///
185    /// ```no_run
186    /// # use spanner_rs::{Client, Error, Statement, TransactionContext};
187    /// # #[tokio::main]
188    /// # async fn main() -> Result<(), Error> {
189    /// # let mut client = Client::configure().connect().await?;
190    /// let id = 42;
191    /// let name = "ferris";
192    /// let new_name = "ferris";
193    /// let rows = client
194    ///     .read_write()
195    ///     .run(|tx| {
196    ///         Box::pin(async move {
197    ///             tx.execute_updates(&[
198    ///                 &Statement {
199    ///                     sql: "INSERT INTO person(id, name) VALUES (@id, @name)",
200    ///                     params: &[("id", &id), ("name", &name)],
201    ///                 },
202    ///                 &Statement {
203    ///                     sql: "UPDATE person SET name = @name WHERE id = 42",
204    ///                     params: &[("name", &new_name)],
205    ///                 },
206    ///             ])
207    ///             .await
208    ///         })
209    ///     })
210    ///     .await?;
211    ///
212    /// // each statement modified a single row
213    /// assert_eq!(rows, vec![1, 1]);
214    ///
215    /// # Ok(()) }
216    /// ```
217    async fn execute_updates(&mut self, statements: &[&Statement]) -> Result<Vec<i64>, Error>;
218}
219
220struct Tx<'a> {
221    connection: Box<dyn Connection>,
222    session: PooledConnection<'a, SessionManager>,
223    selector: TransactionSelector,
224    seqno: i64,
225}
226
227#[async_trait::async_trait]
228impl<'a> ReadContext for Tx<'a> {
229    async fn execute_query(
230        &mut self,
231        statement: &str,
232        parameters: &[(&str, &(dyn ToSpanner + Sync))],
233    ) -> Result<ResultSet, Error> {
234        // seqno is required on DML queries and ignored otherwise. Specifying it on every query is fine.
235        self.seqno += 1;
236        let result_set = self
237            .connection
238            .execute_sql(
239                &self.session,
240                &self.selector,
241                statement,
242                parameters,
243                Some(self.seqno),
244            )
245            .await?;
246
247        // TODO: this is brittle, if we forget to do this in some other method, then we risk not committing.
248        if let TransactionSelector::Begin = self.selector {
249            if let Some(tx) = result_set.transaction.as_ref() {
250                self.selector = TransactionSelector::Id(tx.clone());
251            }
252        }
253
254        Ok(result_set)
255    }
256}
257
258#[async_trait::async_trait]
259impl<'a> TransactionContext for Tx<'a> {
260    async fn execute_update(
261        &mut self,
262        statement: &str,
263        parameters: &[(&str, &(dyn ToSpanner + Sync))],
264    ) -> Result<i64, Error> {
265        self.execute_query(statement, parameters).await?
266            .stats
267            .row_count
268            .ok_or_else(|| Error::Client("no row count available. This may be the result of using execute_update on a statement that did not contain DML.".to_string()))
269    }
270
271    async fn execute_updates(&mut self, statements: &[&Statement]) -> Result<Vec<i64>, Error> {
272        self.seqno += 1;
273        let result_sets = self
274            .connection
275            .execute_batch_dml(&self.session, &self.selector, statements, self.seqno)
276            .await?;
277
278        // TODO: this is brittle, if we forget to do this in some other method, then we risk not committing.
279        if let TransactionSelector::Begin = self.selector {
280            if let Some(tx) = result_sets.get(0).and_then(|rs| rs.transaction.as_ref()) {
281                self.selector = TransactionSelector::Id(tx.clone());
282            }
283        }
284
285        result_sets.iter()
286            .map(|rs| {
287                rs.stats
288                .row_count
289                .ok_or_else(|| Error::Client("no row count available. This may be the result of using execute_update on a statement that did not contain DML.".to_string()))
290            })
291            .collect()
292    }
293}
294
295/// Allows running read/write transactions against Cloud Spanner.
296pub struct TxRunner {
297    connection: Box<dyn Connection>,
298    session_pool: Pool<SessionManager>,
299}
300
301impl TxRunner {
302    /// Runs abitrary read / write operations against Cloud Spanner.
303    ///
304    /// This function encapsulates the read/write transaction management concerns, allowing the application to minimize boilerplate.
305    ///
306    /// # Begin
307    ///
308    /// The underlying transaction is only lazily created. If the provided closure does no work against Cloud Spanner,
309    /// then no transaction is created.
310    ///
311    /// # Commit / Rollback
312    ///
313    /// The underlying transaction will be committed if the provided closure returns `Ok`.
314    /// Conversely, any `Err` returned will initiate a rollback.
315    ///
316    /// If the commit or rollback operation returns an unexpected error, then this function will return that error.
317    ///
318    /// # Retries
319    ///
320    /// When committing, Cloud Spanner may reject the transaction due to conflicts with another transaction.
321    /// In these situations, Cloud Spanner allows retrying the transaction which will have a higher priority and potentially successfully commit.
322    ///
323    /// **NOTE:** the consequence of retyring is that the provided closure may be invoked multiple times.
324    /// It is important to avoid doing any additional side effects within this closure as they will also potentially occur more than once.
325    ///
326    /// # Example
327    ///
328    /// ```no_run
329    /// # use spanner_rs::{Client, Error, ReadContext, TransactionContext};
330    /// async fn bump_version(id: u32) -> Result<u32, Error> {
331    /// # let mut client = Client::configure().connect().await?;
332    ///     client
333    ///         .read_write()
334    ///         .run(|tx| {
335    ///             Box::pin(async move {
336    ///                 let rs = tx
337    ///                     .execute_query(
338    ///                         "SELECT MAX(version) FROM versions WHERE id = @id",
339    ///                         &[("id", &id)],
340    ///                     )
341    ///                     .await?;
342    ///                 let latest_version: u32 = rs.iter().next().unwrap().get(0)?;
343    ///                 let next_version = latest_version + 1;
344    ///                 tx.execute_update(
345    ///                     "INSERT INTO versions(id, version) VALUES(@id, @next_version)",
346    ///                     &[("id", &id), ("next_version", &next_version)],
347    ///                 )
348    ///                 .await?;
349    ///                 Ok(next_version)
350    ///             })
351    ///         })
352    ///         .await
353    /// }
354    /// # #[tokio::main]
355    /// # async fn main() -> Result<(), Error> {
356    /// # bump_version(42).await?;
357    /// # Ok(()) }
358    /// ```
359    pub async fn run<'b, O, F>(&'b mut self, mut work: F) -> Result<O, Error>
360    where
361        F: for<'a> FnMut(
362            &'a mut dyn TransactionContext,
363        ) -> Pin<Box<dyn Future<Output = Result<O, Error>> + 'a>>,
364    {
365        let session = self.session_pool.get().await?;
366        let mut ctx = Tx {
367            connection: self.connection.clone(),
368            session,
369            selector: TransactionSelector::Begin,
370            seqno: 0,
371        };
372
373        loop {
374            ctx.selector = TransactionSelector::Begin;
375            ctx.seqno = 0;
376            let result = work(&mut ctx).await;
377
378            let commit_result = if let TransactionSelector::Id(tx) = ctx.selector {
379                if result.is_ok() {
380                    self.connection.commit(&ctx.session, tx).await
381                } else {
382                    self.connection.rollback(&ctx.session, tx).await
383                }
384            } else {
385                Ok(())
386            };
387
388            match commit_result {
389                Err(Error::Status(status)) if status.code() == Code::Aborted => continue,
390                Err(err) => break Err(err),
391                _ => break result,
392            }
393        }
394    }
395}