1use std::{cell::RefCell, convert::Infallible, marker::PhantomData, sync::atomic::AtomicI64};
2
3use rusqlite::ErrorCode;
4use sea_query::{
5 Alias, DeleteStatement, Expr, ExprTrait, InsertStatement, IntoIden, SelectStatement,
6 SqliteQueryBuilder, UpdateStatement,
7};
8use sea_query_rusqlite::RusqliteBinder;
9use self_cell::{MutBorrow, self_cell};
10
11use crate::{
12 IntoExpr, IntoSelect, Table, TableRow,
13 error::FromConflict,
14 migrate::{Schema, check_schema, schema_version, user_version},
15 migration::Config,
16 mutable::Mutable,
17 pool::Pool,
18 private::{IntoJoinable, Reader},
19 query::{OwnedRows, Query, track_stmt},
20 rows::Rows,
21 value::{DbTyp, OptTable},
22};
23
24pub struct Database<S> {
37 pub(crate) manager: Pool,
38 pub(crate) schema_version: AtomicI64,
39 pub(crate) schema: PhantomData<S>,
40 pub(crate) mut_lock: parking_lot::FairMutex<()>,
41}
42
43impl<S: Schema> Database<S> {
44 pub fn new(config: Config) -> Self {
49 let Some(m) = Self::migrator(config) else {
50 panic!("schema version {}, but got an older version", S::VERSION)
51 };
52 let Some(m) = m.finish() else {
53 panic!("schema version {}, but got a new version", S::VERSION)
54 };
55 m
56 }
57}
58
59use rusqlite::Connection;
60type RTransaction<'x> = Option<rusqlite::Transaction<'x>>;
61
62self_cell!(
63 pub struct OwnedTransaction {
64 owner: MutBorrow<Connection>,
65
66 #[covariant]
67 dependent: RTransaction,
68 }
69);
70
71unsafe impl Send for OwnedTransaction {}
77assert_not_impl_any! {OwnedTransaction: Sync}
78
79thread_local! {
80 pub(crate) static TXN: RefCell<Option<TransactionWithRows>> = const { RefCell::new(None) };
81}
82
83impl OwnedTransaction {
84 pub(crate) fn get(&self) -> &rusqlite::Transaction<'_> {
85 self.borrow_dependent().as_ref().unwrap()
86 }
87
88 pub(crate) fn with(
89 mut self,
90 f: impl FnOnce(rusqlite::Transaction<'_>),
91 ) -> rusqlite::Connection {
92 self.with_dependent_mut(|_, b| f(b.take().unwrap()));
93 self.into_owner().into_inner()
94 }
95}
96
97type OwnedRowsVec<'x> = slab::Slab<OwnedRows<'x>>;
98self_cell!(
99 pub struct TransactionWithRows {
100 owner: OwnedTransaction,
101
102 #[not_covariant]
103 dependent: OwnedRowsVec,
104 }
105);
106
107impl TransactionWithRows {
108 pub(crate) fn new_empty(txn: OwnedTransaction) -> Self {
109 Self::new(txn, |_| slab::Slab::new())
110 }
111
112 pub(crate) fn get(&self) -> &rusqlite::Transaction<'_> {
113 self.borrow_owner().get()
114 }
115}
116
117impl<S: Send + Sync + Schema> Database<S> {
118 #[doc = include_str!("database/transaction.md")]
119 pub fn transaction<R: Send>(&self, f: impl Send + FnOnce(&'static Transaction<S>) -> R) -> R {
120 let res = std::thread::scope(|scope| scope.spawn(|| self.transaction_local(f)).join());
121 match res {
122 Ok(val) => val,
123 Err(payload) => std::panic::resume_unwind(payload),
124 }
125 }
126
127 pub(crate) fn transaction_local<R>(&self, f: impl FnOnce(&'static Transaction<S>) -> R) -> R {
129 let conn = self.manager.pop();
130
131 let owned = OwnedTransaction::new(MutBorrow::new(conn), |conn| {
132 Some(conn.borrow_mut().transaction().unwrap())
133 });
134
135 let res = f(Transaction::new_checked(owned, &self.schema_version));
136
137 let owned = TXN.take().unwrap().into_owner();
138 self.manager.push(owned.into_owner().into_inner());
139
140 res
141 }
142
143 #[doc = include_str!("database/transaction_mut.md")]
144 pub fn transaction_mut<O: Send, E: Send>(
145 &self,
146 f: impl Send + FnOnce(&'static mut Transaction<S>) -> Result<O, E>,
147 ) -> Result<O, E> {
148 let join_res =
149 std::thread::scope(|scope| scope.spawn(|| self.transaction_mut_local(f)).join());
150
151 match join_res {
152 Ok(val) => val,
153 Err(payload) => std::panic::resume_unwind(payload),
154 }
155 }
156
157 pub(crate) fn transaction_mut_local<O, E>(
158 &self,
159 f: impl FnOnce(&'static mut Transaction<S>) -> Result<O, E>,
160 ) -> Result<O, E> {
161 let guard = self.mut_lock.lock();
165
166 let conn = self.manager.pop();
167
168 let owned = OwnedTransaction::new(MutBorrow::new(conn), |conn| {
169 let txn = conn
170 .borrow_mut()
171 .transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)
172 .unwrap();
173 Some(txn)
174 });
175 let res = f(Transaction::new_checked(owned, &self.schema_version));
177
178 drop(guard);
181
182 let owned = TXN.take().unwrap().into_owner();
183
184 let conn = if res.is_ok() {
185 owned.with(|x| x.commit().unwrap())
186 } else {
187 owned.with(|x| x.rollback().unwrap())
188 };
189 self.manager.push(conn);
190
191 res
192 }
193
194 #[doc = include_str!("database/transaction_mut_ok.md")]
195 pub fn transaction_mut_ok<R: Send>(
196 &self,
197 f: impl Send + FnOnce(&'static mut Transaction<S>) -> R,
198 ) -> R {
199 self.transaction_mut(|txn| Ok::<R, Infallible>(f(txn)))
200 .unwrap()
201 }
202
203 pub fn rusqlite_connection(&self) -> rusqlite::Connection {
216 let conn = self.manager.pop();
217 conn.pragma_update(None, "foreign_keys", "ON").unwrap();
218 conn
219 }
220}
221
222pub struct Transaction<S> {
228 pub(crate) _p2: PhantomData<S>,
229 pub(crate) _local: PhantomData<*const ()>,
230}
231
232impl<S> Transaction<S> {
233 pub(crate) fn new() -> Self {
234 Self {
235 _p2: PhantomData,
236 _local: PhantomData,
237 }
238 }
239
240 pub(crate) fn copy(&self) -> Self {
241 Self::new()
242 }
243
244 pub(crate) fn new_ref() -> &'static mut Self {
245 Box::leak(Box::new(Self::new()))
247 }
248}
249
250impl<S: Schema> Transaction<S> {
251 pub(crate) fn new_checked(txn: OwnedTransaction, expected: &AtomicI64) -> &'static mut Self {
253 let schema_version = schema_version(txn.get());
254 if schema_version != expected.load(std::sync::atomic::Ordering::Relaxed) {
257 if user_version(txn.get()).unwrap() != S::VERSION {
258 panic!("The database user_version changed unexpectedly")
259 }
260
261 TXN.set(Some(TransactionWithRows::new_empty(txn)));
262 check_schema::<S>(Self::new_ref());
263 expected.store(schema_version, std::sync::atomic::Ordering::Relaxed);
264 } else {
265 TXN.set(Some(TransactionWithRows::new_empty(txn)));
266 }
267
268 const {
269 assert!(size_of::<Self>() == 0);
270 }
271 Self::new_ref()
272 }
273}
274
275impl<S> Transaction<S> {
276 pub fn query<'t, R>(&'t self, f: impl FnOnce(&mut Query<'t, '_, S>) -> R) -> R {
289 let q = Rows {
294 phantom: PhantomData,
295 ast: Default::default(),
296 _p: PhantomData,
297 };
298 f(&mut Query {
299 q,
300 phantom: PhantomData,
301 })
302 }
303
304 pub fn query_one<O: 'static>(&self, val: impl IntoSelect<'static, S, Out = O>) -> O {
317 let mut query = self.query(|e| e.into_iter(val.into_select()));
318 let res = query.next().unwrap();
319 debug_assert!(query.next().is_none(), "query should return one row");
320 res
321 }
322
323 pub fn lazy<'t, T: OptTable<Schema = S>>(
331 &'t self,
332 val: impl IntoExpr<'static, S, Typ = T>,
333 ) -> T::Lazy<'t> {
334 T::out_to_lazy(self.query_one(val.into_expr()))
335 }
336
337 pub fn lazy_iter<'t, T: Table<Schema = S>>(
341 &'t self,
342 val: impl IntoJoinable<'static, S, Typ = TableRow<T>>,
343 ) -> LazyIter<'t, T> {
344 let val = val.into_joinable();
345 self.query(|rows| {
346 let table = rows.join(val);
347 LazyIter {
348 txn: self,
349 iter: rows.into_iter(table),
350 }
351 })
352 }
353
354 pub fn mutable<'t, T: OptTable<Schema = S>>(
358 &'t mut self,
359 val: impl IntoExpr<'static, S, Typ = T>,
360 ) -> T::Mutable<'t> {
361 let x = self.query_one(T::select_opt_mutable(val.into_expr()));
362 T::into_mutable(x)
363 }
364
365 pub fn mutable_vec<'t, T: Table<Schema = S>>(
369 &'t mut self,
370 val: impl IntoJoinable<'static, S, Typ = TableRow<T>>,
371 ) -> Vec<Mutable<'t, T>> {
372 let val = val.into_joinable();
373 self.query(|rows| {
374 let val = rows.join(val);
375 rows.into_vec((T::into_select(val.clone()), val))
376 .into_iter()
377 .map(TableRow::<T>::into_mutable)
378 .collect()
379 })
380 }
381}
382
383pub struct LazyIter<'t, T: Table> {
384 txn: &'t Transaction<T::Schema>,
385 iter: crate::query::Iter<'t, TableRow<T>>,
386}
387
388impl<'t, T: Table> Iterator for LazyIter<'t, T> {
389 type Item = crate::Lazy<'t, T>;
390
391 fn next(&mut self) -> Option<Self::Item> {
392 self.iter.next().map(|x| self.txn.lazy(x))
393 }
394}
395
396impl<S: 'static> Transaction<S> {
397 pub fn insert<T: Table<Schema = S>>(&mut self, val: T) -> Result<TableRow<T>, T::Conflict> {
419 try_insert_private(T::NAME.into_iden(), None, val)
420 }
421
422 pub fn insert_ok<T: Table<Schema = S, Conflict = Infallible>>(
427 &mut self,
428 val: T,
429 ) -> TableRow<T> {
430 let Ok(row) = self.insert(val);
431 row
432 }
433
434 pub fn find_or_insert<T: Table<Schema = S, Conflict = TableRow<T>>>(
454 &mut self,
455 val: T,
456 ) -> TableRow<T> {
457 match self.insert(val) {
458 Ok(row) => row,
459 Err(row) => row,
460 }
461 }
462
463 pub(crate) fn update<T: Table<Schema = S>>(
464 &mut self,
465 row: TableRow<T>,
466 val: T::Mutable,
467 ) -> Result<(), T::Conflict> {
468 let val = T::mutable_into_insert(val);
469 let mut reader = Reader::default();
470 T::read(&val, &mut reader);
471
472 let (query, args) = UpdateStatement::new()
473 .table(("main", T::NAME))
474 .values(reader.builder.clone())
475 .cond_where(Expr::col((T::NAME, T::ID)).eq(row.inner.idx))
476 .build_rusqlite(SqliteQueryBuilder);
477
478 let res = TXN.with_borrow(|txn| {
479 let txn = txn.as_ref().unwrap().get();
480
481 let mut stmt = txn.prepare_cached(&query).unwrap();
482 stmt.execute(&*args.as_params())
483 });
484
485 match res {
486 Ok(1) => Ok(()),
487 Ok(n) => panic!("unexpected number of updates: {n}"),
488 Err(rusqlite::Error::SqliteFailure(kind, Some(msg)))
489 if kind.code == ErrorCode::ConstraintViolation =>
490 {
491 let res = TXN.with_borrow(|txn| {
493 let txn = txn.as_ref().unwrap().get();
494 <T::Conflict as FromConflict>::from_conflict(
495 txn,
496 T::NAME.into_iden(),
497 reader.builder,
498 msg,
499 )
500 });
501 Err(res)
502 }
503 Err(err) => panic!("{err:?}"),
504 }
505 }
506
507 pub fn downgrade(&'static mut self) -> &'static mut TransactionWeak<S> {
509 Box::leak(Box::new(TransactionWeak { inner: PhantomData }))
511 }
512}
513
514pub struct TransactionWeak<S> {
521 inner: PhantomData<Transaction<S>>,
522}
523
524impl<S: Schema> TransactionWeak<S> {
525 pub fn delete<T: Table<Schema = S>>(&mut self, val: TableRow<T>) -> Result<bool, T::Referer> {
532 let schema = crate::schema::from_macro::Schema::new::<S>();
533
534 let mut checks = vec![];
538 for (&table_name, table) in &schema.tables {
539 for col in table.columns.iter().filter_map(|(col_name, col)| {
540 let col = &col.def;
541 col.fk
542 .as_ref()
543 .is_some_and(|(t, c)| t == T::NAME && c == T::ID)
544 .then_some(col_name)
545 }) {
546 let stmt = SelectStatement::new()
547 .expr(
548 val.inner.idx.in_subquery(
549 SelectStatement::new()
550 .from(table_name)
551 .column(Alias::new(col))
552 .take(),
553 ),
554 )
555 .take();
556 checks.push(stmt.build_rusqlite(SqliteQueryBuilder));
557 }
558 }
559
560 let stmt = DeleteStatement::new()
561 .from_table(("main", T::NAME))
562 .cond_where(Expr::col(("main", T::NAME, T::ID)).eq(val.inner.idx))
563 .take();
564
565 let (query, args) = stmt.build_rusqlite(SqliteQueryBuilder);
566
567 TXN.with_borrow(|txn| {
568 let txn = txn.as_ref().unwrap().get();
569
570 for (query, args) in checks {
571 let mut stmt = txn.prepare_cached(&query).unwrap();
572 match stmt.query_one(&*args.as_params(), |r| r.get(0)) {
573 Ok(true) => return Err(T::get_referer_unchecked()),
574 Ok(false) => {}
575 Err(err) => panic!("{err:?}"),
576 }
577 }
578
579 let mut stmt = txn.prepare_cached(&query).unwrap();
580 match stmt.execute(&*args.as_params()) {
581 Ok(0) => Ok(false),
582 Ok(1) => Ok(true),
583 Ok(n) => {
584 panic!("unexpected number of deletes {n}")
585 }
586 Err(err) => panic!("{err:?}"),
587 }
588 })
589 }
590
591 pub fn delete_ok<T: Table<Referer = Infallible, Schema = S>>(
597 &mut self,
598 val: TableRow<T>,
599 ) -> bool {
600 let Ok(res) = self.delete(val);
601 res
602 }
603
604 pub fn rusqlite_transaction<R>(&mut self, f: impl FnOnce(&rusqlite::Transaction) -> R) -> R {
614 TXN.with_borrow(|txn| f(txn.as_ref().unwrap().get()))
615 }
616}
617
618pub fn try_insert_private<T: Table>(
619 table: sea_query::DynIden,
620 idx: Option<i64>,
621 val: T,
622) -> Result<TableRow<T>, T::Conflict> {
623 let mut reader = Reader::default();
624 T::read(&val, &mut reader);
625 if let Some(idx) = idx {
626 reader.col::<i64>(T::ID, idx);
627 }
628 let (col_names, col_exprs): (Vec<_>, Vec<_>) = reader.builder.clone().into_iter().collect();
629 let is_empty = col_names.is_empty();
630
631 let mut insert = InsertStatement::new();
632 insert.into_table(("main", table.clone()));
633 insert.columns(col_names);
634 if is_empty {
635 insert.or_default_values();
637 } else {
638 insert.values(col_exprs).unwrap();
639 }
640 insert.returning_col(T::ID);
641
642 let (sql, values) = insert.build_rusqlite(SqliteQueryBuilder);
643
644 let res = TXN.with_borrow(|txn| {
645 let txn = txn.as_ref().unwrap().get();
646 track_stmt(txn, &sql, &values);
647
648 let mut statement = txn.prepare_cached(&sql).unwrap();
649 let mut res = statement
650 .query_map(&*values.as_params(), |row| {
651 Ok(TableRow::<T>::from_sql(row.get_ref(T::ID)?)?)
652 })
653 .unwrap();
654
655 res.next().unwrap()
656 });
657
658 match res {
659 Ok(id) => {
660 if let Some(idx) = idx {
661 assert_eq!(idx, id.inner.idx);
662 }
663 Ok(id)
664 }
665 Err(rusqlite::Error::SqliteFailure(kind, Some(msg)))
666 if kind.code == ErrorCode::ConstraintViolation =>
667 {
668 let res = TXN.with_borrow(|txn| {
670 let txn = txn.as_ref().unwrap().get();
671 <T::Conflict as FromConflict>::from_conflict(txn, table, reader.builder, msg)
672 });
673 Err(res)
674 }
675 Err(err) => panic!("{err:?}"),
676 }
677}