Skip to main content

rustio_core/
orm.rs

1//! SQLite-backed ORM.
2//!
3//! Implement [`Model`] on your struct to get `find / all / create / update /
4//! delete` for free. SQLx is used internally; user code never references it.
5//!
6//! Phase 1 supports `i32`, `i64`, `String`, and `bool` field types; the id
7//! column is required to be `i64`.
8
9use sqlx::sqlite::{SqlitePool, SqlitePoolOptions, SqliteRow};
10use sqlx::Row as _;
11
12use crate::error::Error;
13
14#[derive(Clone)]
15pub struct Db {
16    pool: SqlitePool,
17}
18
19impl Db {
20    pub async fn connect(url: &str) -> Result<Self, Error> {
21        let pool = SqlitePool::connect(url).await?;
22        Ok(Self { pool })
23    }
24
25    pub async fn memory() -> Result<Self, Error> {
26        let pool = SqlitePoolOptions::new()
27            .max_connections(1)
28            .connect("sqlite::memory:")
29            .await?;
30        Ok(Self { pool })
31    }
32
33    pub async fn execute(&self, sql: &str) -> Result<(), Error> {
34        sqlx::query(sql).execute(&self.pool).await?;
35        Ok(())
36    }
37
38    pub(crate) fn pool(&self) -> &SqlitePool {
39        &self.pool
40    }
41}
42
43pub struct Row<'a> {
44    inner: &'a SqliteRow,
45}
46
47impl<'a> Row<'a> {
48    pub(crate) fn new(inner: &'a SqliteRow) -> Self {
49        Self { inner }
50    }
51
52    pub fn get_i32(&self, name: &str) -> Result<i32, Error> {
53        self.inner.try_get(name).map_err(Error::from)
54    }
55
56    pub fn get_i64(&self, name: &str) -> Result<i64, Error> {
57        self.inner.try_get(name).map_err(Error::from)
58    }
59
60    pub fn get_string(&self, name: &str) -> Result<String, Error> {
61        self.inner.try_get(name).map_err(Error::from)
62    }
63
64    pub fn get_bool(&self, name: &str) -> Result<bool, Error> {
65        self.inner.try_get(name).map_err(Error::from)
66    }
67}
68
69#[derive(Debug)]
70pub enum Value {
71    I32(i32),
72    I64(i64),
73    String(String),
74    Bool(bool),
75    Null,
76}
77
78impl From<i32> for Value {
79    fn from(v: i32) -> Self {
80        Value::I32(v)
81    }
82}
83
84impl From<i64> for Value {
85    fn from(v: i64) -> Self {
86        Value::I64(v)
87    }
88}
89
90impl From<String> for Value {
91    fn from(v: String) -> Self {
92        Value::String(v)
93    }
94}
95
96impl From<&str> for Value {
97    fn from(v: &str) -> Self {
98        Value::String(v.to_owned())
99    }
100}
101
102impl From<bool> for Value {
103    fn from(v: bool) -> Self {
104        Value::Bool(v)
105    }
106}
107
108pub trait Model: Sized + Send + Sync + Unpin + 'static {
109    const TABLE: &'static str;
110    const COLUMNS: &'static [&'static str];
111    const INSERT_COLUMNS: &'static [&'static str];
112
113    fn id(&self) -> i64;
114    fn from_row(row: Row<'_>) -> Result<Self, Error>;
115    fn insert_values(&self) -> Vec<Value>;
116
117    fn find(
118        db: &Db,
119        id: i64,
120    ) -> impl std::future::Future<Output = Result<Option<Self>, Error>> + Send
121    where
122        Self: Send,
123    {
124        async move {
125            let sql = format!(
126                "SELECT {} FROM {} WHERE id = ?",
127                Self::COLUMNS.join(", "),
128                Self::TABLE,
129            );
130            let row = sqlx::query(&sql).bind(id).fetch_optional(db.pool()).await?;
131            match row {
132                Some(r) => Ok(Some(Self::from_row(Row::new(&r))?)),
133                None => Ok(None),
134            }
135        }
136    }
137
138    fn all(db: &Db) -> impl std::future::Future<Output = Result<Vec<Self>, Error>> + Send {
139        async move {
140            let sql = format!("SELECT {} FROM {}", Self::COLUMNS.join(", "), Self::TABLE);
141            let rows = sqlx::query(&sql).fetch_all(db.pool()).await?;
142            rows.iter().map(|r| Self::from_row(Row::new(r))).collect()
143        }
144    }
145
146    fn create<'a>(
147        &'a self,
148        db: &'a Db,
149    ) -> impl std::future::Future<Output = Result<i64, Error>> + Send + 'a {
150        async move {
151            let placeholders = vec!["?"; Self::INSERT_COLUMNS.len()].join(", ");
152            let sql = format!(
153                "INSERT INTO {} ({}) VALUES ({})",
154                Self::TABLE,
155                Self::INSERT_COLUMNS.join(", "),
156                placeholders,
157            );
158            let mut query = sqlx::query(&sql);
159            for v in self.insert_values() {
160                query = bind_value(query, v);
161            }
162            let result = query.execute(db.pool()).await?;
163            Ok(result.last_insert_rowid())
164        }
165    }
166
167    fn update<'a>(
168        &'a self,
169        db: &'a Db,
170    ) -> impl std::future::Future<Output = Result<(), Error>> + Send + 'a {
171        async move {
172            let assignments: Vec<String> = Self::INSERT_COLUMNS
173                .iter()
174                .map(|c| format!("{c} = ?"))
175                .collect();
176            let sql = format!(
177                "UPDATE {} SET {} WHERE id = ?",
178                Self::TABLE,
179                assignments.join(", "),
180            );
181            let mut query = sqlx::query(&sql);
182            for v in self.insert_values() {
183                query = bind_value(query, v);
184            }
185            query = query.bind(self.id());
186            query.execute(db.pool()).await?;
187            Ok(())
188        }
189    }
190
191    fn delete(db: &Db, id: i64) -> impl std::future::Future<Output = Result<(), Error>> + Send {
192        async move {
193            let sql = format!("DELETE FROM {} WHERE id = ?", Self::TABLE);
194            sqlx::query(&sql).bind(id).execute(db.pool()).await?;
195            Ok(())
196        }
197    }
198}
199
200fn bind_value<'q>(
201    query: sqlx::query::Query<'q, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'q>>,
202    value: Value,
203) -> sqlx::query::Query<'q, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'q>> {
204    match value {
205        Value::I32(v) => query.bind(v),
206        Value::I64(v) => query.bind(v),
207        Value::String(v) => query.bind(v),
208        Value::Bool(v) => query.bind(v),
209        Value::Null => query.bind(Option::<i64>::None),
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216
217    #[derive(Debug, PartialEq)]
218    struct User {
219        id: i64,
220        name: String,
221        is_admin: bool,
222    }
223
224    impl Model for User {
225        const TABLE: &'static str = "users";
226        const COLUMNS: &'static [&'static str] = &["id", "name", "is_admin"];
227        const INSERT_COLUMNS: &'static [&'static str] = &["name", "is_admin"];
228
229        fn id(&self) -> i64 {
230            self.id
231        }
232
233        fn from_row(row: Row<'_>) -> Result<Self, Error> {
234            Ok(Self {
235                id: row.get_i64("id")?,
236                name: row.get_string("name")?,
237                is_admin: row.get_bool("is_admin")?,
238            })
239        }
240
241        fn insert_values(&self) -> Vec<Value> {
242            vec![self.name.clone().into(), self.is_admin.into()]
243        }
244    }
245
246    async fn setup() -> Db {
247        let db = Db::memory().await.unwrap();
248        db.execute(
249            "CREATE TABLE users (
250                id INTEGER PRIMARY KEY AUTOINCREMENT,
251                name TEXT NOT NULL,
252                is_admin INTEGER NOT NULL
253            )",
254        )
255        .await
256        .unwrap();
257        db
258    }
259
260    #[tokio::test]
261    async fn create_assigns_new_id_and_find_reads_it_back() {
262        let db = setup().await;
263        let u = User {
264            id: 0,
265            name: "Alice".into(),
266            is_admin: false,
267        };
268        let id = u.create(&db).await.unwrap();
269        assert!(id >= 1);
270        let back = User::find(&db, id).await.unwrap().unwrap();
271        assert_eq!(back.name, "Alice");
272        assert!(!back.is_admin);
273        assert_eq!(back.id, id);
274    }
275
276    #[tokio::test]
277    async fn find_missing_returns_none() {
278        let db = setup().await;
279        assert!(User::find(&db, 42).await.unwrap().is_none());
280    }
281
282    #[tokio::test]
283    async fn all_returns_every_row() {
284        let db = setup().await;
285        User {
286            id: 0,
287            name: "a".into(),
288            is_admin: false,
289        }
290        .create(&db)
291        .await
292        .unwrap();
293        User {
294            id: 0,
295            name: "b".into(),
296            is_admin: true,
297        }
298        .create(&db)
299        .await
300        .unwrap();
301        User {
302            id: 0,
303            name: "c".into(),
304            is_admin: false,
305        }
306        .create(&db)
307        .await
308        .unwrap();
309        let rows = User::all(&db).await.unwrap();
310        assert_eq!(rows.len(), 3);
311        let names: Vec<&str> = rows.iter().map(|u| u.name.as_str()).collect();
312        assert_eq!(names, vec!["a", "b", "c"]);
313    }
314
315    #[tokio::test]
316    async fn update_modifies_row_in_place() {
317        let db = setup().await;
318        let id = User {
319            id: 0,
320            name: "old".into(),
321            is_admin: false,
322        }
323        .create(&db)
324        .await
325        .unwrap();
326        let updated = User {
327            id,
328            name: "new".into(),
329            is_admin: true,
330        };
331        updated.update(&db).await.unwrap();
332        let back = User::find(&db, id).await.unwrap().unwrap();
333        assert_eq!(back.name, "new");
334        assert!(back.is_admin);
335    }
336
337    #[tokio::test]
338    async fn delete_removes_row() {
339        let db = setup().await;
340        let id = User {
341            id: 0,
342            name: "x".into(),
343            is_admin: false,
344        }
345        .create(&db)
346        .await
347        .unwrap();
348        assert!(User::find(&db, id).await.unwrap().is_some());
349        User::delete(&db, id).await.unwrap();
350        assert!(User::find(&db, id).await.unwrap().is_none());
351    }
352
353    #[tokio::test]
354    async fn row_getters_handle_wrong_type_gracefully() {
355        let db = setup().await;
356        User {
357            id: 0,
358            name: "a".into(),
359            is_admin: false,
360        }
361        .create(&db)
362        .await
363        .unwrap();
364        let row = sqlx::query("SELECT id, name, is_admin FROM users LIMIT 1")
365            .fetch_one(db.pool())
366            .await
367            .unwrap();
368        let wrapped = Row::new(&row);
369        assert!(wrapped.get_i64("id").is_ok());
370        assert!(wrapped.get_string("nonexistent_column").is_err());
371    }
372}