Skip to main content

rustio_admin/
orm.rs

1//! PostgreSQL-backed ORM. A thin shim over sqlx — `Db`, `Model`,
2//! `Value`, `Row` — and a handful of generic CRUD helpers. Hand-written
3//! `impl Model` is the contract; users keep writing SQL where it matters.
4
5use std::time::Duration;
6
7use chrono::{DateTime, NaiveDate, NaiveTime, Utc};
8use rust_decimal::Decimal;
9use serde_json::Value as JsonValue;
10use sqlx::postgres::{PgPoolOptions, PgRow};
11use sqlx::{PgPool, Row as SqlxRow};
12use uuid::Uuid;
13
14use crate::error::{Error, Result};
15
16// public:
17/// Shared handle to the database. Cheap to clone; every handler gets
18/// its own clone.
19#[derive(Clone)]
20pub struct Db {
21    pool: PgPool,
22}
23
24impl Db {
25    // public:
26    /// Connect with sensible production defaults: 30 max connections,
27    /// 1s acquire timeout, 5min idle timeout.
28    pub async fn connect(url: &str) -> Result<Self> {
29        Self::connect_with(url, DbOptions::default()).await
30    }
31
32    // public:
33    pub async fn connect_with(url: &str, opts: DbOptions) -> Result<Self> {
34        let pool = PgPoolOptions::new()
35            .max_connections(opts.max_connections)
36            .min_connections(opts.min_connections)
37            .acquire_timeout(opts.acquire_timeout)
38            .idle_timeout(Some(opts.idle_timeout))
39            .max_lifetime(Some(opts.max_lifetime))
40            // Suppress benign Postgres NOTICE chatter so harmless
41            // `CREATE … IF NOT EXISTS` and `ADD COLUMN … IF NOT EXISTS`
42            // re-runs don't blast 60+ "relation already exists, skipping"
43            // lines into the operator's log on every boot. WARNING +
44            // ERROR + LOG levels still surface — only the chatty NOTICE
45            // band is silenced. Set per-connection so reconnects in the
46            // pool stay quiet too.
47            .after_connect(|conn, _meta| {
48                Box::pin(async move {
49                    use sqlx::Executor;
50                    conn.execute("SET client_min_messages = warning")
51                        .await
52                        .map(|_| ())
53                })
54            })
55            .connect(url)
56            .await
57            .map_err(|e| Error::Internal(format!("db connect failed: {e}")))?;
58        Ok(Self { pool })
59    }
60
61    // public:
62    pub fn pool(&self) -> &PgPool {
63        &self.pool
64    }
65
66    // public:
67    pub async fn health_check(&self) -> Result<()> {
68        sqlx::query("SELECT 1")
69            .fetch_one(&self.pool)
70            .await
71            .map(|_| ())
72            .map_err(|e| Error::Internal(format!("health check: {e}")))
73    }
74
75    /// Test-only constructor. Builds a `Db` whose pool is real but
76    /// never opens a TCP connection — `connect_lazy_with` defers that
77    /// until `.acquire()` is called. Tests that don't actually hit the
78    /// database can use this.
79    #[cfg(test)]
80    #[allow(dead_code)]
81    pub(crate) fn for_testing_no_connection() -> Self {
82        let pool = PgPoolOptions::new()
83            .max_connections(1)
84            .connect_lazy("postgres://test:test@127.0.0.1:1/never_used")
85            .expect("connect_lazy never fails on a syntactically valid URL");
86        Self { pool }
87    }
88}
89
90// public:
91#[derive(Clone, Debug)]
92pub struct DbOptions {
93    pub max_connections: u32,
94    pub min_connections: u32,
95    pub acquire_timeout: Duration,
96    pub idle_timeout: Duration,
97    pub max_lifetime: Duration,
98}
99
100impl Default for DbOptions {
101    fn default() -> Self {
102        Self {
103            max_connections: 30,
104            min_connections: 2,
105            acquire_timeout: Duration::from_secs(1),
106            idle_timeout: Duration::from_secs(300),
107            max_lifetime: Duration::from_secs(1800),
108        }
109    }
110}
111
112// public:
113/// The value types the framework understands. Kept small on purpose.
114#[derive(Debug, Clone)]
115pub enum Value {
116    Null,
117    I32(i32),
118    I64(i64),
119    F64(f64),
120    Decimal(Decimal),
121    Bool(bool),
122    Text(String),
123    DateTime(DateTime<Utc>),
124    Date(NaiveDate),
125    Time(NaiveTime),
126    Uuid(Uuid),
127    Json(JsonValue),
128}
129
130impl From<i32> for Value {
131    fn from(v: i32) -> Self {
132        Value::I32(v)
133    }
134}
135impl From<i64> for Value {
136    fn from(v: i64) -> Self {
137        Value::I64(v)
138    }
139}
140impl From<f64> for Value {
141    fn from(v: f64) -> Self {
142        Value::F64(v)
143    }
144}
145impl From<Decimal> for Value {
146    fn from(v: Decimal) -> Self {
147        Value::Decimal(v)
148    }
149}
150impl From<bool> for Value {
151    fn from(v: bool) -> Self {
152        Value::Bool(v)
153    }
154}
155impl From<String> for Value {
156    fn from(v: String) -> Self {
157        Value::Text(v)
158    }
159}
160impl<'a> From<&'a str> for Value {
161    fn from(v: &'a str) -> Self {
162        Value::Text(v.to_string())
163    }
164}
165impl From<DateTime<Utc>> for Value {
166    fn from(v: DateTime<Utc>) -> Self {
167        Value::DateTime(v)
168    }
169}
170impl From<NaiveDate> for Value {
171    fn from(v: NaiveDate) -> Self {
172        Value::Date(v)
173    }
174}
175impl From<NaiveTime> for Value {
176    fn from(v: NaiveTime) -> Self {
177        Value::Time(v)
178    }
179}
180impl From<Uuid> for Value {
181    fn from(v: Uuid) -> Self {
182        Value::Uuid(v)
183    }
184}
185impl From<JsonValue> for Value {
186    fn from(v: JsonValue) -> Self {
187        Value::Json(v)
188    }
189}
190impl<T: Into<Value>> From<Option<T>> for Value {
191    fn from(v: Option<T>) -> Self {
192        match v {
193            Some(v) => v.into(),
194            None => Value::Null,
195        }
196    }
197}
198
199// public:
200pub struct Row<'a> {
201    inner: &'a PgRow,
202}
203
204impl<'a> Row<'a> {
205    // public:
206    pub fn from_pg(row: &'a PgRow) -> Self {
207        Self { inner: row }
208    }
209
210    // public:
211    pub fn get_i32(&self, col: &str) -> Result<i32> {
212        self.inner
213            .try_get::<i32, _>(col)
214            .map_err(|e| Error::Internal(format!("get_i32({col}): {e}")))
215    }
216    // public:
217    pub fn get_i64(&self, col: &str) -> Result<i64> {
218        self.inner
219            .try_get::<i64, _>(col)
220            .map_err(|e| Error::Internal(format!("get_i64({col}): {e}")))
221    }
222    // public:
223    pub fn get_f64(&self, col: &str) -> Result<f64> {
224        self.inner
225            .try_get::<f64, _>(col)
226            .map_err(|e| Error::Internal(format!("get_f64({col}): {e}")))
227    }
228    // public:
229    pub fn get_decimal(&self, col: &str) -> Result<Decimal> {
230        self.inner
231            .try_get::<Decimal, _>(col)
232            .map_err(|e| Error::Internal(format!("get_decimal({col}): {e}")))
233    }
234    // public:
235    pub fn get_optional_i64(&self, col: &str) -> Result<Option<i64>> {
236        self.inner
237            .try_get::<Option<i64>, _>(col)
238            .map_err(|e| Error::Internal(format!("{col}: {e}")))
239    }
240    // public:
241    pub fn get_bool(&self, col: &str) -> Result<bool> {
242        self.inner
243            .try_get::<bool, _>(col)
244            .map_err(|e| Error::Internal(format!("get_bool({col}): {e}")))
245    }
246    // public:
247    pub fn get_string(&self, col: &str) -> Result<String> {
248        self.inner
249            .try_get::<String, _>(col)
250            .map_err(|e| Error::Internal(format!("get_string({col}): {e}")))
251    }
252    // public:
253    pub fn get_optional_string(&self, col: &str) -> Result<Option<String>> {
254        self.inner
255            .try_get::<Option<String>, _>(col)
256            .map_err(|e| Error::Internal(format!("{col}: {e}")))
257    }
258    // public:
259    pub fn get_datetime(&self, col: &str) -> Result<DateTime<Utc>> {
260        self.inner
261            .try_get::<DateTime<Utc>, _>(col)
262            .map_err(|e| Error::Internal(format!("{col}: {e}")))
263    }
264    // public:
265    pub fn get_date(&self, col: &str) -> Result<NaiveDate> {
266        self.inner
267            .try_get::<NaiveDate, _>(col)
268            .map_err(|e| Error::Internal(format!("get_date({col}): {e}")))
269    }
270    // public:
271    pub fn get_time(&self, col: &str) -> Result<NaiveTime> {
272        self.inner
273            .try_get::<NaiveTime, _>(col)
274            .map_err(|e| Error::Internal(format!("get_time({col}): {e}")))
275    }
276    // public:
277    pub fn get_optional_datetime(&self, col: &str) -> Result<Option<DateTime<Utc>>> {
278        self.inner
279            .try_get::<Option<DateTime<Utc>>, _>(col)
280            .map_err(|e| Error::Internal(format!("{col}: {e}")))
281    }
282    // public:
283    pub fn get_uuid(&self, col: &str) -> Result<Uuid> {
284        self.inner
285            .try_get::<Uuid, _>(col)
286            .map_err(|e| Error::Internal(format!("get_uuid({col}): {e}")))
287    }
288    // public:
289    pub fn get_json(&self, col: &str) -> Result<JsonValue> {
290        self.inner
291            .try_get::<JsonValue, _>(col)
292            .map_err(|e| Error::Internal(format!("{col}: {e}")))
293    }
294}
295
296// public:
297pub trait Model: Send + Sync + Sized + 'static {
298    const TABLE: &'static str;
299    const COLUMNS: &'static [&'static str];
300    const INSERT_COLUMNS: &'static [&'static str];
301
302    fn id(&self) -> i64;
303    fn from_row(row: Row<'_>) -> Result<Self>;
304    fn insert_values(&self) -> Vec<Value>;
305}
306
307// ---- Generic CRUD helpers -----------------------------------------------
308
309// public:
310pub async fn all<M: Model>(db: &Db) -> Result<Vec<M>> {
311    let sql = format!(
312        "SELECT {} FROM {} ORDER BY id DESC",
313        M::COLUMNS.join(", "),
314        M::TABLE
315    );
316    let rows = sqlx::query(&sql).fetch_all(db.pool()).await?;
317    rows.iter().map(|r| M::from_row(Row::from_pg(r))).collect()
318}
319
320// public:
321pub async fn page<M: Model>(db: &Db, limit: i64, offset: i64) -> Result<Vec<M>> {
322    let sql = format!(
323        "SELECT {} FROM {} ORDER BY id DESC LIMIT $1 OFFSET $2",
324        M::COLUMNS.join(", "),
325        M::TABLE
326    );
327    let rows = sqlx::query(&sql)
328        .bind(limit)
329        .bind(offset)
330        .fetch_all(db.pool())
331        .await?;
332    rows.iter().map(|r| M::from_row(Row::from_pg(r))).collect()
333}
334
335// public:
336pub async fn count<M: Model>(db: &Db) -> Result<i64> {
337    let sql = format!("SELECT COUNT(*) AS c FROM {}", M::TABLE);
338    let row = sqlx::query(&sql).fetch_one(db.pool()).await?;
339    row.try_get::<i64, _>("c")
340        .map_err(|e| Error::Internal(format!("count: {e}")))
341}
342
343// public:
344pub async fn find<M: Model>(db: &Db, id: i64) -> Result<Option<M>> {
345    let sql = format!(
346        "SELECT {} FROM {} WHERE id = $1",
347        M::COLUMNS.join(", "),
348        M::TABLE
349    );
350    let row = sqlx::query(&sql).bind(id).fetch_optional(db.pool()).await?;
351    match row {
352        Some(r) => Ok(Some(M::from_row(Row::from_pg(&r))?)),
353        None => Ok(None),
354    }
355}
356
357// public:
358pub async fn create<M: Model>(db: &Db, model: &M) -> Result<i64> {
359    let cols = M::INSERT_COLUMNS.join(", ");
360    let placeholders: Vec<String> = (1..=M::INSERT_COLUMNS.len())
361        .map(|i| format!("${i}"))
362        .collect();
363    let sql = format!(
364        "INSERT INTO {} ({}) VALUES ({}) RETURNING id",
365        M::TABLE,
366        cols,
367        placeholders.join(", ")
368    );
369    let mut query = sqlx::query(&sql);
370    for value in model.insert_values() {
371        query = bind_value(query, value);
372    }
373    let row = query.fetch_one(db.pool()).await?;
374    let id: i64 = row
375        .try_get("id")
376        .map_err(|e| Error::Internal(format!("returning id: {e}")))?;
377    Ok(id)
378}
379
380// public:
381pub async fn update<M: Model>(db: &Db, id: i64, model: &M) -> Result<()> {
382    let sets: Vec<String> = M::INSERT_COLUMNS
383        .iter()
384        .enumerate()
385        .map(|(i, col)| format!("{col} = ${}", i + 1))
386        .collect();
387    let sql = format!(
388        "UPDATE {} SET {} WHERE id = ${}",
389        M::TABLE,
390        sets.join(", "),
391        M::INSERT_COLUMNS.len() + 1
392    );
393    let mut query = sqlx::query(&sql);
394    for value in model.insert_values() {
395        query = bind_value(query, value);
396    }
397    query = query.bind(id);
398    query.execute(db.pool()).await?;
399    Ok(())
400}
401
402// public:
403pub async fn delete<M: Model>(db: &Db, id: i64) -> Result<()> {
404    let sql = format!("DELETE FROM {} WHERE id = $1", M::TABLE);
405    sqlx::query(&sql).bind(id).execute(db.pool()).await?;
406    Ok(())
407}
408
409fn bind_value<'a>(
410    q: sqlx::query::Query<'a, sqlx::Postgres, sqlx::postgres::PgArguments>,
411    v: Value,
412) -> sqlx::query::Query<'a, sqlx::Postgres, sqlx::postgres::PgArguments> {
413    match v {
414        Value::Null => q.bind(None::<i64>),
415        Value::I32(n) => q.bind(n),
416        Value::I64(n) => q.bind(n),
417        Value::F64(n) => q.bind(n),
418        Value::Decimal(d) => q.bind(d),
419        Value::Bool(b) => q.bind(b),
420        Value::Text(s) => q.bind(s),
421        Value::DateTime(d) => q.bind(d),
422        Value::Date(d) => q.bind(d),
423        Value::Time(t) => q.bind(t),
424        Value::Uuid(u) => q.bind(u),
425        Value::Json(j) => q.bind(j),
426    }
427}
428
429#[cfg(test)]
430mod value_conversion_tests {
431    use super::Value;
432    use chrono::{NaiveDate, NaiveTime};
433    use rust_decimal::Decimal;
434    use std::str::FromStr;
435    use uuid::Uuid;
436
437    #[test]
438    fn scalar_from_impls_map_to_their_variants() {
439        assert!(matches!(Value::from(3.5_f64), Value::F64(v) if v == 3.5));
440        let dec = Decimal::from_str("19.99").unwrap();
441        assert!(matches!(Value::from(dec), Value::Decimal(v) if v == dec));
442        let d = NaiveDate::from_ymd_opt(2026, 6, 2).unwrap();
443        assert!(matches!(Value::from(d), Value::Date(v) if v == d));
444        let t = NaiveTime::from_hms_opt(9, 30, 0).unwrap();
445        assert!(matches!(Value::from(t), Value::Time(v) if v == t));
446        let u = Uuid::from_u128(0x550e8400_e29b_41d4_a716_446655440000);
447        assert!(matches!(Value::from(u), Value::Uuid(v) if v == u));
448    }
449}