Skip to main content

rustio_core/
orm.rs

1use sqlx::Row as _;
2use sqlx::sqlite::{SqlitePool, SqlitePoolOptions, SqliteRow};
3
4use crate::error::Error;
5
6#[derive(Clone)]
7pub struct Db {
8    pool: SqlitePool,
9}
10
11impl Db {
12    pub async fn connect(url: &str) -> Result<Self, Error> {
13        let pool = SqlitePool::connect(url).await?;
14        Ok(Self { pool })
15    }
16
17    pub async fn memory() -> Result<Self, Error> {
18        let pool = SqlitePoolOptions::new()
19            .max_connections(1)
20            .connect("sqlite::memory:")
21            .await?;
22        Ok(Self { pool })
23    }
24
25    pub async fn execute(&self, sql: &str) -> Result<(), Error> {
26        sqlx::query(sql).execute(&self.pool).await?;
27        Ok(())
28    }
29
30    pub(crate) fn pool(&self) -> &SqlitePool {
31        &self.pool
32    }
33}
34
35pub struct Row<'a> {
36    inner: &'a SqliteRow,
37}
38
39impl<'a> Row<'a> {
40    pub(crate) fn new(inner: &'a SqliteRow) -> Self {
41        Self { inner }
42    }
43
44    pub fn get_i32(&self, name: &str) -> Result<i32, Error> {
45        self.inner.try_get(name).map_err(Error::from)
46    }
47
48    pub fn get_i64(&self, name: &str) -> Result<i64, Error> {
49        self.inner.try_get(name).map_err(Error::from)
50    }
51
52    pub fn get_string(&self, name: &str) -> Result<String, Error> {
53        self.inner.try_get(name).map_err(Error::from)
54    }
55
56    pub fn get_bool(&self, name: &str) -> Result<bool, Error> {
57        self.inner.try_get(name).map_err(Error::from)
58    }
59}
60
61#[derive(Debug)]
62pub enum Value {
63    I32(i32),
64    I64(i64),
65    String(String),
66    Bool(bool),
67    Null,
68}
69
70impl From<i32> for Value {
71    fn from(v: i32) -> Self {
72        Value::I32(v)
73    }
74}
75
76impl From<i64> for Value {
77    fn from(v: i64) -> Self {
78        Value::I64(v)
79    }
80}
81
82impl From<String> for Value {
83    fn from(v: String) -> Self {
84        Value::String(v)
85    }
86}
87
88impl From<&str> for Value {
89    fn from(v: &str) -> Self {
90        Value::String(v.to_owned())
91    }
92}
93
94impl From<bool> for Value {
95    fn from(v: bool) -> Self {
96        Value::Bool(v)
97    }
98}
99
100pub trait Model: Sized + Send + Sync + Unpin + 'static {
101    const TABLE: &'static str;
102    const COLUMNS: &'static [&'static str];
103    const INSERT_COLUMNS: &'static [&'static str];
104
105    fn id(&self) -> i64;
106    fn from_row(row: Row<'_>) -> Result<Self, Error>;
107    fn insert_values(&self) -> Vec<Value>;
108
109    fn find(db: &Db, id: i64) -> impl std::future::Future<Output = Result<Option<Self>, Error>> + Send
110    where
111        Self: Send,
112    {
113        async move {
114            let sql = format!(
115                "SELECT {} FROM {} WHERE id = ?",
116                Self::COLUMNS.join(", "),
117                Self::TABLE,
118            );
119            let row = sqlx::query(&sql)
120                .bind(id)
121                .fetch_optional(db.pool())
122                .await?;
123            match row {
124                Some(r) => Ok(Some(Self::from_row(Row::new(&r))?)),
125                None => Ok(None),
126            }
127        }
128    }
129
130    fn all(db: &Db) -> impl std::future::Future<Output = Result<Vec<Self>, Error>> + Send {
131        async move {
132            let sql = format!("SELECT {} FROM {}", Self::COLUMNS.join(", "), Self::TABLE);
133            let rows = sqlx::query(&sql).fetch_all(db.pool()).await?;
134            rows.iter()
135                .map(|r| Self::from_row(Row::new(r)))
136                .collect()
137        }
138    }
139
140    fn create<'a>(
141        &'a self,
142        db: &'a Db,
143    ) -> impl std::future::Future<Output = Result<i64, Error>> + Send + 'a {
144        async move {
145            let placeholders = vec!["?"; Self::INSERT_COLUMNS.len()].join(", ");
146            let sql = format!(
147                "INSERT INTO {} ({}) VALUES ({})",
148                Self::TABLE,
149                Self::INSERT_COLUMNS.join(", "),
150                placeholders,
151            );
152            let mut query = sqlx::query(&sql);
153            for v in self.insert_values() {
154                query = bind_value(query, v);
155            }
156            let result = query.execute(db.pool()).await?;
157            Ok(result.last_insert_rowid())
158        }
159    }
160
161    fn update<'a>(
162        &'a self,
163        db: &'a Db,
164    ) -> impl std::future::Future<Output = Result<(), Error>> + Send + 'a {
165        async move {
166            let assignments: Vec<String> = Self::INSERT_COLUMNS
167                .iter()
168                .map(|c| format!("{c} = ?"))
169                .collect();
170            let sql = format!(
171                "UPDATE {} SET {} WHERE id = ?",
172                Self::TABLE,
173                assignments.join(", "),
174            );
175            let mut query = sqlx::query(&sql);
176            for v in self.insert_values() {
177                query = bind_value(query, v);
178            }
179            query = query.bind(self.id());
180            query.execute(db.pool()).await?;
181            Ok(())
182        }
183    }
184
185    fn delete(db: &Db, id: i64) -> impl std::future::Future<Output = Result<(), Error>> + Send {
186        async move {
187            let sql = format!("DELETE FROM {} WHERE id = ?", Self::TABLE);
188            sqlx::query(&sql).bind(id).execute(db.pool()).await?;
189            Ok(())
190        }
191    }
192}
193
194fn bind_value<'q>(
195    query: sqlx::query::Query<'q, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'q>>,
196    value: Value,
197) -> sqlx::query::Query<'q, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'q>> {
198    match value {
199        Value::I32(v) => query.bind(v),
200        Value::I64(v) => query.bind(v),
201        Value::String(v) => query.bind(v),
202        Value::Bool(v) => query.bind(v),
203        Value::Null => query.bind(Option::<i64>::None),
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210
211    #[derive(Debug, PartialEq)]
212    struct User {
213        id: i64,
214        name: String,
215        is_admin: bool,
216    }
217
218    impl Model for User {
219        const TABLE: &'static str = "users";
220        const COLUMNS: &'static [&'static str] = &["id", "name", "is_admin"];
221        const INSERT_COLUMNS: &'static [&'static str] = &["name", "is_admin"];
222
223        fn id(&self) -> i64 {
224            self.id
225        }
226
227        fn from_row(row: Row<'_>) -> Result<Self, Error> {
228            Ok(Self {
229                id: row.get_i64("id")?,
230                name: row.get_string("name")?,
231                is_admin: row.get_bool("is_admin")?,
232            })
233        }
234
235        fn insert_values(&self) -> Vec<Value> {
236            vec![self.name.clone().into(), self.is_admin.into()]
237        }
238    }
239
240    async fn setup() -> Db {
241        let db = Db::memory().await.unwrap();
242        db.execute(
243            "CREATE TABLE users (
244                id INTEGER PRIMARY KEY AUTOINCREMENT,
245                name TEXT NOT NULL,
246                is_admin INTEGER NOT NULL
247            )",
248        )
249        .await
250        .unwrap();
251        db
252    }
253
254    #[tokio::test]
255    async fn create_assigns_new_id_and_find_reads_it_back() {
256        let db = setup().await;
257        let u = User {
258            id: 0,
259            name: "Alice".into(),
260            is_admin: false,
261        };
262        let id = u.create(&db).await.unwrap();
263        assert!(id >= 1);
264        let back = User::find(&db, id).await.unwrap().unwrap();
265        assert_eq!(back.name, "Alice");
266        assert!(!back.is_admin);
267        assert_eq!(back.id, id);
268    }
269
270    #[tokio::test]
271    async fn find_missing_returns_none() {
272        let db = setup().await;
273        assert!(User::find(&db, 42).await.unwrap().is_none());
274    }
275
276    #[tokio::test]
277    async fn all_returns_every_row() {
278        let db = setup().await;
279        User { id: 0, name: "a".into(), is_admin: false }.create(&db).await.unwrap();
280        User { id: 0, name: "b".into(), is_admin: true }.create(&db).await.unwrap();
281        User { id: 0, name: "c".into(), is_admin: false }.create(&db).await.unwrap();
282        let rows = User::all(&db).await.unwrap();
283        assert_eq!(rows.len(), 3);
284        let names: Vec<&str> = rows.iter().map(|u| u.name.as_str()).collect();
285        assert_eq!(names, vec!["a", "b", "c"]);
286    }
287
288    #[tokio::test]
289    async fn update_modifies_row_in_place() {
290        let db = setup().await;
291        let id = User { id: 0, name: "old".into(), is_admin: false }
292            .create(&db)
293            .await
294            .unwrap();
295        let updated = User { id, name: "new".into(), is_admin: true };
296        updated.update(&db).await.unwrap();
297        let back = User::find(&db, id).await.unwrap().unwrap();
298        assert_eq!(back.name, "new");
299        assert!(back.is_admin);
300    }
301
302    #[tokio::test]
303    async fn delete_removes_row() {
304        let db = setup().await;
305        let id = User { id: 0, name: "x".into(), is_admin: false }
306            .create(&db)
307            .await
308            .unwrap();
309        assert!(User::find(&db, id).await.unwrap().is_some());
310        User::delete(&db, id).await.unwrap();
311        assert!(User::find(&db, id).await.unwrap().is_none());
312    }
313
314    #[tokio::test]
315    async fn row_getters_handle_wrong_type_gracefully() {
316        let db = setup().await;
317        User { id: 0, name: "a".into(), is_admin: false }
318            .create(&db)
319            .await
320            .unwrap();
321        let row = sqlx::query("SELECT id, name, is_admin FROM users LIMIT 1")
322            .fetch_one(db.pool())
323            .await
324            .unwrap();
325        let wrapped = Row::new(&row);
326        assert!(wrapped.get_i64("id").is_ok());
327        assert!(wrapped.get_string("nonexistent_column").is_err());
328    }
329}