Skip to main content

rust_query/
error.rs

1use std::{borrow::Cow, convert::Infallible, fmt::Debug, marker::PhantomData};
2
3use sea_query::{ExprTrait, SelectStatement, SqliteQueryBuilder};
4use sea_query_rusqlite::RusqliteBinder;
5
6use crate::{Table, TableRow, db::TableRowInner};
7
8/// Error type that is used by [crate::Transaction::insert] and [crate::Mutable::unique] when
9/// there are at least two unique constraints.
10///
11/// The source of the error is the message received from sqlite. It contains the column
12/// names that were conflicted.
13pub struct Conflict<T: Table> {
14    _p: PhantomData<T>,
15    msg: Box<dyn std::error::Error>,
16}
17
18impl<T: Table> Debug for Conflict<T> {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        f.debug_struct("Conflict")
21            .field("table", &T::NAME)
22            .field("msg", &self.msg)
23            .finish()
24    }
25}
26
27impl<T: Table> std::fmt::Display for Conflict<T> {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        write!(f, "Conflict in table `{}`", T::NAME)
30    }
31}
32
33impl<T: Table> std::error::Error for Conflict<T> {
34    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
35        Some(&*self.msg)
36    }
37}
38
39impl<T: Table<Conflict = Self>> std::fmt::Display for TableRow<T> {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        let unique_columns = get_unique_columns::<T>().join(", ");
42        write!(
43            f,
44            "Row exists in table `{}`, with unique constraint on ({})",
45            T::NAME,
46            unique_columns
47        )
48    }
49}
50
51impl<T: Table<Conflict = Self>> std::error::Error for TableRow<T> {}
52
53pub(crate) trait FromConflict {
54    fn from_conflict(
55        txn: &rusqlite::Transaction<'_>,
56        table: sea_query::DynIden,
57        cols: Vec<(&'static str, sea_query::Expr)>,
58        msg: String,
59    ) -> Self;
60}
61
62impl FromConflict for Infallible {
63    fn from_conflict(
64        _txn: &rusqlite::Transaction<'_>,
65        _table: sea_query::DynIden,
66        _cols: Vec<(&'static str, sea_query::Expr)>,
67        _msg: String,
68    ) -> Self {
69        unreachable!()
70    }
71}
72
73impl<T: Table<Conflict = Self>> FromConflict for Conflict<T> {
74    fn from_conflict(
75        _txn: &rusqlite::Transaction<'_>,
76        _table: sea_query::DynIden,
77        _cols: Vec<(&'static str, sea_query::Expr)>,
78        msg: String,
79    ) -> Self {
80        Self {
81            _p: PhantomData,
82            msg: msg.into(),
83        }
84    }
85}
86
87pub(crate) fn get_unique_columns<T: Table<Conflict = TableRow<T>>>() -> Vec<Cow<'static, str>> {
88    // TODO: optimize to const
89    let schema = crate::schema::from_macro::Table::new::<T>();
90    let [index] = schema
91        .indices
92        .into_iter()
93        .filter(|x| x.def.unique)
94        .collect::<Vec<_>>()
95        .try_into()
96        .unwrap();
97    index.def.columns
98}
99
100impl<T: Table<Conflict = Self>> FromConflict for TableRow<T> {
101    fn from_conflict(
102        txn: &rusqlite::Transaction<'_>,
103        table: sea_query::DynIden,
104        mut cols: Vec<(&'static str, sea_query::Expr)>,
105        _msg: String,
106    ) -> Self {
107        let unique_columns = get_unique_columns::<T>();
108
109        cols.retain(|(name, _val)| unique_columns.contains(&Cow::Borrowed(*name)));
110        assert_eq!(cols.len(), unique_columns.len());
111
112        let mut select = SelectStatement::new()
113            .from(("main", table.clone()))
114            .column((table.clone(), T::ID))
115            .take();
116
117        for (col, val) in cols {
118            select.cond_where(val.equals((table.clone(), col)));
119        }
120
121        let (query, args) = select.build_rusqlite(SqliteQueryBuilder);
122
123        let mut stmt = txn.prepare_cached(&query).unwrap();
124        stmt.query_one(&*args.as_params(), |row| {
125            Ok(Self {
126                _local: PhantomData,
127                inner: TableRowInner {
128                    _p: PhantomData,
129                    idx: row.get(0)?,
130                },
131            })
132        })
133        .unwrap()
134    }
135}