rust_query/
transaction.rs

1use std::{
2    cell::RefCell, convert::Infallible, iter::zip, marker::PhantomData, sync::atomic::AtomicI64,
3};
4
5use rusqlite::ErrorCode;
6use sea_query::{
7    Alias, CommonTableExpression, DeleteStatement, Expr, ExprTrait, InsertStatement, IntoTableRef,
8    SelectStatement, SqliteQueryBuilder, UpdateStatement, WithClause,
9};
10use sea_query_rusqlite::RusqliteBinder;
11use self_cell::{MutBorrow, self_cell};
12
13use crate::{
14    IntoExpr, IntoSelect, Table, TableRow,
15    migrate::{Schema, check_schema, schema_version, user_version},
16    private::Reader,
17    query::{Query, track_stmt},
18    rows::Rows,
19    value::{DynTypedExpr, SecretFromSql, ValueBuilder},
20    writable::TableInsert,
21};
22
23/// [Database] is a proof that the database has been configured.
24///
25/// Creating a [Database] requires going through the steps to migrate an existing database to
26/// the required schema, or creating a new database from scratch (See also [crate::migration::Config]).
27/// Please see [Database::migrator] to get started.
28///
29/// Having done the setup to create a compatible database is sadly not a guarantee that the
30/// database will stay compatible for the lifetime of the [Database] struct.
31/// That is why [Database] also stores the `schema_version`. This allows detecting non-malicious
32/// modifications to the schema and gives us the ability to panic when this is detected.
33/// Such non-malicious modification of the schema can happen for example if another [Database]
34/// instance is created with additional migrations (e.g. by another newer instance of your program).
35pub struct Database<S> {
36    pub(crate) manager: r2d2_sqlite::SqliteConnectionManager,
37    pub(crate) schema_version: AtomicI64,
38    pub(crate) schema: PhantomData<S>,
39    // TODO: this should technically not be required with `unlock_notify`.
40    // see <https://github.com/rusqlite/rusqlite/issues/1736>
41    pub(crate) mut_lock: parking_lot::FairMutex<()>,
42}
43
44use rusqlite::Connection;
45type RTransaction<'x> = Option<rusqlite::Transaction<'x>>;
46
47self_cell!(
48    pub struct OwnedTransaction {
49        owner: MutBorrow<Connection>,
50
51        #[covariant]
52        dependent: RTransaction,
53    }
54);
55
56/// SAFETY:
57/// `RTransaction: !Send` because it borrows from `Connection` and `Connection: !Sync`.
58/// `OwnedTransaction` can be `Send` because we know that `dependent` is the only
59/// borrow of `owner` and `OwnedTransaction: !Sync` so `dependent` can not be borrowed
60/// from multiple threads.
61unsafe impl Send for OwnedTransaction {}
62
63assert_not_impl_any! {OwnedTransaction: Sync}
64
65thread_local! {
66    pub(crate) static TXN: RefCell<Option<OwnedTransaction>> = const { RefCell::new(None) };
67}
68
69impl OwnedTransaction {
70    pub fn get(&self) -> &rusqlite::Transaction<'_> {
71        self.borrow_dependent().as_ref().unwrap()
72    }
73
74    pub fn with(mut self, f: impl FnOnce(rusqlite::Transaction<'_>)) {
75        self.with_dependent_mut(|_, b| f(b.take().unwrap()))
76    }
77}
78
79impl<S: Send + Sync + Schema> Database<S> {
80    /// Create a [Transaction]. This operation always completes immediately as it does not need to wait on other transactions.
81    ///
82    /// This function will panic if the schema was modified compared to when the [Database] value
83    /// was created. This can happen for example by running another instance of your program with
84    /// additional migrations.
85    pub fn transaction<R: Send>(&self, f: impl Send + FnOnce(&'static Transaction<S>) -> R) -> R {
86        let res = std::thread::scope(|scope| {
87            scope
88                .spawn(|| {
89                    use r2d2::ManageConnection;
90
91                    let conn = self.manager.connect().unwrap();
92                    let owned = OwnedTransaction::new(MutBorrow::new(conn), |conn| {
93                        Some(conn.borrow_mut().transaction().unwrap())
94                    });
95
96                    f(Transaction::new_checked(owned, &self.schema_version))
97                })
98                .join()
99        });
100        match res {
101            Ok(val) => val,
102            Err(payload) => std::panic::resume_unwind(payload),
103        }
104    }
105
106    /// Create a mutable [Transaction].
107    /// This operation needs to wait for all other mutable [Transaction]s for this database to be finished.
108    ///
109    /// Whether the transaction is commited depends on the result of the closure.
110    /// The transaction is only commited if the closure return [Ok]. In the case that it returns [Err]
111    /// or when the closure panics, a rollback is performed.
112    ///
113    /// The implementation uses the [unlock_notify](https://sqlite.org/unlock_notify.html) feature of sqlite.
114    /// This makes it work across processes.
115    /// Note: you can create a deadlock if you are holding on to another lock while trying to
116    /// get a mutable transaction!
117    ///
118    /// This function will panic if the schema was modified compared to when the [Database] value
119    /// was created. This can happen for example by running another instance of your program with
120    /// additional migrations.
121    pub fn transaction_mut<O: Send, E: Send>(
122        &self,
123        f: impl Send + FnOnce(&'static mut Transaction<S>) -> Result<O, E>,
124    ) -> Result<O, E> {
125        use r2d2::ManageConnection;
126        let conn = self.manager.connect().unwrap();
127
128        // Acquire the lock just before creating the transaction
129        let guard = self.mut_lock.lock();
130
131        let owned = OwnedTransaction::new(MutBorrow::new(conn), |conn| {
132            let txn = conn
133                .borrow_mut()
134                .transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)
135                .unwrap();
136            Some(txn)
137        });
138        let join_res = std::thread::scope(|scope| {
139            scope
140                .spawn(|| {
141                    let res = f(Transaction::new_checked(owned, &self.schema_version));
142                    let owned = TXN.take().unwrap();
143                    (res, owned)
144                })
145                .join()
146        });
147
148        // Drop the guard before commiting to let sqlite go to the next transaction
149        // more quickly while guaranteeing that the database will unlock soon.
150        drop(guard);
151
152        let (res, owned) = match join_res {
153            Ok(val) => val,
154            Err(payload) => std::panic::resume_unwind(payload),
155        };
156
157        if res.is_ok() {
158            owned.with(|x| x.commit().unwrap());
159        } else {
160            owned.with(|x| x.rollback().unwrap());
161        }
162        res
163    }
164
165    /// Same as [Self::transaction_mut], but always commits the transaction.
166    ///
167    /// The only exception is that if the closure panics, a rollback is performed.
168    pub fn transaction_mut_ok<R: Send>(
169        &self,
170        f: impl Send + FnOnce(&'static mut Transaction<S>) -> R,
171    ) -> R {
172        self.transaction_mut(|txn| Ok::<R, Infallible>(f(txn)))
173            .unwrap()
174    }
175
176    /// Create a new [rusqlite::Connection] to the database.
177    ///
178    /// You can do (almost) anything you want with this connection as it is almost completely isolated from all other
179    /// [rust_query] connections. The only thing you should not do here is changing the schema.
180    /// Schema changes are detected with the `schema_version` pragma and will result in a panic when creating a new
181    /// [rust_query] transaction.
182    ///
183    /// The `foreign_keys` pragma is always enabled here, even if [crate::migrate::ForeignKeys::SQLite] is not used.
184    pub fn rusqlite_connection(&self) -> rusqlite::Connection {
185        use r2d2::ManageConnection;
186        let conn = self.manager.connect().unwrap();
187        conn.pragma_update(None, "foreign_keys", "ON").unwrap();
188        conn
189    }
190}
191
192/// [Transaction] can be used to query and update the database.
193///
194/// From the perspective of a [Transaction] each other [Transaction] is fully applied or not at all.
195/// Futhermore, the effects of [Transaction]s have a global order.
196/// So if we have mutations `A` and then `B`, it is impossible for a [Transaction] to see the effect of `B` without seeing the effect of `A`.
197pub struct Transaction<S> {
198    pub(crate) _p2: PhantomData<S>,
199    pub(crate) _local: PhantomData<*const ()>,
200}
201
202impl<S> Transaction<S> {
203    pub(crate) fn new() -> Self {
204        Self {
205            _p2: PhantomData,
206            _local: PhantomData,
207        }
208    }
209}
210
211impl<S: Schema> Transaction<S> {
212    /// This will check the schema version and panic if it is not as expected
213    pub(crate) fn new_checked(txn: OwnedTransaction, expected: &AtomicI64) -> &'static mut Self {
214        let schema_version = schema_version(txn.get());
215        // If the schema version is not the expected version then we
216        // check if the changes are acceptable.
217        if schema_version != expected.load(std::sync::atomic::Ordering::Relaxed) {
218            if user_version(txn.get()).unwrap() != S::VERSION {
219                panic!("The database user_version changed unexpectedly")
220            }
221
222            TXN.set(Some(txn));
223            check_schema::<S>();
224            expected.store(schema_version, std::sync::atomic::Ordering::Relaxed);
225        } else {
226            TXN.set(Some(txn));
227        }
228
229        const {
230            assert!(size_of::<Self>() == 0);
231        }
232        // no memory is leaked because Self is zero sized
233        Box::leak(Box::new(Self::new()))
234    }
235}
236
237impl<S> Transaction<S> {
238    /// Execute a query with multiple results.
239    ///
240    /// ```
241    /// # use rust_query::{private::doctest::*};
242    /// # get_txn(|txn| {
243    /// let user_names = txn.query(|rows| {
244    ///     let user = rows.join(User);
245    ///     rows.into_vec(&user.name)
246    /// });
247    /// assert_eq!(user_names, vec!["Alice".to_owned()]);
248    /// # });
249    /// ```
250    pub fn query<F, R>(&self, f: F) -> R
251    where
252        F: for<'inner> FnOnce(&mut Query<'inner, S>) -> R,
253    {
254        // Execution already happens in a [Transaction].
255        // and thus any [TransactionMut] that it might be borrowed
256        // from is borrowed immutably, which means the rows can not change.
257
258        TXN.with_borrow(|txn| {
259            let conn = txn.as_ref().unwrap().get();
260            let q = Rows {
261                phantom: PhantomData,
262                ast: Default::default(),
263                _p: PhantomData,
264            };
265            f(&mut Query {
266                q,
267                phantom: PhantomData,
268                conn,
269            })
270        })
271    }
272
273    /// Retrieve a single result from the database.
274    ///
275    /// ```
276    /// # use rust_query::{private::doctest::*, IntoExpr};
277    /// # rust_query::private::doctest::get_txn(|txn| {
278    /// let res = txn.query_one("test".into_expr());
279    /// assert_eq!(res, "test");
280    /// # });
281    /// ```
282    ///
283    /// Instead of using [Self::query_one] in a loop, it is better to
284    /// call [Self::query] and return all results at once.
285    pub fn query_one<O: 'static>(&self, val: impl IntoSelect<'static, S, Out = O>) -> O {
286        self.query(|e| e.into_iter(val.into_select()).next().unwrap())
287    }
288}
289
290impl<S: 'static> Transaction<S> {
291    /// Try inserting a value into the database.
292    ///
293    /// Returns [Ok] with a reference to the new inserted value or an [Err] with conflict information.
294    /// The type of conflict information depends on the number of unique constraints on the table:
295    /// - 0 unique constraints => [Infallible]
296    /// - 1 unique constraint => [Expr] reference to the conflicting table row.
297    /// - 2+ unique constraints => `()` no further information is provided.
298    ///
299    /// ```
300    /// # use rust_query::{private::doctest::*, IntoExpr};
301    /// # rust_query::private::doctest::get_txn(|mut txn| {
302    /// let res = txn.insert(User {
303    ///     name: "Bob",
304    /// });
305    /// assert!(res.is_ok());
306    /// let res = txn.insert(User {
307    ///     name: "Bob",
308    /// });
309    /// assert!(res.is_err(), "there is a unique constraint on the name");
310    /// # });
311    /// ```
312    pub fn insert<T: Table<Schema = S>>(
313        &mut self,
314        val: impl TableInsert<T = T>,
315    ) -> Result<TableRow<T>, T::Conflict> {
316        try_insert_private(T::NAME.into_table_ref(), None, val.into_insert())
317    }
318
319    /// This is a convenience function to make using [Transaction::insert]
320    /// easier for tables without unique constraints.
321    ///
322    /// The new row is added to the table and the row reference is returned.
323    pub fn insert_ok<T: Table<Schema = S, Conflict = Infallible>>(
324        &mut self,
325        val: impl TableInsert<T = T>,
326    ) -> TableRow<T> {
327        let Ok(row) = self.insert(val);
328        row
329    }
330
331    /// This is a convenience function to make using [Transaction::insert]
332    /// easier for tables with exactly one unique constraints.
333    ///
334    /// The new row is inserted and the reference to the row is returned OR
335    /// an existing row is found which conflicts with the new row and a reference
336    /// to the conflicting row is returned.
337    ///
338    /// ```
339    /// # use rust_query::{private::doctest::*, IntoExpr};
340    /// # rust_query::private::doctest::get_txn(|mut txn| {
341    /// let bob = txn.insert(User {
342    ///     name: "Bob",
343    /// }).unwrap();
344    /// let bob2 = txn.find_or_insert(User {
345    ///     name: "Bob", // this will conflict with the existing row.
346    /// });
347    /// assert_eq!(bob, bob2);
348    /// # });
349    /// ```
350    pub fn find_or_insert<T: Table<Schema = S, Conflict = TableRow<T>>>(
351        &mut self,
352        val: impl TableInsert<T = T>,
353    ) -> TableRow<T> {
354        match self.insert(val) {
355            Ok(row) => row,
356            Err(row) => row,
357        }
358    }
359
360    /// Try updating a row in the database to have new column values.
361    ///
362    /// Updating can fail just like [Transaction::insert] because of unique constraint conflicts.
363    /// This happens when the new values are in conflict with an existing different row.
364    ///
365    /// When the update succeeds, this function returns [Ok], when it fails it returns [Err] with one of
366    /// three conflict types:
367    /// - 0 unique constraints => [Infallible]
368    /// - 1 unique constraint => [Expr] reference to the conflicting table row.
369    /// - 2+ unique constraints => `()` no further information is provided.
370    ///
371    /// ```
372    /// # use rust_query::{private::doctest::*, IntoExpr, Update};
373    /// # rust_query::private::doctest::get_txn(|mut txn| {
374    /// let bob = txn.insert(User {
375    ///     name: "Bob",
376    /// }).unwrap();
377    /// txn.update(bob, User {
378    ///     name: Update::set("New Bob"),
379    /// }).unwrap();
380    /// # });
381    /// ```
382    pub fn update<T: Table<Schema = S>>(
383        &mut self,
384        row: impl IntoExpr<'static, S, Typ = T>,
385        val: T::Update,
386    ) -> Result<(), T::Conflict> {
387        let mut id = ValueBuilder::default();
388        let row = row.into_expr();
389        let (id, _) = id.simple_one(DynTypedExpr::erase(&row));
390
391        let val = T::apply_try_update(val, row);
392        let mut reader = Reader::default();
393        T::read(&val, &mut reader);
394        let (col_names, col_exprs): (Vec<_>, Vec<_>) = reader.builder.into_iter().collect();
395
396        let (select, col_fields) = ValueBuilder::default().simple(col_exprs);
397        let cte = CommonTableExpression::new()
398            .query(select)
399            .columns(col_fields.clone())
400            .table_name(Alias::new("cte"))
401            .to_owned();
402        let with_clause = WithClause::new().cte(cte).to_owned();
403
404        let mut update = UpdateStatement::new()
405            .table(("main", T::NAME))
406            .cond_where(Expr::col(("main", T::NAME, T::ID)).in_subquery(id))
407            .to_owned();
408
409        for (name, field) in zip(col_names, col_fields) {
410            let select = SelectStatement::new()
411                .from(Alias::new("cte"))
412                .column(field)
413                .to_owned();
414            let value = sea_query::Expr::SubQuery(
415                None,
416                Box::new(sea_query::SubQueryStatement::SelectStatement(select)),
417            );
418            update.value(Alias::new(name), value);
419        }
420
421        let (query, args) = update.with(with_clause).build_rusqlite(SqliteQueryBuilder);
422
423        TXN.with_borrow(|txn| {
424            let txn = txn.as_ref().unwrap().get();
425
426            let mut stmt = txn.prepare_cached(&query).unwrap();
427            match stmt.execute(&*args.as_params()) {
428                Ok(1) => Ok(()),
429                Ok(n) => panic!("unexpected number of updates: {n}"),
430                Err(rusqlite::Error::SqliteFailure(kind, Some(_val)))
431                    if kind.code == ErrorCode::ConstraintViolation =>
432                {
433                    // val looks like "UNIQUE constraint failed: playlist_track.playlist, playlist_track.track"
434                    Err(T::get_conflict_unchecked(self, &val))
435                }
436                Err(err) => panic!("{err:?}"),
437            }
438        })
439    }
440
441    /// This is a convenience function to use [Transaction::update] for updates
442    /// that can not cause unique constraint violations.
443    ///
444    /// This method can be used for all tables, it just does not allow modifying
445    /// columns that are part of unique constraints.
446    pub fn update_ok<T: Table<Schema = S>>(
447        &mut self,
448        row: impl IntoExpr<'static, S, Typ = T>,
449        val: T::UpdateOk,
450    ) {
451        match self.update(row, T::update_into_try_update(val)) {
452            Ok(val) => val,
453            Err(_) => {
454                unreachable!("update can not fail")
455            }
456        }
457    }
458
459    /// Convert the [Transaction] into a [TransactionWeak] to allow deletions.
460    pub fn downgrade(&'static mut self) -> &'static mut TransactionWeak<S> {
461        // TODO: clean this up
462        Box::leak(Box::new(TransactionWeak { inner: PhantomData }))
463    }
464}
465
466/// This is the weak version of [Transaction].
467///
468/// The reason that it is called `weak` is because [TransactionWeak] can not guarantee
469/// that [TableRow]s prove the existence of their particular row.
470///
471/// [TransactionWeak] is useful because it allowes deleting rows.
472pub struct TransactionWeak<S> {
473    inner: PhantomData<Transaction<S>>,
474}
475
476impl<S: Schema> TransactionWeak<S> {
477    /// Try to delete a row from the database.
478    ///
479    /// This will return an [Err] if there is a row that references the row that is being deleted.
480    /// When this method returns [Ok] it will contain a [bool] that is either
481    /// - `true` if the row was just deleted.
482    /// - `false` if the row was deleted previously in this transaction.
483    pub fn delete<T: Table<Schema = S>>(&mut self, val: TableRow<T>) -> Result<bool, T::Referer> {
484        let schema = crate::hash::Schema::new::<S>();
485
486        // This is a manual check that foreign key constraints are not violated.
487        // We do this manually because we don't want to enabled foreign key constraints for the whole
488        // transaction (and is not possible to enable for part of a transaction).
489        let mut checks = vec![];
490        for (table_name, table) in &*schema.tables {
491            for col in table.columns.iter().filter_map(|col| {
492                col.fk
493                    .as_ref()
494                    .is_some_and(|(t, c)| t == T::NAME && c == T::ID)
495                    .then_some(&col.name)
496            }) {
497                let stmt = SelectStatement::new()
498                    .expr(
499                        val.in_subquery(
500                            SelectStatement::new()
501                                .from(Alias::new(table_name))
502                                .column(Alias::new(col))
503                                .take(),
504                        ),
505                    )
506                    .take();
507                checks.push(stmt.build_rusqlite(SqliteQueryBuilder));
508            }
509        }
510
511        let stmt = DeleteStatement::new()
512            .from_table(("main", T::NAME))
513            .cond_where(Expr::col(("main", T::NAME, T::ID)).eq(val.inner.idx))
514            .take();
515
516        let (query, args) = stmt.build_rusqlite(SqliteQueryBuilder);
517
518        TXN.with_borrow(|txn| {
519            let txn = txn.as_ref().unwrap().get();
520
521            for (query, args) in checks {
522                let mut stmt = txn.prepare_cached(&query).unwrap();
523                match stmt.query_one(&*args.as_params(), |r| r.get(0)) {
524                    Ok(true) => return Err(T::get_referer_unchecked()),
525                    Ok(false) => {}
526                    Err(err) => panic!("{err:?}"),
527                }
528            }
529
530            let mut stmt = txn.prepare_cached(&query).unwrap();
531            match stmt.execute(&*args.as_params()) {
532                Ok(0) => Ok(false),
533                Ok(1) => Ok(true),
534                Ok(n) => {
535                    panic!("unexpected number of deletes {n}")
536                }
537                Err(err) => panic!("{err:?}"),
538            }
539        })
540    }
541
542    /// Delete a row from the database.
543    ///
544    /// This is the infallible version of [TransactionWeak::delete].
545    ///
546    /// To be able to use this method you have to mark the table as `#[no_reference]` in the schema.
547    pub fn delete_ok<T: Table<Referer = Infallible, Schema = S>>(
548        &mut self,
549        val: TableRow<T>,
550    ) -> bool {
551        let Ok(res) = self.delete(val);
552        res
553    }
554
555    /// This allows you to do (almost) anything you want with the internal [rusqlite::Transaction].
556    ///
557    /// Note that there are some things that you should not do with the transaction, such as:
558    /// - Changes to the schema, these will result in a panic as described in [Database].
559    /// - Making changes that violate foreign-key constraints (see below).
560    ///
561    /// Sadly it is not possible to enable (or disable) the `foreign_keys` pragma during a transaction.
562    /// This means that whether this pragma is enabled depends on which [crate::migrate::ForeignKeys]
563    /// option is used and can not be changed.
564    pub fn rusqlite_transaction<R>(&mut self, f: impl FnOnce(&rusqlite::Transaction) -> R) -> R {
565        TXN.with_borrow(|txn| f(txn.as_ref().unwrap().get()))
566    }
567}
568
569pub fn try_insert_private<T: Table>(
570    table: sea_query::TableRef,
571    idx: Option<i64>,
572    val: T::Insert,
573) -> Result<TableRow<T>, T::Conflict> {
574    let mut reader = Reader::default();
575    T::read(&val, &mut reader);
576    if let Some(idx) = idx {
577        reader.col(T::ID, idx);
578    }
579    let (col_names, col_exprs): (Vec<_>, Vec<_>) = reader.builder.into_iter().collect();
580    let is_empty = col_names.is_empty();
581
582    let (select, _) = ValueBuilder::default().simple(col_exprs);
583
584    let mut insert = InsertStatement::new();
585    insert.into_table(table);
586    insert.columns(col_names.into_iter().map(Alias::new));
587    if is_empty {
588        // select always has at least one column, so we leave it out when there are no columns
589        insert.or_default_values();
590    } else {
591        insert.select_from(select).unwrap();
592    }
593    insert.returning_col(T::ID);
594
595    let (sql, values) = insert.build_rusqlite(SqliteQueryBuilder);
596
597    TXN.with_borrow(|txn| {
598        let txn = txn.as_ref().unwrap().get();
599        track_stmt(txn, &sql, &values);
600
601        let mut statement = txn.prepare_cached(&sql).unwrap();
602        let mut res = statement
603            .query_map(&*values.as_params(), |row| {
604                Ok(TableRow::<T>::from_sql(row.get_ref(T::ID)?)?)
605            })
606            .unwrap();
607
608        match res.next().unwrap() {
609            Ok(id) => Ok(id),
610            Err(rusqlite::Error::SqliteFailure(kind, Some(_val)))
611                if kind.code == ErrorCode::ConstraintViolation =>
612            {
613                // val looks like "UNIQUE constraint failed: playlist_track.playlist, playlist_track.track"
614                Err(T::get_conflict_unchecked(&Transaction::new(), &val))
615            }
616            Err(err) => panic!("{err:?}"),
617        }
618    })
619}