rust_query/
transaction.rs

1use std::{cell::RefCell, convert::Infallible, iter::zip, marker::PhantomData};
2
3use rusqlite::ErrorCode;
4use sea_query::{
5    Alias, CommonTableExpression, DeleteStatement, Expr, ExprTrait, InsertStatement, IntoTableRef,
6    SelectStatement, SqliteQueryBuilder, UpdateStatement, WithClause,
7};
8use sea_query_rusqlite::RusqliteBinder;
9use self_cell::{MutBorrow, self_cell};
10
11use crate::{
12    IntoExpr, IntoSelect, Table, TableRow,
13    migrate::schema_version,
14    private::Reader,
15    query::Query,
16    rows::Rows,
17    value::{SecretFromSql, ValueBuilder},
18    writable::TableInsert,
19};
20
21/// [Database] is a proof that the database has been configured.
22///
23/// Creating a [Database] requires going through the steps to migrate an existing database to
24/// the required schema, or creating a new database from scratch (See also [crate::migration::Config]).
25/// Please see [Database::migrator] to get started.
26///
27/// Having done the setup to create a compatible database is sadly not a guarantee that the
28/// database will stay compatible for the lifetime of the [Database] struct.
29/// That is why [Database] also stores the `schema_version`. This allows detecting non-malicious
30/// modifications to the schema and gives us the ability to panic when this is detected.
31/// Such non-malicious modification of the schema can happen for example if another [Database]
32/// instance is created with additional migrations (e.g. by another newer instance of your program).
33pub struct Database<S> {
34    pub(crate) manager: r2d2_sqlite::SqliteConnectionManager,
35    pub(crate) schema_version: i64,
36    pub(crate) schema: PhantomData<S>,
37}
38
39use rusqlite::Connection;
40type RTransaction<'x> = Option<rusqlite::Transaction<'x>>;
41
42self_cell!(
43    pub struct OwnedTransaction {
44        owner: MutBorrow<Connection>,
45
46        #[covariant]
47        dependent: RTransaction,
48    }
49);
50
51/// SAFETY:
52/// `RTransaction: !Send` because it borrows from `Connection` and `Connection: !Sync`.
53/// `OwnedTransaction` can be `Send` because we know that `dependent` is the only
54/// borrow of `owner` and `OwnedTransaction: !Sync` so `dependent` can not be borrowed
55/// from multiple threads.
56unsafe impl Send for OwnedTransaction {}
57
58assert_not_impl_any! {OwnedTransaction: Sync}
59
60thread_local! {
61    pub(crate) static TXN: RefCell<Option<OwnedTransaction>> = const { RefCell::new(None) };
62}
63
64impl OwnedTransaction {
65    pub fn get(&self) -> &rusqlite::Transaction<'_> {
66        self.borrow_dependent().as_ref().unwrap()
67    }
68
69    pub fn with(mut self, f: impl FnOnce(rusqlite::Transaction<'_>)) {
70        self.with_dependent_mut(|_, b| f(b.take().unwrap()))
71    }
72}
73
74impl<S: Send + Sync + 'static> Database<S> {
75    /// Create a [Transaction]. This operation always completes immediately as it does not need to wait on other transactions.
76    ///
77    /// This function will panic if the schema was modified compared to when the [Database] value
78    /// was created. This can happen for example by running another instance of your program with
79    /// additional migrations.
80    pub fn transaction<R: Send>(&self, f: impl Send + FnOnce(&'static Transaction<S>) -> R) -> R {
81        let res = std::thread::scope(|scope| {
82            scope
83                .spawn(|| {
84                    use r2d2::ManageConnection;
85
86                    let conn = self.manager.connect().unwrap();
87                    let owned = OwnedTransaction::new(MutBorrow::new(conn), |conn| {
88                        Some(conn.borrow_mut().transaction().unwrap())
89                    });
90
91                    f(Transaction::new_checked(owned, self.schema_version))
92                })
93                .join()
94        });
95        match res {
96            Ok(val) => val,
97            Err(payload) => std::panic::resume_unwind(payload),
98        }
99    }
100
101    /// Create a mutable [Transaction].
102    /// This operation needs to wait for all other mutable [Transaction]s for this database to be finished.
103    ///
104    /// Whether the transaction is commited depends on the result of the closure.
105    /// The transaction is only commited if the closure return [Ok]. In the case that it returns [Err]
106    /// or when the closure panics, a rollback is performed.
107    ///
108    /// The implementation uses the [unlock_notify](https://sqlite.org/unlock_notify.html) feature of sqlite.
109    /// This makes it work across processes.
110    /// Note: you can create a deadlock if you are holding on to another lock while trying to
111    /// get a mutable transaction!
112    ///
113    /// This function will panic if the schema was modified compared to when the [Database] value
114    /// was created. This can happen for example by running another instance of your program with
115    /// additional migrations.
116    pub fn transaction_mut<O: Send, E: Send>(
117        &self,
118        f: impl Send + FnOnce(&'static mut Transaction<S>) -> Result<O, E>,
119    ) -> Result<O, E> {
120        let res = std::thread::scope(|scope| {
121            scope
122                .spawn(|| {
123                    use r2d2::ManageConnection;
124
125                    let conn = self.manager.connect().unwrap();
126                    let owned = OwnedTransaction::new(MutBorrow::new(conn), |conn| {
127                        let txn = conn
128                            .borrow_mut()
129                            .transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)
130                            .unwrap();
131                        Some(txn)
132                    });
133
134                    let res = f(Transaction::new_checked(owned, self.schema_version));
135
136                    let owned = TXN.take().unwrap();
137
138                    if res.is_ok() {
139                        owned.with(|x| x.commit().unwrap());
140                    } else {
141                        owned.with(|x| x.rollback().unwrap());
142                    }
143                    res
144                })
145                .join()
146        });
147        match res {
148            Ok(val) => val,
149            Err(payload) => std::panic::resume_unwind(payload),
150        }
151    }
152
153    /// Same as [Self::transaction_mut], but always commits the transaction.
154    ///
155    /// The only exception is that if the closure panics, a rollback is performed.
156    pub fn transaction_mut_ok<R: Send>(
157        &self,
158        f: impl Send + FnOnce(&'static mut Transaction<S>) -> R,
159    ) -> R {
160        self.transaction_mut(|txn| Ok::<R, Infallible>(f(txn)))
161            .unwrap()
162    }
163
164    /// Create a new [rusqlite::Connection] to the database.
165    ///
166    /// You can do (almost) anything you want with this connection as it is almost completely isolated from all other
167    /// [rust_query] connections. The only thing you should not do here is changing the schema.
168    /// Schema changes are detected with the `schema_version` pragma and will result in a panic when creating a new
169    /// [rust_query] transaction.
170    pub fn rusqlite_connection(&self) -> rusqlite::Connection {
171        use r2d2::ManageConnection;
172        self.manager.connect().unwrap()
173    }
174}
175
176/// [Transaction] can be used to query and update the database.
177///
178/// From the perspective of a [Transaction] each other [Transaction] is fully applied or not at all.
179/// Futhermore, the effects of [Transaction]s have a global order.
180/// So if we have mutations `A` and then `B`, it is impossible for a [Transaction] to see the effect of `B` without seeing the effect of `A`.
181pub struct Transaction<S> {
182    pub(crate) _p2: PhantomData<S>,
183    pub(crate) _local: PhantomData<*const ()>,
184}
185
186impl<S> Transaction<S> {
187    pub(crate) fn new() -> Self {
188        Self {
189            _p2: PhantomData,
190            _local: PhantomData,
191        }
192    }
193}
194
195impl<S> Transaction<S> {
196    /// This will check the schema version and panic if it is not as expected
197    pub(crate) fn new_checked(txn: OwnedTransaction, expected: i64) -> &'static mut Self {
198        if schema_version(txn.get()) != expected {
199            panic!("The database schema was updated unexpectedly")
200        }
201        TXN.set(Some(txn));
202
203        const {
204            assert!(size_of::<Self>() == 0);
205        }
206        // no memory is leaked because Self is zero sized
207        Box::leak(Box::new(Self::new()))
208    }
209
210    /// Execute a query with multiple results.
211    ///
212    /// ```
213    /// # use rust_query::{private::doctest::*};
214    /// # get_txn(|txn| {
215    /// let user_names = txn.query(|rows| {
216    ///     let user = rows.join(User);
217    ///     rows.into_vec(&user.name)
218    /// });
219    /// assert_eq!(user_names, vec!["Alice".to_owned()]);
220    /// # });
221    /// ```
222    pub fn query<F, R>(&self, f: F) -> R
223    where
224        F: for<'inner> FnOnce(&mut Query<'inner, S>) -> R,
225    {
226        // Execution already happens in a [Transaction].
227        // and thus any [TransactionMut] that it might be borrowed
228        // from is borrowed immutably, which means the rows can not change.
229
230        TXN.with_borrow(|txn| {
231            let conn = txn.as_ref().unwrap().get();
232            let q = Rows {
233                phantom: PhantomData,
234                ast: Default::default(),
235                _p: PhantomData,
236            };
237            f(&mut Query {
238                q,
239                phantom: PhantomData,
240                conn,
241            })
242        })
243    }
244
245    /// Retrieve a single result from the database.
246    ///
247    /// ```
248    /// # use rust_query::{private::doctest::*, IntoExpr};
249    /// # rust_query::private::doctest::get_txn(|txn| {
250    /// let res = txn.query_one("test".into_expr());
251    /// assert_eq!(res, "test");
252    /// # });
253    /// ```
254    ///
255    /// Instead of using [Self::query_one] in a loop, it is better to
256    /// call [Self::query] and return all results at once.
257    pub fn query_one<O: 'static>(&self, val: impl IntoSelect<'static, S, Out = O>) -> O {
258        self.query(|e| e.into_iter(val.into_select()).next().unwrap())
259    }
260}
261
262impl<S: 'static> Transaction<S> {
263    /// Try inserting a value into the database.
264    ///
265    /// Returns [Ok] with a reference to the new inserted value or an [Err] with conflict information.
266    /// The type of conflict information depends on the number of unique constraints on the table:
267    /// - 0 unique constraints => [Infallible]
268    /// - 1 unique constraint => [Expr] reference to the conflicting table row.
269    /// - 2+ unique constraints => `()` no further information is provided.
270    ///
271    /// ```
272    /// # use rust_query::{private::doctest::*, IntoExpr};
273    /// # rust_query::private::doctest::get_txn(|mut txn| {
274    /// let res = txn.insert(User {
275    ///     name: "Bob",
276    /// });
277    /// assert!(res.is_ok());
278    /// let res = txn.insert(User {
279    ///     name: "Bob",
280    /// });
281    /// assert!(res.is_err(), "there is a unique constraint on the name");
282    /// # });
283    /// ```
284    pub fn insert<T: Table<Schema = S>>(
285        &mut self,
286        val: impl TableInsert<T = T>,
287    ) -> Result<TableRow<T>, T::Conflict> {
288        try_insert_private(T::NAME.into_table_ref(), None, val.into_insert())
289    }
290
291    /// This is a convenience function to make using [Transaction::insert]
292    /// easier for tables without unique constraints.
293    ///
294    /// The new row is added to the table and the row reference is returned.
295    pub fn insert_ok<T: Table<Schema = S, Conflict = Infallible>>(
296        &mut self,
297        val: impl TableInsert<T = T>,
298    ) -> TableRow<T> {
299        let Ok(row) = self.insert(val);
300        row
301    }
302
303    /// This is a convenience function to make using [Transaction::insert]
304    /// easier for tables with exactly one unique constraints.
305    ///
306    /// The new row is inserted and the reference to the row is returned OR
307    /// an existing row is found which conflicts with the new row and a reference
308    /// to the conflicting row is returned.
309    ///
310    /// ```
311    /// # use rust_query::{private::doctest::*, IntoExpr};
312    /// # rust_query::private::doctest::get_txn(|mut txn| {
313    /// let bob = txn.insert(User {
314    ///     name: "Bob",
315    /// }).unwrap();
316    /// let bob2 = txn.find_or_insert(User {
317    ///     name: "Bob", // this will conflict with the existing row.
318    /// });
319    /// assert_eq!(bob, bob2);
320    /// # });
321    /// ```
322    pub fn find_or_insert<T: Table<Schema = S, Conflict = TableRow<T>>>(
323        &mut self,
324        val: impl TableInsert<T = T>,
325    ) -> TableRow<T> {
326        match self.insert(val) {
327            Ok(row) => row,
328            Err(row) => row,
329        }
330    }
331
332    /// Try updating a row in the database to have new column values.
333    ///
334    /// Updating can fail just like [Transaction::insert] because of unique constraint conflicts.
335    /// This happens when the new values are in conflict with an existing different row.
336    ///
337    /// When the update succeeds, this function returns [Ok], when it fails it returns [Err] with one of
338    /// three conflict types:
339    /// - 0 unique constraints => [Infallible]
340    /// - 1 unique constraint => [Expr] reference to the conflicting table row.
341    /// - 2+ unique constraints => `()` no further information is provided.
342    ///
343    /// ```
344    /// # use rust_query::{private::doctest::*, IntoExpr, Update};
345    /// # rust_query::private::doctest::get_txn(|mut txn| {
346    /// let bob = txn.insert(User {
347    ///     name: "Bob",
348    /// }).unwrap();
349    /// txn.update(bob, User {
350    ///     name: Update::set("New Bob"),
351    /// }).unwrap();
352    /// # });
353    /// ```
354    pub fn update<T: Table<Schema = S>>(
355        &mut self,
356        row: impl IntoExpr<'static, S, Typ = T>,
357        val: T::Update,
358    ) -> Result<(), T::Conflict> {
359        let mut id = ValueBuilder::default();
360        let row = row.into_expr();
361        let (id, _) = id.simple_one(row.inner.clone().erase());
362
363        let val = T::apply_try_update(val, row);
364        let mut reader = Reader::default();
365        T::read(&val, &mut reader);
366        let (col_names, col_exprs): (Vec<_>, Vec<_>) = reader.builder.into_iter().collect();
367
368        let (select, col_fields) = ValueBuilder::default().simple(col_exprs);
369        let cte = CommonTableExpression::new()
370            .query(select)
371            .columns(col_fields.clone())
372            .table_name(Alias::new("cte"))
373            .to_owned();
374        let with_clause = WithClause::new().cte(cte).to_owned();
375
376        let mut update = UpdateStatement::new()
377            .table(("main", T::NAME))
378            .cond_where(Expr::col(("main", T::NAME, T::ID)).in_subquery(id))
379            .to_owned();
380
381        for (name, field) in zip(col_names, col_fields) {
382            let select = SelectStatement::new()
383                .from(Alias::new("cte"))
384                .column(field)
385                .to_owned();
386            let value = sea_query::Expr::SubQuery(
387                None,
388                Box::new(sea_query::SubQueryStatement::SelectStatement(select)),
389            );
390            update.value(Alias::new(name), value);
391        }
392
393        let (query, args) = update.with(with_clause).build_rusqlite(SqliteQueryBuilder);
394
395        TXN.with_borrow(|txn| {
396            let txn = txn.as_ref().unwrap().get();
397
398            let mut stmt = txn.prepare_cached(&query).unwrap();
399            match stmt.execute(&*args.as_params()) {
400                Ok(1) => Ok(()),
401                Ok(n) => panic!("unexpected number of updates: {n}"),
402                Err(rusqlite::Error::SqliteFailure(kind, Some(_val)))
403                    if kind.code == ErrorCode::ConstraintViolation =>
404                {
405                    // val looks like "UNIQUE constraint failed: playlist_track.playlist, playlist_track.track"
406                    Err(T::get_conflict_unchecked(self, &val))
407                }
408                Err(err) => panic!("{err:?}"),
409            }
410        })
411    }
412
413    /// This is a convenience function to use [Transaction::update] for updates
414    /// that can not cause unique constraint violations.
415    ///
416    /// This method can be used for all tables, it just does not allow modifying
417    /// columns that are part of unique constraints.
418    pub fn update_ok<T: Table<Schema = S>>(
419        &mut self,
420        row: impl IntoExpr<'static, S, Typ = T>,
421        val: T::UpdateOk,
422    ) {
423        match self.update(row, T::update_into_try_update(val)) {
424            Ok(val) => val,
425            Err(_) => {
426                unreachable!("update can not fail")
427            }
428        }
429    }
430
431    /// Convert the [Transaction] into a [TransactionWeak] to allow deletions.
432    pub fn downgrade(&'static mut self) -> &'static mut TransactionWeak<S> {
433        // TODO: clean this up
434        Box::leak(Box::new(TransactionWeak { inner: PhantomData }))
435    }
436}
437
438/// This is the weak version of [Transaction].
439///
440/// The reason that it is called `weak` is because [TransactionWeak] can not guarantee
441/// that [TableRow]s prove the existence of their particular row.
442///
443/// [TransactionWeak] is useful because it allowes deleting rows.
444pub struct TransactionWeak<S> {
445    inner: PhantomData<Transaction<S>>,
446}
447
448impl<S: 'static> TransactionWeak<S> {
449    /// Try to delete a row from the database.
450    ///
451    /// This will return an [Err] if there is a row that references the row that is being deleted.
452    /// When this method returns [Ok] it will contain a [bool] that is either
453    /// - `true` if the row was just deleted.
454    /// - `false` if the row was deleted previously in this transaction.
455    pub fn delete<T: Table<Schema = S>>(&mut self, val: TableRow<T>) -> Result<bool, T::Referer> {
456        let stmt = DeleteStatement::new()
457            .from_table(("main", T::NAME))
458            .cond_where(Expr::col(("main", T::NAME, T::ID)).eq(val.inner.idx))
459            .to_owned();
460
461        let (query, args) = stmt.build_rusqlite(SqliteQueryBuilder);
462
463        TXN.with_borrow(|txn| {
464            let txn = txn.as_ref().unwrap().get();
465            let mut stmt = txn.prepare_cached(&query).unwrap();
466
467            match stmt.execute(&*args.as_params()) {
468                Ok(0) => Ok(false),
469                Ok(1) => Ok(true),
470                Ok(n) => {
471                    panic!("unexpected number of deletes {n}")
472                }
473                Err(rusqlite::Error::SqliteFailure(kind, Some(_val)))
474                    if kind.code == ErrorCode::ConstraintViolation =>
475                {
476                    // Some foreign key constraint got violated
477                    Err(T::get_referer_unchecked())
478                }
479                Err(err) => panic!("{err:?}"),
480            }
481        })
482    }
483
484    /// Delete a row from the database.
485    ///
486    /// This is the infallible version of [TransactionWeak::delete].
487    ///
488    /// To be able to use this method you have to mark the table as `#[no_reference]` in the schema.
489    pub fn delete_ok<T: Table<Referer = Infallible, Schema = S>>(
490        &mut self,
491        val: TableRow<T>,
492    ) -> bool {
493        let Ok(res) = self.delete(val);
494        res
495    }
496
497    /// This allows you to do (almost) anything you want with the internal [rusqlite::Transaction].
498    ///
499    /// Note that there are some things that you should not do with the transaction, such as:
500    /// - Changes to the schema, these will result in a panic as described in [Database].
501    /// - Changes to the connection configuration such as disabling foreign key checks.
502    ///
503    /// **When this method is used to break [rust_query] invariants, all other [rust_query] function calls
504    /// may result in a panic.**
505    pub fn rusqlite_transaction<R>(&mut self, f: impl FnOnce(&rusqlite::Transaction) -> R) -> R {
506        TXN.with_borrow(|txn| f(txn.as_ref().unwrap().get()))
507    }
508}
509
510pub fn try_insert_private<T: Table>(
511    table: sea_query::TableRef,
512    idx: Option<i64>,
513    val: T::Insert,
514) -> Result<TableRow<T>, T::Conflict> {
515    let mut reader = Reader::default();
516    T::read(&val, &mut reader);
517    if let Some(idx) = idx {
518        reader.col(T::ID, idx);
519    }
520    let (col_names, col_exprs): (Vec<_>, Vec<_>) = reader.builder.into_iter().collect();
521    let is_empty = col_names.is_empty();
522
523    let (select, _) = ValueBuilder::default().simple(col_exprs);
524
525    let mut insert = InsertStatement::new();
526    insert.into_table(table);
527    insert.columns(col_names.into_iter().map(Alias::new));
528    if is_empty {
529        // select always has at least one column, so we leave it out when there are no columns
530        insert.or_default_values();
531    } else {
532        insert.select_from(select).unwrap();
533    }
534    insert.returning_col(T::ID);
535
536    let (sql, values) = insert.build_rusqlite(SqliteQueryBuilder);
537
538    TXN.with_borrow(|txn| {
539        let txn = txn.as_ref().unwrap().get();
540
541        let mut statement = txn.prepare_cached(&sql).unwrap();
542        let mut res = statement
543            .query_map(&*values.as_params(), |row| {
544                Ok(TableRow::<T>::from_sql(row.get_ref(T::ID)?)?)
545            })
546            .unwrap();
547
548        match res.next().unwrap() {
549            Ok(id) => Ok(id),
550            Err(rusqlite::Error::SqliteFailure(kind, Some(_val)))
551                if kind.code == ErrorCode::ConstraintViolation =>
552            {
553                // val looks like "UNIQUE constraint failed: playlist_track.playlist, playlist_track.track"
554                Err(T::get_conflict_unchecked(&Transaction::new(), &val))
555            }
556            Err(err) => panic!("{err:?}"),
557        }
558    })
559}