rust_query/
transaction.rs

1use std::{convert::Infallible, marker::PhantomData, ops::Deref, rc::Rc};
2
3use rusqlite::ErrorCode;
4use sea_query::{
5    Alias, CommonTableExpression, DeleteStatement, Expr, InsertStatement, IntoTableRef,
6    SelectStatement, SimpleExpr, SqliteQueryBuilder, UpdateStatement, WithClause,
7};
8use sea_query_rusqlite::RusqliteBinder;
9
10use crate::{
11    IntoExpr, IntoSelect, Table, TableRow, ast::MySelect, client::LocalClient,
12    migrate::schema_version, private::Reader, query::Query, rows::Rows, value::SecretFromSql,
13    writable::TableInsert,
14};
15
16/// [Database] is a proof that the database has been configured.
17///
18/// Creating a [Database] requires going through the steps to migrate an existing database to
19/// the required schema, or creating a new database from scratch (See also [crate::migration::Config]).
20/// Having done the setup to create a compatible database is sadly not a guarantee that the
21/// database will stay compatible for the lifetime of the [Database] struct.
22///
23/// That is why [Database] also stores the `schema_version`. This allows detecting non-malicious
24/// modifications to the schema and gives us the ability to panic when this is detected.
25/// Such non-malicious modification of the schema can happen for example if another [Database]
26/// instance is created with additional migrations (e.g. by another newer instance of your program).
27///
28/// For information on how to create transactions, please refer to [LocalClient].
29pub struct Database<S> {
30    pub(crate) manager: r2d2_sqlite::SqliteConnectionManager,
31    pub(crate) schema_version: i64,
32    pub(crate) schema: PhantomData<S>,
33}
34
35impl<S> Database<S> {
36    /// Create a new [rusqlite::Connection] to the database.
37    ///
38    /// You can do (almost) anything you want with this connection as it is almost completely isolated from all other
39    /// [rust_query] connections. The only thing you should not do here is changing the schema.
40    /// Schema changes are detected with the `schema_version` pragma and will result in a panic when creating a new
41    /// transaction.
42    pub fn rusqlite_connection(&self) -> rusqlite::Connection {
43        use r2d2::ManageConnection;
44        self.manager.connect().unwrap()
45    }
46}
47
48/// [Transaction] can be used to query the database.
49///
50/// From the perspective of a [Transaction] each [TransactionMut] is fully applied or not at all.
51/// Futhermore, the effects of [TransactionMut]s have a global order.
52/// So if we have mutations `A` and then `B`, it is impossible for a [Transaction] to see the effect of `B` without seeing the effect of `A`.
53///
54/// All [TableRow] references retrieved from the database live for at most `'a`.
55/// This makes these references effectively local to this [Transaction].
56pub struct Transaction<'t, S> {
57    pub(crate) transaction: Rc<rusqlite::Transaction<'t>>,
58    pub(crate) _p: PhantomData<fn(&'t ()) -> &'t ()>,
59    pub(crate) _p2: PhantomData<S>,
60    pub(crate) _local: PhantomData<LocalClient>,
61}
62
63impl<'t, S> Transaction<'t, S> {
64    pub(crate) fn new(raw: Rc<rusqlite::Transaction<'t>>) -> Self {
65        Self {
66            transaction: raw,
67            _p: PhantomData,
68            _p2: PhantomData,
69            _local: PhantomData,
70        }
71    }
72}
73
74/// Same as [Transaction], but allows inserting new rows.
75///
76/// [TransactionMut] always uses the latest version of the database, with the effects of all previous [TransactionMut]s applied.
77///
78/// To make mutations to the database permanent you need to use [TransactionMut::commit].
79/// This is to make sure that if a function panics while holding a mutable transaction, it will roll back those changes.
80pub struct TransactionMut<'t, S> {
81    pub(crate) inner: Transaction<'t, S>,
82}
83
84impl<'t, S> Deref for TransactionMut<'t, S> {
85    type Target = Transaction<'t, S>;
86
87    fn deref(&self) -> &Self::Target {
88        &self.inner
89    }
90}
91
92impl<'t, S> Transaction<'t, S> {
93    /// This will check the schema version and panic if it is not as expected
94    pub(crate) fn new_checked(txn: rusqlite::Transaction<'t>, expected: i64) -> Self {
95        if schema_version(&txn) != expected {
96            panic!("The database schema was updated unexpectedly")
97        }
98
99        Self::new(Rc::new(txn))
100    }
101
102    /// Execute a query with multiple results.
103    ///
104    /// ```
105    /// # use rust_query::{private::doctest::*, Table};
106    /// # let mut client = get_client();
107    /// # let txn = get_txn(&mut client);
108    /// let user_names = txn.query(|rows| {
109    ///     let user = User::join(rows);
110    ///     rows.into_vec(user.name())
111    /// });
112    /// assert_eq!(user_names, vec!["Alice".to_owned()]);
113    /// ```
114    pub fn query<F, R>(&self, f: F) -> R
115    where
116        F: for<'inner> FnOnce(&mut Query<'t, 'inner, S>) -> R,
117    {
118        // Execution already happens in a [Transaction].
119        // and thus any [TransactionMut] that it might be borrowed
120        // from is borrowed immutably, which means the rows can not change.
121        let conn: &rusqlite::Connection = &self.transaction;
122        let ast = MySelect::default();
123        let q = Rows {
124            phantom: PhantomData,
125            ast,
126            _p: PhantomData,
127        };
128        f(&mut Query {
129            q,
130            phantom: PhantomData,
131            conn,
132        })
133    }
134
135    /// Retrieve a single result from the database.
136    ///
137    /// ```
138    /// # use rust_query::{private::doctest::*, IntoExpr};
139    /// # let mut client = rust_query::private::doctest::get_client();
140    /// # let txn = rust_query::private::doctest::get_txn(&mut client);
141    /// let res = txn.query_one("test".into_expr());
142    /// assert_eq!(res, "test");
143    /// ```
144    ///
145    /// Instead of using [Self::query_one] in a loop, it is better to
146    /// call [Self::query] and return all results at once.
147    pub fn query_one<'e, O>(&self, val: impl IntoSelect<'t, 't, S, Out = O>) -> O {
148        // Theoretically this doesn't even need to be in a transaction.
149        // We already have one though, so we must use it.
150        let mut res = self.query(|e| {
151            // Cast the static lifetime to any lifetime necessary, this is fine because we know the static lifetime
152            // can not be guaranteed by a query scope.
153            e.into_vec_private(val)
154        });
155        res.pop().unwrap()
156    }
157}
158
159impl<'t, S: 'static> TransactionMut<'t, S> {
160    /// Try inserting a value into the database.
161    ///
162    /// Returns [Ok] with a reference to the new inserted value or an [Err] with conflict information.
163    /// The type of conflict information depends on the number of unique constraints on the table:
164    /// - 0 unique constraints => [Infallible]
165    /// - 1 unique constraint => [Expr] reference to the conflicting table row.
166    /// - 2+ unique constraints => `()` no further information is provided.
167    ///
168    /// ```
169    /// # use rust_query::{private::doctest::*, IntoExpr};
170    /// # let mut client = rust_query::private::doctest::get_client();
171    /// # let mut txn = rust_query::private::doctest::get_txn(&mut client);
172    /// let res = txn.insert(User {
173    ///     name: "Bob",
174    /// });
175    /// assert!(res.is_ok());
176    /// let res = txn.insert(User {
177    ///     name: "Bob",
178    /// });
179    /// assert!(res.is_err(), "there is a unique constraint on the name");
180    /// ```
181    pub fn insert<T: Table<Schema = S>>(
182        &mut self,
183        val: impl TableInsert<'t, T = T>,
184    ) -> Result<TableRow<'t, T>, T::Conflict<'t>> {
185        try_insert_private(
186            &self.transaction,
187            Alias::new(T::NAME).into_table_ref(),
188            None,
189            val.into_insert(),
190        )
191    }
192
193    /// This is a convenience function to make using [TransactionMut::insert]
194    /// easier for tables without unique constraints.
195    ///
196    /// The new row is added to the table and the row reference is returned.
197    pub fn insert_ok<T: Table<Schema = S, Conflict<'t> = Infallible>>(
198        &mut self,
199        val: impl TableInsert<'t, T = T>,
200    ) -> TableRow<'t, T> {
201        let Ok(row) = self.insert(val);
202        row
203    }
204
205    /// This is a convenience function to make using [TransactionMut::insert]
206    /// easier for tables with exactly one unique constraints.
207    ///
208    /// The new row is inserted and the reference to the row is returned OR
209    /// an existing row is found which conflicts with the new row and a reference
210    /// to the conflicting row is returned.
211    ///
212    /// ```
213    /// # use rust_query::{private::doctest::*, IntoExpr};
214    /// # let mut client = rust_query::private::doctest::get_client();
215    /// # let mut txn = rust_query::private::doctest::get_txn(&mut client);
216    /// let bob = txn.insert(User {
217    ///     name: "Bob",
218    /// }).unwrap();
219    /// let bob2 = txn.find_or_insert(User {
220    ///     name: "Bob", // this will conflict with the existing row.
221    /// });
222    /// assert_eq!(bob, txn.query_one(bob2));
223    /// ```
224    pub fn find_or_insert<T: Table<Schema = S, Conflict<'t> = crate::Expr<'t, S, T>>>(
225        &mut self,
226        val: impl TableInsert<'t, T = T>,
227    ) -> crate::Expr<'t, S, T> {
228        match self.insert(val) {
229            Ok(row) => row.into_expr(),
230            Err(row) => row,
231        }
232    }
233
234    /// Try updating a row in the database to have new column values.
235    ///
236    /// Updating can fail just like [TransactionMut::insert] because of unique constraint conflicts.
237    /// This happens when the new values are in conflict with an existing different row.
238    ///
239    /// When the update succeeds, this function returns [Ok<()>], when it fails it returns [Err] with one of
240    /// three conflict types:
241    /// - 0 unique constraints => [Infallible]
242    /// - 1 unique constraint => [Expr] reference to the conflicting table row.
243    /// - 2+ unique constraints => `()` no further information is provided.
244    ///
245    /// ```
246    /// # use rust_query::{private::doctest::*, IntoExpr, Update};
247    /// # let mut client = rust_query::private::doctest::get_client();
248    /// # let mut txn = rust_query::private::doctest::get_txn(&mut client);
249    /// let bob = txn.insert(User {
250    ///     name: "Bob",
251    /// }).unwrap();
252    /// txn.update(bob, User {
253    ///     name: Update::set("New Bob"),
254    /// }).unwrap();
255    /// ```
256    pub fn update<T: Table<Schema = S>>(
257        &mut self,
258        row: impl IntoExpr<'t, S, Typ = T>,
259        val: T::Update<'t>,
260    ) -> Result<(), T::Conflict<'t>> {
261        let id = MySelect::default();
262        Reader::new(&id).col(T::ID, &row);
263        let id = id.build_select(false);
264
265        let val = T::apply_try_update(val, row.into_expr());
266        let ast = MySelect::default();
267        T::read(&val, Reader::new(&ast));
268
269        let select = ast.build_select(false);
270        let cte = CommonTableExpression::new()
271            .query(select)
272            .columns(ast.select.iter().map(|x| x.1))
273            .table_name(Alias::new("cte"))
274            .to_owned();
275        let with_clause = WithClause::new().cte(cte).to_owned();
276
277        let mut update = UpdateStatement::new()
278            .table(Alias::new(T::NAME))
279            .cond_where(Expr::col(Alias::new(T::ID)).in_subquery(id))
280            .to_owned();
281
282        for (_, col) in ast.select.iter() {
283            let select = SelectStatement::new()
284                .from(Alias::new("cte"))
285                .column(*col)
286                .to_owned();
287            let value = SimpleExpr::SubQuery(
288                None,
289                Box::new(sea_query::SubQueryStatement::SelectStatement(select)),
290            );
291            update.value(*col, value);
292        }
293
294        let (query, args) = update.with(with_clause).build_rusqlite(SqliteQueryBuilder);
295
296        let mut stmt = self.transaction.prepare_cached(&query).unwrap();
297        match stmt.execute(&*args.as_params()) {
298            Ok(1) => Ok(()),
299            Ok(n) => panic!("unexpected number of updates: {n}"),
300            Err(rusqlite::Error::SqliteFailure(kind, Some(_val)))
301                if kind.code == ErrorCode::ConstraintViolation =>
302            {
303                // val looks like "UNIQUE constraint failed: playlist_track.playlist, playlist_track.track"
304                Err(T::get_conflict_unchecked(&val))
305            }
306            Err(err) => Err(err).unwrap(),
307        }
308    }
309
310    /// This is a convenience function to use [TransactionMut::update] for updates
311    /// that can not cause unique constraint violations.
312    ///
313    /// This method can be used for all tables, it just does not allow modifying
314    /// columns that are part of unique constraints.
315    pub fn update_ok<T: Table<Schema = S>>(
316        &mut self,
317        row: impl IntoExpr<'t, S, Typ = T>,
318        val: T::UpdateOk<'t>,
319    ) {
320        match self.update(row, T::update_into_try_update(val)) {
321            Ok(val) => val,
322            Err(_) => {
323                unreachable!("update can not fail")
324            }
325        }
326    }
327
328    /// Make the changes made in this [TransactionMut] permanent.
329    ///
330    /// If the [TransactionMut] is dropped without calling this function, then the changes are rolled back.
331    pub fn commit(self) {
332        Rc::into_inner(self.inner.transaction)
333            .unwrap()
334            .commit()
335            .unwrap();
336    }
337
338    /// Convert the [TransactionMut] into a [TransactionWeak] to allow deletions.
339    pub fn downgrade(self) -> TransactionWeak<'t, S> {
340        TransactionWeak { inner: self }
341    }
342}
343
344/// This is the weak version of [TransactionMut].
345///
346/// The reason that it is called `weak` is because [TransactionWeak] can not guarantee
347/// that [TableRow]s prove the existence of their particular row.
348///
349/// [TransactionWeak] is useful because it allowes deleting rows.
350pub struct TransactionWeak<'t, S> {
351    inner: TransactionMut<'t, S>,
352}
353
354impl<'t, S: 'static> TransactionWeak<'t, S> {
355    /// Try to delete a row from the database.
356    ///
357    /// This will return an [Err] if there is a row that references the row that is being deleted.
358    /// When this method returns [Ok] it will contain a [bool] that is either
359    /// - `true` if the row was just deleted.
360    /// - `false` if the row was deleted previously in this transaction.
361    pub fn delete<T: Table<Schema = S>>(
362        &mut self,
363        val: TableRow<'t, T>,
364    ) -> Result<bool, T::Referer> {
365        let stmt = DeleteStatement::new()
366            .from_table(Alias::new(T::NAME))
367            .cond_where(Expr::col(Alias::new(T::ID)).eq(val.inner.idx))
368            .to_owned();
369
370        let (query, args) = stmt.build_rusqlite(SqliteQueryBuilder);
371        let mut stmt = self.inner.transaction.prepare_cached(&query).unwrap();
372
373        match stmt.execute(&*args.as_params()) {
374            Ok(0) => Ok(false),
375            Ok(1) => Ok(true),
376            Ok(n) => {
377                panic!("unexpected number of deletes {n}")
378            }
379            Err(rusqlite::Error::SqliteFailure(kind, Some(_val)))
380                if kind.code == ErrorCode::ConstraintViolation =>
381            {
382                // Some foreign key constraint got violated
383                Err(T::get_referer_unchecked())
384            }
385            Err(err) => Err(err).unwrap(),
386        }
387    }
388
389    /// Delete a row from the database.
390    ///
391    /// This is the infallible version of [TransactionWeak::delete].
392    ///
393    /// To be able to use this method you have to mark the table as `#[no_reference]` in the schema.
394    pub fn delete_ok<T: Table<Referer = Infallible, Schema = S>>(
395        &mut self,
396        val: TableRow<'t, T>,
397    ) -> bool {
398        let Ok(res) = self.delete(val);
399        res
400    }
401
402    /// This allows you to do (almost) anything you want with the internal [rusqlite::Transaction].
403    ///
404    /// Note that there are some things that you should not do with the transaction, such as:
405    /// - Changes to the schema, these will result in a panic as described in [Database].
406    /// - Changes to the connection configuration such as disabling foreign key checks.
407    ///
408    /// **When this method is used to break [rust_query] invariants, all other [rust_query] function calls
409    /// may result in a panic.**
410    pub fn rusqlite_transaction(&mut self) -> &rusqlite::Transaction {
411        &self.inner.transaction
412    }
413
414    /// Make the changes made in this [TransactionWeak] permanent.
415    ///
416    /// If the [TransactionWeak] is dropped without calling this function, then the changes are rolled back.
417    pub fn commit(self) {
418        self.inner.commit();
419    }
420}
421
422pub fn try_insert_private<'t, T: Table>(
423    transaction: &Rc<rusqlite::Transaction<'t>>,
424    table: sea_query::TableRef,
425    idx: Option<i64>,
426    val: T::Insert<'t>,
427) -> Result<TableRow<'t, T>, T::Conflict<'t>> {
428    let ast = MySelect::default();
429    let reader = Reader::new(&ast);
430    T::read(&val, reader);
431    if let Some(idx) = idx {
432        reader.col(T::ID, idx);
433    }
434
435    let select = ast.simple();
436
437    let mut insert = InsertStatement::new();
438    let names = ast.select.iter().map(|(_field, name)| *name);
439    insert.into_table(table);
440    insert.columns(names);
441    insert.select_from(select).unwrap();
442    insert.returning_col(Alias::new(T::ID));
443
444    let (sql, values) = insert.build_rusqlite(SqliteQueryBuilder);
445
446    let mut statement = transaction.prepare_cached(&sql).unwrap();
447    let mut res = statement
448        .query_map(&*values.as_params(), |row| {
449            Ok(TableRow::<'_, T>::from_sql(row.get_ref(T::ID)?)?)
450        })
451        .unwrap();
452
453    match res.next().unwrap() {
454        Ok(id) => Ok(id),
455        Err(rusqlite::Error::SqliteFailure(kind, Some(_val)))
456            if kind.code == ErrorCode::ConstraintViolation =>
457        {
458            // val looks like "UNIQUE constraint failed: playlist_track.playlist, playlist_track.track"
459            Err(T::get_conflict_unchecked(&val))
460        }
461        Err(err) => Err(err).unwrap(),
462    }
463}