Skip to main content

rust_query/
error.rs

1use std::{borrow::Cow, convert::Infallible, fmt::Debug, marker::PhantomData, rc::Rc};
2
3use crate::{
4    Table, TableRow,
5    db::TableRowInner,
6    lower::{
7        self,
8        emit::{self, IndexMap},
9        ord_rc::OrdRc,
10    },
11};
12
13/// Error type that is used by [crate::Transaction::insert] and [crate::Mutable::unique] when
14/// there are at least two unique constraints.
15///
16/// The source of the error is the message received from sqlite. It contains the column
17/// names that were conflicted.
18pub struct Conflict<T: Table> {
19    _p: PhantomData<T>,
20    msg: Box<dyn std::error::Error>,
21}
22
23#[cfg_attr(feature = "__mutants", mutants::skip)]
24impl<T: Table> Debug for Conflict<T> {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        f.debug_struct("Conflict")
27            .field("table", &T::NAME)
28            .field("msg", &self.msg)
29            .finish()
30    }
31}
32
33impl<T: Table> std::fmt::Display for Conflict<T> {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        write!(f, "Conflict in table `{}`", T::NAME)
36    }
37}
38
39impl<T: Table> std::error::Error for Conflict<T> {
40    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
41        Some(&*self.msg)
42    }
43}
44
45impl<T: Table<Conflict = Self>> std::fmt::Display for TableRow<T> {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        let unique_columns = get_unique_columns::<T>().join(", ");
48        write!(
49            f,
50            "Row exists in table `{}`, with unique constraint on ({})",
51            T::NAME,
52            unique_columns
53        )
54    }
55}
56
57impl<T: Table<Conflict = Self>> std::error::Error for TableRow<T> {}
58
59pub(crate) trait FromConflict {
60    fn from_conflict(
61        txn: &rusqlite::Transaction<'_>,
62        table: lower::JoinableTable,
63        cols: Vec<(&'static str, OrdRc<rusqlite::types::Value>)>,
64        msg: String,
65    ) -> Self;
66}
67
68impl FromConflict for Infallible {
69    fn from_conflict(
70        _txn: &rusqlite::Transaction<'_>,
71        _table: lower::JoinableTable,
72        _cols: Vec<(&'static str, OrdRc<rusqlite::types::Value>)>,
73        _msg: String,
74    ) -> Self {
75        unreachable!()
76    }
77}
78
79impl<T: Table<Conflict = Self>> FromConflict for Conflict<T> {
80    fn from_conflict(
81        _txn: &rusqlite::Transaction<'_>,
82        _table: lower::JoinableTable,
83        _cols: Vec<(&'static str, OrdRc<rusqlite::types::Value>)>,
84        msg: String,
85    ) -> Self {
86        Self {
87            _p: PhantomData,
88            msg: msg.into(),
89        }
90    }
91}
92
93pub(crate) fn get_unique_columns<T: Table<Conflict = TableRow<T>>>() -> Vec<Cow<'static, str>> {
94    // TODO: optimize to const
95    let schema = crate::schema::from_macro::Table::new::<T>();
96    let [index] = schema
97        .indices
98        .into_iter()
99        .filter(|x| x.def.unique)
100        .collect::<Vec<_>>()
101        .try_into()
102        .unwrap();
103    index.def.columns
104}
105
106impl<T: Table<Conflict = Self>> FromConflict for TableRow<T> {
107    fn from_conflict(
108        txn: &rusqlite::Transaction<'_>,
109        table: lower::JoinableTable,
110        mut cols: Vec<(&'static str, OrdRc<rusqlite::types::Value>)>,
111        _msg: String,
112    ) -> Self {
113        let unique_columns = get_unique_columns::<T>();
114
115        cols.retain(|(name, _val)| unique_columns.contains(&Cow::Borrowed(*name)));
116        assert_eq!(cols.len(), unique_columns.len());
117
118        let mut rows = lower::Rows::default();
119        let join = rows.join(table);
120
121        for (col, val) in cols {
122            let table_val = Rc::new(lower::Expr::RowIndex(
123                lower::RowLike::Join(join.clone()),
124                col,
125            ));
126            let val = Rc::new(lower::Expr::Parameter(val));
127            rows.filter(Rc::new(lower::Expr::Infix(val, "=", table_val)));
128        }
129
130        let mut selected = IndexMap::default();
131
132        let id = Rc::new(lower::Expr::RowIndex(lower::RowLike::Join(join), T::ID));
133        let (idx, _) = selected.insert_with(id, |_| ());
134
135        let mut stmt = emit::Stmt::default();
136        let forwarded = rows.emit(&mut stmt, false, &selected);
137        assert!(forwarded.is_empty());
138
139        let mut cached = txn.prepare_cached(&stmt.sql).unwrap();
140        cached
141            .query_one(rusqlite::params_from_iter(stmt.params), |row| {
142                Ok(Self {
143                    _local: PhantomData,
144                    inner: TableRowInner {
145                        _p: PhantomData,
146                        idx: row.get(&*format!("s{idx}"))?,
147                    },
148                })
149            })
150            .unwrap()
151    }
152}