Skip to main content

rust_query/
error.rs

1use std::{borrow::Cow, convert::Infallible, fmt::Debug, marker::PhantomData};
2
3use sea_query::{ExprTrait, SelectStatement, SqliteQueryBuilder};
4use sea_query_rusqlite::RusqliteBinder;
5
6use crate::{Table, TableRow, db::TableRowInner};
7
8/// Error type that is used by [crate::Transaction::insert] and [crate::Mutable::unique] when
9/// there are at least two unique constraints.
10///
11/// The source of the error is the message received from sqlite. It contains the column
12/// names that were conflicted.
13pub struct Conflict<T: Table> {
14    _p: PhantomData<T>,
15    msg: Box<dyn std::error::Error>,
16}
17
18#[cfg_attr(test, mutants::skip)]
19impl<T: Table> Debug for Conflict<T> {
20    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21        f.debug_struct("Conflict")
22            .field("table", &T::NAME)
23            .field("msg", &self.msg)
24            .finish()
25    }
26}
27
28impl<T: Table> std::fmt::Display for Conflict<T> {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        write!(f, "Conflict in table `{}`", T::NAME)
31    }
32}
33
34impl<T: Table> std::error::Error for Conflict<T> {
35    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
36        Some(&*self.msg)
37    }
38}
39
40impl<T: Table<Conflict = Self>> std::fmt::Display for TableRow<T> {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        let unique_columns = get_unique_columns::<T>().join(", ");
43        write!(
44            f,
45            "Row exists in table `{}`, with unique constraint on ({})",
46            T::NAME,
47            unique_columns
48        )
49    }
50}
51
52impl<T: Table<Conflict = Self>> std::error::Error for TableRow<T> {}
53
54pub(crate) trait FromConflict {
55    fn from_conflict(
56        txn: &rusqlite::Transaction<'_>,
57        table: sea_query::DynIden,
58        cols: Vec<(&'static str, sea_query::Expr)>,
59        msg: String,
60    ) -> Self;
61}
62
63impl FromConflict for Infallible {
64    fn from_conflict(
65        _txn: &rusqlite::Transaction<'_>,
66        _table: sea_query::DynIden,
67        _cols: Vec<(&'static str, sea_query::Expr)>,
68        _msg: String,
69    ) -> Self {
70        unreachable!()
71    }
72}
73
74impl<T: Table<Conflict = Self>> FromConflict for Conflict<T> {
75    fn from_conflict(
76        _txn: &rusqlite::Transaction<'_>,
77        _table: sea_query::DynIden,
78        _cols: Vec<(&'static str, sea_query::Expr)>,
79        msg: String,
80    ) -> Self {
81        Self {
82            _p: PhantomData,
83            msg: msg.into(),
84        }
85    }
86}
87
88pub(crate) fn get_unique_columns<T: Table<Conflict = TableRow<T>>>() -> Vec<Cow<'static, str>> {
89    // TODO: optimize to const
90    let schema = crate::schema::from_macro::Table::new::<T>();
91    let [index] = schema
92        .indices
93        .into_iter()
94        .filter(|x| x.def.unique)
95        .collect::<Vec<_>>()
96        .try_into()
97        .unwrap();
98    index.def.columns
99}
100
101impl<T: Table<Conflict = Self>> FromConflict for TableRow<T> {
102    fn from_conflict(
103        txn: &rusqlite::Transaction<'_>,
104        table: sea_query::DynIden,
105        mut cols: Vec<(&'static str, sea_query::Expr)>,
106        _msg: String,
107    ) -> Self {
108        let unique_columns = get_unique_columns::<T>();
109
110        cols.retain(|(name, _val)| unique_columns.contains(&Cow::Borrowed(*name)));
111        assert_eq!(cols.len(), unique_columns.len());
112
113        let mut select = SelectStatement::new()
114            .from(("main", table.clone()))
115            .column((table.clone(), T::ID))
116            .take();
117
118        for (col, val) in cols {
119            select.cond_where(val.equals((table.clone(), col)));
120        }
121
122        let (query, args) = select.build_rusqlite(SqliteQueryBuilder);
123
124        let mut stmt = txn.prepare_cached(&query).unwrap();
125        stmt.query_one(&*args.as_params(), |row| {
126            Ok(Self {
127                _local: PhantomData,
128                inner: TableRowInner {
129                    _p: PhantomData,
130                    idx: row.get(0)?,
131                },
132            })
133        })
134        .unwrap()
135    }
136}