Skip to main content

rustio_admin/
orm.rs

1//! PostgreSQL-backed ORM. A thin shim over sqlx — `Db`, `Model`,
2//! `Value`, `Row` — and a handful of generic CRUD helpers. Hand-written
3//! `impl Model` is the contract; users keep writing SQL where it matters.
4
5use std::time::Duration;
6
7use chrono::{DateTime, NaiveDate, NaiveTime, Utc};
8use rust_decimal::Decimal;
9use serde_json::Value as JsonValue;
10use sqlx::postgres::{PgPoolOptions, PgRow};
11use sqlx::{PgPool, Row as SqlxRow};
12use uuid::Uuid;
13
14use crate::error::{Error, Result};
15
16// public:
17/// Shared handle to the database. Cheap to clone; every handler gets
18/// its own clone.
19#[derive(Clone)]
20pub struct Db {
21    pool: PgPool,
22}
23
24impl Db {
25    // public:
26    /// Connect with sensible production defaults: 30 max connections,
27    /// 1s acquire timeout, 5min idle timeout.
28    pub async fn connect(url: &str) -> Result<Self> {
29        Self::connect_with(url, DbOptions::default()).await
30    }
31
32    // public:
33    pub async fn connect_with(url: &str, opts: DbOptions) -> Result<Self> {
34        let pool = PgPoolOptions::new()
35            .max_connections(opts.max_connections)
36            .min_connections(opts.min_connections)
37            .acquire_timeout(opts.acquire_timeout)
38            .idle_timeout(Some(opts.idle_timeout))
39            .max_lifetime(Some(opts.max_lifetime))
40            // Suppress benign Postgres NOTICE chatter so harmless
41            // `CREATE … IF NOT EXISTS` and `ADD COLUMN … IF NOT EXISTS`
42            // re-runs don't blast 60+ "relation already exists, skipping"
43            // lines into the operator's log on every boot. WARNING +
44            // ERROR + LOG levels still surface — only the chatty NOTICE
45            // band is silenced. Set per-connection so reconnects in the
46            // pool stay quiet too.
47            .after_connect(|conn, _meta| {
48                Box::pin(async move {
49                    use sqlx::Executor;
50                    conn.execute("SET client_min_messages = warning")
51                        .await
52                        .map(|_| ())
53                })
54            })
55            .connect(url)
56            .await
57            .map_err(|e| Error::Internal(format!("db connect failed: {e}")))?;
58        Ok(Self { pool })
59    }
60
61    // public:
62    pub fn pool(&self) -> &PgPool {
63        &self.pool
64    }
65
66    // public:
67    pub async fn health_check(&self) -> Result<()> {
68        sqlx::query("SELECT 1")
69            .fetch_one(&self.pool)
70            .await
71            .map(|_| ())
72            .map_err(|e| Error::Internal(format!("health check: {e}")))
73    }
74
75    /// Test-only constructor. Builds a `Db` whose pool is real but
76    /// never opens a TCP connection — `connect_lazy_with` defers that
77    /// until `.acquire()` is called. Tests that don't actually hit the
78    /// database can use this.
79    #[cfg(test)]
80    #[allow(dead_code)]
81    pub(crate) fn for_testing_no_connection() -> Self {
82        let pool = PgPoolOptions::new()
83            .max_connections(1)
84            .connect_lazy("postgres://test:test@127.0.0.1:1/never_used")
85            .expect("connect_lazy never fails on a syntactically valid URL");
86        Self { pool }
87    }
88}
89
90// public:
91#[derive(Clone, Debug)]
92pub struct DbOptions {
93    pub max_connections: u32,
94    pub min_connections: u32,
95    pub acquire_timeout: Duration,
96    pub idle_timeout: Duration,
97    pub max_lifetime: Duration,
98}
99
100impl Default for DbOptions {
101    fn default() -> Self {
102        Self {
103            max_connections: 30,
104            min_connections: 2,
105            acquire_timeout: Duration::from_secs(1),
106            idle_timeout: Duration::from_secs(300),
107            max_lifetime: Duration::from_secs(1800),
108        }
109    }
110}
111
112// public:
113/// The value types the framework understands. Kept small on purpose.
114#[derive(Debug, Clone)]
115pub enum Value {
116    Null,
117    I32(i32),
118    I64(i64),
119    F64(f64),
120    Decimal(Decimal),
121    Bool(bool),
122    Text(String),
123    DateTime(DateTime<Utc>),
124    Date(NaiveDate),
125    Time(NaiveTime),
126    Uuid(Uuid),
127    Json(JsonValue),
128}
129
130impl From<i32> for Value {
131    fn from(v: i32) -> Self {
132        Value::I32(v)
133    }
134}
135impl From<i64> for Value {
136    fn from(v: i64) -> Self {
137        Value::I64(v)
138    }
139}
140impl From<f64> for Value {
141    fn from(v: f64) -> Self {
142        Value::F64(v)
143    }
144}
145impl From<Decimal> for Value {
146    fn from(v: Decimal) -> Self {
147        Value::Decimal(v)
148    }
149}
150impl From<bool> for Value {
151    fn from(v: bool) -> Self {
152        Value::Bool(v)
153    }
154}
155impl From<String> for Value {
156    fn from(v: String) -> Self {
157        Value::Text(v)
158    }
159}
160impl<'a> From<&'a str> for Value {
161    fn from(v: &'a str) -> Self {
162        Value::Text(v.to_string())
163    }
164}
165impl From<DateTime<Utc>> for Value {
166    fn from(v: DateTime<Utc>) -> Self {
167        Value::DateTime(v)
168    }
169}
170impl From<NaiveDate> for Value {
171    fn from(v: NaiveDate) -> Self {
172        Value::Date(v)
173    }
174}
175impl From<NaiveTime> for Value {
176    fn from(v: NaiveTime) -> Self {
177        Value::Time(v)
178    }
179}
180impl From<Uuid> for Value {
181    fn from(v: Uuid) -> Self {
182        Value::Uuid(v)
183    }
184}
185impl From<JsonValue> for Value {
186    fn from(v: JsonValue) -> Self {
187        Value::Json(v)
188    }
189}
190impl<T: Into<Value>> From<Option<T>> for Value {
191    fn from(v: Option<T>) -> Self {
192        match v {
193            Some(v) => v.into(),
194            None => Value::Null,
195        }
196    }
197}
198
199// public:
200pub struct Row<'a> {
201    inner: &'a PgRow,
202}
203
204impl<'a> Row<'a> {
205    // public:
206    pub fn from_pg(row: &'a PgRow) -> Self {
207        Self { inner: row }
208    }
209
210    // public:
211    pub fn get_i32(&self, col: &str) -> Result<i32> {
212        self.inner
213            .try_get::<i32, _>(col)
214            .map_err(|e| Error::Internal(format!("get_i32({col}): {e}")))
215    }
216    // public:
217    pub fn get_i64(&self, col: &str) -> Result<i64> {
218        self.inner
219            .try_get::<i64, _>(col)
220            .map_err(|e| Error::Internal(format!("get_i64({col}): {e}")))
221    }
222    // public:
223    pub fn get_f64(&self, col: &str) -> Result<f64> {
224        self.inner
225            .try_get::<f64, _>(col)
226            .map_err(|e| Error::Internal(format!("get_f64({col}): {e}")))
227    }
228    // public:
229    pub fn get_decimal(&self, col: &str) -> Result<Decimal> {
230        self.inner
231            .try_get::<Decimal, _>(col)
232            .map_err(|e| Error::Internal(format!("get_decimal({col}): {e}")))
233    }
234    // public:
235    pub fn get_optional_i64(&self, col: &str) -> Result<Option<i64>> {
236        self.inner
237            .try_get::<Option<i64>, _>(col)
238            .map_err(|e| Error::Internal(format!("{col}: {e}")))
239    }
240    // public:
241    pub fn get_bool(&self, col: &str) -> Result<bool> {
242        self.inner
243            .try_get::<bool, _>(col)
244            .map_err(|e| Error::Internal(format!("get_bool({col}): {e}")))
245    }
246    // public:
247    pub fn get_string(&self, col: &str) -> Result<String> {
248        self.inner
249            .try_get::<String, _>(col)
250            .map_err(|e| Error::Internal(format!("get_string({col}): {e}")))
251    }
252    // public:
253    pub fn get_optional_string(&self, col: &str) -> Result<Option<String>> {
254        self.inner
255            .try_get::<Option<String>, _>(col)
256            .map_err(|e| Error::Internal(format!("{col}: {e}")))
257    }
258    // public:
259    pub fn get_datetime(&self, col: &str) -> Result<DateTime<Utc>> {
260        self.inner
261            .try_get::<DateTime<Utc>, _>(col)
262            .map_err(|e| Error::Internal(format!("{col}: {e}")))
263    }
264    // public:
265    pub fn get_date(&self, col: &str) -> Result<NaiveDate> {
266        self.inner
267            .try_get::<NaiveDate, _>(col)
268            .map_err(|e| Error::Internal(format!("get_date({col}): {e}")))
269    }
270    // public:
271    pub fn get_time(&self, col: &str) -> Result<NaiveTime> {
272        self.inner
273            .try_get::<NaiveTime, _>(col)
274            .map_err(|e| Error::Internal(format!("get_time({col}): {e}")))
275    }
276    // public:
277    pub fn get_optional_datetime(&self, col: &str) -> Result<Option<DateTime<Utc>>> {
278        self.inner
279            .try_get::<Option<DateTime<Utc>>, _>(col)
280            .map_err(|e| Error::Internal(format!("{col}: {e}")))
281    }
282    // public:
283    pub fn get_uuid(&self, col: &str) -> Result<Uuid> {
284        self.inner
285            .try_get::<Uuid, _>(col)
286            .map_err(|e| Error::Internal(format!("get_uuid({col}): {e}")))
287    }
288    // public:
289    pub fn get_json(&self, col: &str) -> Result<JsonValue> {
290        self.inner
291            .try_get::<JsonValue, _>(col)
292            .map_err(|e| Error::Internal(format!("{col}: {e}")))
293    }
294}
295
296// public:
297pub trait Model: Send + Sync + Sized + 'static {
298    const TABLE: &'static str;
299    const COLUMNS: &'static [&'static str];
300    const INSERT_COLUMNS: &'static [&'static str];
301
302    fn id(&self) -> i64;
303    fn from_row(row: Row<'_>) -> Result<Self>;
304    fn insert_values(&self) -> Vec<Value>;
305}
306
307// ---- Generic CRUD helpers -----------------------------------------------
308
309// public:
310pub async fn all<M: Model>(db: &Db) -> Result<Vec<M>> {
311    let sql = format!(
312        "SELECT {} FROM {} ORDER BY id DESC",
313        M::COLUMNS.join(", "),
314        M::TABLE
315    );
316    let rows = sqlx::query(sqlx::AssertSqlSafe(sql))
317        .fetch_all(db.pool())
318        .await?;
319    rows.iter().map(|r| M::from_row(Row::from_pg(r))).collect()
320}
321
322// public:
323pub async fn page<M: Model>(db: &Db, limit: i64, offset: i64) -> Result<Vec<M>> {
324    let sql = format!(
325        "SELECT {} FROM {} ORDER BY id DESC LIMIT $1 OFFSET $2",
326        M::COLUMNS.join(", "),
327        M::TABLE
328    );
329    let rows = sqlx::query(sqlx::AssertSqlSafe(sql))
330        .bind(limit)
331        .bind(offset)
332        .fetch_all(db.pool())
333        .await?;
334    rows.iter().map(|r| M::from_row(Row::from_pg(r))).collect()
335}
336
337// public:
338pub async fn count<M: Model>(db: &Db) -> Result<i64> {
339    let sql = format!("SELECT COUNT(*) AS c FROM {}", M::TABLE);
340    let row = sqlx::query(sqlx::AssertSqlSafe(sql))
341        .fetch_one(db.pool())
342        .await?;
343    row.try_get::<i64, _>("c")
344        .map_err(|e| Error::Internal(format!("count: {e}")))
345}
346
347// public:
348pub async fn find<M: Model>(db: &Db, id: i64) -> Result<Option<M>> {
349    let sql = format!(
350        "SELECT {} FROM {} WHERE id = $1",
351        M::COLUMNS.join(", "),
352        M::TABLE
353    );
354    let row = sqlx::query(sqlx::AssertSqlSafe(sql))
355        .bind(id)
356        .fetch_optional(db.pool())
357        .await?;
358    match row {
359        Some(r) => Ok(Some(M::from_row(Row::from_pg(&r))?)),
360        None => Ok(None),
361    }
362}
363
364// public:
365pub async fn create<M: Model>(db: &Db, model: &M) -> Result<i64> {
366    let cols = M::INSERT_COLUMNS.join(", ");
367    let placeholders: Vec<String> = (1..=M::INSERT_COLUMNS.len())
368        .map(|i| format!("${i}"))
369        .collect();
370    let sql = format!(
371        "INSERT INTO {} ({}) VALUES ({}) RETURNING id",
372        M::TABLE,
373        cols,
374        placeholders.join(", ")
375    );
376    let mut query = sqlx::query(sqlx::AssertSqlSafe(sql));
377    for value in model.insert_values() {
378        query = bind_value(query, value);
379    }
380    let row = query.fetch_one(db.pool()).await?;
381    let id: i64 = row
382        .try_get("id")
383        .map_err(|e| Error::Internal(format!("returning id: {e}")))?;
384    Ok(id)
385}
386
387// public:
388pub async fn update<M: Model>(db: &Db, id: i64, model: &M) -> Result<()> {
389    let sets: Vec<String> = M::INSERT_COLUMNS
390        .iter()
391        .enumerate()
392        .map(|(i, col)| format!("{col} = ${}", i + 1))
393        .collect();
394    let sql = format!(
395        "UPDATE {} SET {} WHERE id = ${}",
396        M::TABLE,
397        sets.join(", "),
398        M::INSERT_COLUMNS.len() + 1
399    );
400    let mut query = sqlx::query(sqlx::AssertSqlSafe(sql));
401    for value in model.insert_values() {
402        query = bind_value(query, value);
403    }
404    query = query.bind(id);
405    query.execute(db.pool()).await?;
406    Ok(())
407}
408
409// public:
410pub async fn delete<M: Model>(db: &Db, id: i64) -> Result<()> {
411    let sql = format!("DELETE FROM {} WHERE id = $1", M::TABLE);
412    sqlx::query(sqlx::AssertSqlSafe(sql))
413        .bind(id)
414        .execute(db.pool())
415        .await?;
416    Ok(())
417}
418
419fn bind_value<'a>(
420    q: sqlx::query::Query<'a, sqlx::Postgres, sqlx::postgres::PgArguments>,
421    v: Value,
422) -> sqlx::query::Query<'a, sqlx::Postgres, sqlx::postgres::PgArguments> {
423    match v {
424        Value::Null => q.bind(None::<i64>),
425        Value::I32(n) => q.bind(n),
426        Value::I64(n) => q.bind(n),
427        Value::F64(n) => q.bind(n),
428        Value::Decimal(d) => q.bind(d),
429        Value::Bool(b) => q.bind(b),
430        Value::Text(s) => q.bind(s),
431        Value::DateTime(d) => q.bind(d),
432        Value::Date(d) => q.bind(d),
433        Value::Time(t) => q.bind(t),
434        Value::Uuid(u) => q.bind(u),
435        Value::Json(j) => q.bind(j),
436    }
437}
438
439#[cfg(test)]
440mod value_conversion_tests {
441    use super::Value;
442    use chrono::{NaiveDate, NaiveTime};
443    use rust_decimal::Decimal;
444    use std::str::FromStr;
445    use uuid::Uuid;
446
447    #[test]
448    fn scalar_from_impls_map_to_their_variants() {
449        assert!(matches!(Value::from(3.5_f64), Value::F64(v) if v == 3.5));
450        let dec = Decimal::from_str("19.99").unwrap();
451        assert!(matches!(Value::from(dec), Value::Decimal(v) if v == dec));
452        let d = NaiveDate::from_ymd_opt(2026, 6, 2).unwrap();
453        assert!(matches!(Value::from(d), Value::Date(v) if v == d));
454        let t = NaiveTime::from_hms_opt(9, 30, 0).unwrap();
455        assert!(matches!(Value::from(t), Value::Time(v) if v == t));
456        let u = Uuid::from_u128(0x550e8400_e29b_41d4_a716_446655440000);
457        assert!(matches!(Value::from(u), Value::Uuid(v) if v == u));
458    }
459}