Skip to main content

rustio_core/
orm.rs

1//! SQLite-backed ORM.
2//!
3//! Implement [`Model`] on your struct to get `find / all / create / update /
4//! delete` for free. SQLx is used internally; user code never references it.
5//!
6//! Phase 1 supports `i32`, `i64`, `String`, and `bool` field types; the id
7//! column is required to be `i64`.
8
9use std::str::FromStr;
10
11use chrono::{DateTime, Utc};
12use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions, SqliteRow};
13use sqlx::Row as _;
14
15use crate::error::Error;
16
17#[derive(Clone)]
18pub struct Db {
19    pool: SqlitePool,
20}
21
22impl Db {
23    /// Open a pool against the given SQLite URL.
24    ///
25    /// Foreign-key enforcement is **always on** (`PRAGMA foreign_keys = ON`
26    /// applied on every connection via sqlx's connect-time hook). SQLite
27    /// ignores FK constraints unless this pragma is set per-connection,
28    /// and relying on user configuration to enable it is unsafe —
29    /// `ON DELETE CASCADE` in the schema would silently do nothing.
30    pub async fn connect(url: &str) -> Result<Self, Error> {
31        let opts = SqliteConnectOptions::from_str(url)
32            .map_err(|e| Error::Internal(format!("invalid database URL: {e}")))?
33            .foreign_keys(true);
34        let pool = SqlitePoolOptions::new().connect_with(opts).await?;
35        Ok(Self { pool })
36    }
37
38    /// Open an in-memory pool with FK enforcement on.
39    ///
40    /// Limited to a single connection because each `:memory:` connection
41    /// opens its *own* database; multiple would not share rows.
42    pub async fn memory() -> Result<Self, Error> {
43        let opts = SqliteConnectOptions::from_str("sqlite::memory:")
44            .map_err(|e| Error::Internal(format!("invalid database URL: {e}")))?
45            .foreign_keys(true);
46        let pool = SqlitePoolOptions::new()
47            .max_connections(1)
48            .connect_with(opts)
49            .await?;
50        Ok(Self { pool })
51    }
52
53    pub async fn execute(&self, sql: &str) -> Result<(), Error> {
54        sqlx::query(sql).execute(&self.pool).await?;
55        Ok(())
56    }
57
58    pub(crate) fn pool(&self) -> &SqlitePool {
59        &self.pool
60    }
61}
62
63pub struct Row<'a> {
64    inner: &'a SqliteRow,
65}
66
67impl<'a> Row<'a> {
68    pub(crate) fn new(inner: &'a SqliteRow) -> Self {
69        Self { inner }
70    }
71
72    pub fn get_i32(&self, name: &str) -> Result<i32, Error> {
73        self.inner.try_get(name).map_err(Error::from)
74    }
75
76    pub fn get_i64(&self, name: &str) -> Result<i64, Error> {
77        self.inner.try_get(name).map_err(Error::from)
78    }
79
80    pub fn get_string(&self, name: &str) -> Result<String, Error> {
81        self.inner.try_get(name).map_err(Error::from)
82    }
83
84    pub fn get_bool(&self, name: &str) -> Result<bool, Error> {
85        self.inner.try_get(name).map_err(Error::from)
86    }
87
88    pub fn get_datetime(&self, name: &str) -> Result<DateTime<Utc>, Error> {
89        self.inner.try_get(name).map_err(Error::from)
90    }
91
92    // Nullable variants. Each returns `None` when the column is SQL NULL;
93    // any other decode failure still surfaces as `Error::Internal`.
94
95    pub fn get_optional_i32(&self, name: &str) -> Result<Option<i32>, Error> {
96        self.inner.try_get(name).map_err(Error::from)
97    }
98
99    pub fn get_optional_i64(&self, name: &str) -> Result<Option<i64>, Error> {
100        self.inner.try_get(name).map_err(Error::from)
101    }
102
103    pub fn get_optional_string(&self, name: &str) -> Result<Option<String>, Error> {
104        self.inner.try_get(name).map_err(Error::from)
105    }
106
107    pub fn get_optional_bool(&self, name: &str) -> Result<Option<bool>, Error> {
108        self.inner.try_get(name).map_err(Error::from)
109    }
110
111    pub fn get_optional_datetime(&self, name: &str) -> Result<Option<DateTime<Utc>>, Error> {
112        self.inner.try_get(name).map_err(Error::from)
113    }
114}
115
116/// A typed value ready to bind to a SQL placeholder.
117///
118/// `#[non_exhaustive]` because we expect to add variants (`Uuid`, `Json`,
119/// `Bytes`, `Decimal`) in later releases. The `bind_value` matcher below
120/// must be updated in lockstep with additions here.
121#[non_exhaustive]
122#[derive(Debug)]
123pub enum Value {
124    I32(i32),
125    I64(i64),
126    String(String),
127    Bool(bool),
128    DateTime(DateTime<Utc>),
129    /// NULL. Produced from `None` via the `From<Option<T>>` impls.
130    Null,
131}
132
133impl From<i32> for Value {
134    fn from(v: i32) -> Self {
135        Value::I32(v)
136    }
137}
138
139impl From<i64> for Value {
140    fn from(v: i64) -> Self {
141        Value::I64(v)
142    }
143}
144
145impl From<String> for Value {
146    fn from(v: String) -> Self {
147        Value::String(v)
148    }
149}
150
151impl From<&str> for Value {
152    fn from(v: &str) -> Self {
153        Value::String(v.to_owned())
154    }
155}
156
157impl From<bool> for Value {
158    fn from(v: bool) -> Self {
159        Value::Bool(v)
160    }
161}
162
163impl From<DateTime<Utc>> for Value {
164    fn from(v: DateTime<Utc>) -> Self {
165        Value::DateTime(v)
166    }
167}
168
169// Blanket `Option<T>` support: any type that converts into `Value` can
170// also be wrapped in `Option` for nullable columns. `None` becomes
171// `Value::Null`, `Some(x)` becomes whatever `x` converts to.
172impl<T> From<Option<T>> for Value
173where
174    T: Into<Value>,
175{
176    fn from(v: Option<T>) -> Self {
177        match v {
178            Some(inner) => inner.into(),
179            None => Value::Null,
180        }
181    }
182}
183
184pub trait Model: Sized + Send + Sync + Unpin + 'static {
185    const TABLE: &'static str;
186    const COLUMNS: &'static [&'static str];
187    const INSERT_COLUMNS: &'static [&'static str];
188
189    fn id(&self) -> i64;
190    fn from_row(row: Row<'_>) -> Result<Self, Error>;
191    fn insert_values(&self) -> Vec<Value>;
192
193    fn find(
194        db: &Db,
195        id: i64,
196    ) -> impl std::future::Future<Output = Result<Option<Self>, Error>> + Send
197    where
198        Self: Send,
199    {
200        async move {
201            let sql = format!(
202                "SELECT {} FROM {} WHERE id = ?",
203                Self::COLUMNS.join(", "),
204                Self::TABLE,
205            );
206            let row = sqlx::query(&sql).bind(id).fetch_optional(db.pool()).await?;
207            match row {
208                Some(r) => Ok(Some(Self::from_row(Row::new(&r))?)),
209                None => Ok(None),
210            }
211        }
212    }
213
214    fn all(db: &Db) -> impl std::future::Future<Output = Result<Vec<Self>, Error>> + Send {
215        async move {
216            let sql = format!("SELECT {} FROM {}", Self::COLUMNS.join(", "), Self::TABLE);
217            let rows = sqlx::query(&sql).fetch_all(db.pool()).await?;
218            rows.iter().map(|r| Self::from_row(Row::new(r))).collect()
219        }
220    }
221
222    fn create<'a>(
223        &'a self,
224        db: &'a Db,
225    ) -> impl std::future::Future<Output = Result<i64, Error>> + Send + 'a {
226        async move {
227            let placeholders = vec!["?"; Self::INSERT_COLUMNS.len()].join(", ");
228            let sql = format!(
229                "INSERT INTO {} ({}) VALUES ({})",
230                Self::TABLE,
231                Self::INSERT_COLUMNS.join(", "),
232                placeholders,
233            );
234            let mut query = sqlx::query(&sql);
235            for v in self.insert_values() {
236                query = bind_value(query, v);
237            }
238            let result = query.execute(db.pool()).await?;
239            Ok(result.last_insert_rowid())
240        }
241    }
242
243    fn update<'a>(
244        &'a self,
245        db: &'a Db,
246    ) -> impl std::future::Future<Output = Result<(), Error>> + Send + 'a {
247        async move {
248            let assignments: Vec<String> = Self::INSERT_COLUMNS
249                .iter()
250                .map(|c| format!("{c} = ?"))
251                .collect();
252            let sql = format!(
253                "UPDATE {} SET {} WHERE id = ?",
254                Self::TABLE,
255                assignments.join(", "),
256            );
257            let mut query = sqlx::query(&sql);
258            for v in self.insert_values() {
259                query = bind_value(query, v);
260            }
261            query = query.bind(self.id());
262            query.execute(db.pool()).await?;
263            Ok(())
264        }
265    }
266
267    fn delete(db: &Db, id: i64) -> impl std::future::Future<Output = Result<(), Error>> + Send {
268        async move {
269            let sql = format!("DELETE FROM {} WHERE id = ?", Self::TABLE);
270            sqlx::query(&sql).bind(id).execute(db.pool()).await?;
271            Ok(())
272        }
273    }
274}
275
276fn bind_value<'q>(
277    query: sqlx::query::Query<'q, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'q>>,
278    value: Value,
279) -> sqlx::query::Query<'q, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'q>> {
280    match value {
281        Value::I32(v) => query.bind(v),
282        Value::I64(v) => query.bind(v),
283        Value::String(v) => query.bind(v),
284        Value::Bool(v) => query.bind(v),
285        Value::DateTime(v) => query.bind(v),
286        Value::Null => query.bind(Option::<i64>::None),
287    }
288}
289
290#[cfg(test)]
291mod tests {
292    use super::*;
293    use chrono::TimeZone as _;
294
295    #[derive(Debug, PartialEq)]
296    struct User {
297        id: i64,
298        name: String,
299        is_admin: bool,
300    }
301
302    impl Model for User {
303        const TABLE: &'static str = "users";
304        const COLUMNS: &'static [&'static str] = &["id", "name", "is_admin"];
305        const INSERT_COLUMNS: &'static [&'static str] = &["name", "is_admin"];
306
307        fn id(&self) -> i64 {
308            self.id
309        }
310
311        fn from_row(row: Row<'_>) -> Result<Self, Error> {
312            Ok(Self {
313                id: row.get_i64("id")?,
314                name: row.get_string("name")?,
315                is_admin: row.get_bool("is_admin")?,
316            })
317        }
318
319        fn insert_values(&self) -> Vec<Value> {
320            vec![self.name.clone().into(), self.is_admin.into()]
321        }
322    }
323
324    async fn setup() -> Db {
325        let db = Db::memory().await.unwrap();
326        db.execute(
327            "CREATE TABLE users (
328                id INTEGER PRIMARY KEY AUTOINCREMENT,
329                name TEXT NOT NULL,
330                is_admin INTEGER NOT NULL
331            )",
332        )
333        .await
334        .unwrap();
335        db
336    }
337
338    #[tokio::test]
339    async fn create_assigns_new_id_and_find_reads_it_back() {
340        let db = setup().await;
341        let u = User {
342            id: 0,
343            name: "Alice".into(),
344            is_admin: false,
345        };
346        let id = u.create(&db).await.unwrap();
347        assert!(id >= 1);
348        let back = User::find(&db, id).await.unwrap().unwrap();
349        assert_eq!(back.name, "Alice");
350        assert!(!back.is_admin);
351        assert_eq!(back.id, id);
352    }
353
354    #[tokio::test]
355    async fn find_missing_returns_none() {
356        let db = setup().await;
357        assert!(User::find(&db, 42).await.unwrap().is_none());
358    }
359
360    #[tokio::test]
361    async fn all_returns_every_row() {
362        let db = setup().await;
363        User {
364            id: 0,
365            name: "a".into(),
366            is_admin: false,
367        }
368        .create(&db)
369        .await
370        .unwrap();
371        User {
372            id: 0,
373            name: "b".into(),
374            is_admin: true,
375        }
376        .create(&db)
377        .await
378        .unwrap();
379        User {
380            id: 0,
381            name: "c".into(),
382            is_admin: false,
383        }
384        .create(&db)
385        .await
386        .unwrap();
387        let rows = User::all(&db).await.unwrap();
388        assert_eq!(rows.len(), 3);
389        let names: Vec<&str> = rows.iter().map(|u| u.name.as_str()).collect();
390        assert_eq!(names, vec!["a", "b", "c"]);
391    }
392
393    #[tokio::test]
394    async fn update_modifies_row_in_place() {
395        let db = setup().await;
396        let id = User {
397            id: 0,
398            name: "old".into(),
399            is_admin: false,
400        }
401        .create(&db)
402        .await
403        .unwrap();
404        let updated = User {
405            id,
406            name: "new".into(),
407            is_admin: true,
408        };
409        updated.update(&db).await.unwrap();
410        let back = User::find(&db, id).await.unwrap().unwrap();
411        assert_eq!(back.name, "new");
412        assert!(back.is_admin);
413    }
414
415    #[tokio::test]
416    async fn delete_removes_row() {
417        let db = setup().await;
418        let id = User {
419            id: 0,
420            name: "x".into(),
421            is_admin: false,
422        }
423        .create(&db)
424        .await
425        .unwrap();
426        assert!(User::find(&db, id).await.unwrap().is_some());
427        User::delete(&db, id).await.unwrap();
428        assert!(User::find(&db, id).await.unwrap().is_none());
429    }
430
431    #[tokio::test]
432    async fn row_getters_handle_wrong_type_gracefully() {
433        let db = setup().await;
434        User {
435            id: 0,
436            name: "a".into(),
437            is_admin: false,
438        }
439        .create(&db)
440        .await
441        .unwrap();
442        let row = sqlx::query("SELECT id, name, is_admin FROM users LIMIT 1")
443            .fetch_one(db.pool())
444            .await
445            .unwrap();
446        let wrapped = Row::new(&row);
447        assert!(wrapped.get_i64("id").is_ok());
448        assert!(wrapped.get_string("nonexistent_column").is_err());
449    }
450
451    // --- Option<T> NULL ↔ None round-trip ------------------------------
452    //
453    // Proves that the *whole* chain — `Value::Null`, the nullable getters
454    // on `Row`, and the `From<Option<T>>` blanket impl — stays coherent
455    // in both directions. If any piece drifts, `Some`/`None` start
456    // leaking into each other silently and the admin forms misbehave.
457
458    #[derive(Debug, PartialEq)]
459    struct Event {
460        id: i64,
461        title: String,
462        note: Option<String>,
463        priority: Option<i32>,
464        starts_at: Option<DateTime<Utc>>,
465    }
466
467    impl Model for Event {
468        const TABLE: &'static str = "events";
469        const COLUMNS: &'static [&'static str] = &["id", "title", "note", "priority", "starts_at"];
470        const INSERT_COLUMNS: &'static [&'static str] = &["title", "note", "priority", "starts_at"];
471
472        fn id(&self) -> i64 {
473            self.id
474        }
475
476        fn from_row(row: Row<'_>) -> Result<Self, Error> {
477            Ok(Self {
478                id: row.get_i64("id")?,
479                title: row.get_string("title")?,
480                note: row.get_optional_string("note")?,
481                priority: row.get_optional_i32("priority")?,
482                starts_at: row.get_optional_datetime("starts_at")?,
483            })
484        }
485
486        fn insert_values(&self) -> Vec<Value> {
487            vec![
488                self.title.clone().into(),
489                self.note.clone().into(),
490                self.priority.into(),
491                self.starts_at.into(),
492            ]
493        }
494    }
495
496    async fn setup_events() -> Db {
497        let db = Db::memory().await.unwrap();
498        db.execute(
499            "CREATE TABLE events (
500                id INTEGER PRIMARY KEY AUTOINCREMENT,
501                title TEXT NOT NULL,
502                note TEXT NULL,
503                priority INTEGER NULL,
504                starts_at TEXT NULL
505            )",
506        )
507        .await
508        .unwrap();
509        db
510    }
511
512    #[tokio::test]
513    async fn option_none_round_trips_as_null() {
514        let db = setup_events().await;
515        let id = Event {
516            id: 0,
517            title: "empty".into(),
518            note: None,
519            priority: None,
520            starts_at: None,
521        }
522        .create(&db)
523        .await
524        .unwrap();
525
526        let back = Event::find(&db, id).await.unwrap().unwrap();
527        assert_eq!(back.note, None);
528        assert_eq!(back.priority, None);
529        assert_eq!(back.starts_at, None);
530
531        // The raw row must actually be NULL, not the empty string — a
532        // string that round-trips as Some("") would silently break the
533        // admin's "no value" semantics.
534        let row = sqlx::query(
535            "SELECT note IS NULL AS note_is_null,
536                    priority IS NULL AS priority_is_null,
537                    starts_at IS NULL AS starts_is_null
538             FROM events WHERE id = ?",
539        )
540        .bind(id)
541        .fetch_one(db.pool())
542        .await
543        .unwrap();
544        assert_eq!(row.get::<i64, _>(0), 1);
545        assert_eq!(row.get::<i64, _>(1), 1);
546        assert_eq!(row.get::<i64, _>(2), 1);
547    }
548
549    #[tokio::test]
550    async fn option_some_round_trips_without_data_loss() {
551        let db = setup_events().await;
552        let when = Utc.with_ymd_and_hms(2026, 4, 18, 10, 12, 33).unwrap();
553        let id = Event {
554            id: 0,
555            title: "full".into(),
556            note: Some("hello".into()),
557            priority: Some(7),
558            starts_at: Some(when),
559        }
560        .create(&db)
561        .await
562        .unwrap();
563
564        let back = Event::find(&db, id).await.unwrap().unwrap();
565        assert_eq!(back.note.as_deref(), Some("hello"));
566        assert_eq!(back.priority, Some(7));
567        assert_eq!(back.starts_at, Some(when));
568    }
569
570    #[tokio::test]
571    async fn option_update_flips_null_to_some_and_back() {
572        let db = setup_events().await;
573        let id = Event {
574            id: 0,
575            title: "t".into(),
576            note: None,
577            priority: None,
578            starts_at: None,
579        }
580        .create(&db)
581        .await
582        .unwrap();
583
584        Event {
585            id,
586            title: "t".into(),
587            note: Some("filled".into()),
588            priority: Some(1),
589            starts_at: None,
590        }
591        .update(&db)
592        .await
593        .unwrap();
594        let mid = Event::find(&db, id).await.unwrap().unwrap();
595        assert_eq!(mid.note.as_deref(), Some("filled"));
596        assert_eq!(mid.priority, Some(1));
597        assert_eq!(mid.starts_at, None);
598
599        Event {
600            id,
601            title: "t".into(),
602            note: None,
603            priority: None,
604            starts_at: None,
605        }
606        .update(&db)
607        .await
608        .unwrap();
609        let after = Event::find(&db, id).await.unwrap().unwrap();
610        assert_eq!(after.note, None);
611        assert_eq!(after.priority, None);
612        assert_eq!(after.starts_at, None);
613    }
614}