Skip to main content

rustio_admin/
orm.rs

1//! PostgreSQL-backed ORM. A thin shim over sqlx — `Db`, `Model`,
2//! `Value`, `Row` — and a handful of generic CRUD helpers. Hand-written
3//! `impl Model` is the contract; users keep writing SQL where it matters.
4
5use std::time::Duration;
6
7use chrono::{DateTime, Utc};
8use serde_json::Value as JsonValue;
9use sqlx::postgres::{PgPoolOptions, PgRow};
10use sqlx::{PgPool, Row as SqlxRow};
11use uuid::Uuid;
12
13use crate::error::{Error, Result};
14
15// public:
16/// Shared handle to the database. Cheap to clone; every handler gets
17/// its own clone.
18#[derive(Clone)]
19pub struct Db {
20    pool: PgPool,
21}
22
23impl Db {
24    // public:
25    /// Connect with sensible production defaults: 30 max connections,
26    /// 1s acquire timeout, 5min idle timeout.
27    pub async fn connect(url: &str) -> Result<Self> {
28        Self::connect_with(url, DbOptions::default()).await
29    }
30
31    // public:
32    pub async fn connect_with(url: &str, opts: DbOptions) -> Result<Self> {
33        let pool = PgPoolOptions::new()
34            .max_connections(opts.max_connections)
35            .min_connections(opts.min_connections)
36            .acquire_timeout(opts.acquire_timeout)
37            .idle_timeout(Some(opts.idle_timeout))
38            .max_lifetime(Some(opts.max_lifetime))
39            .connect(url)
40            .await
41            .map_err(|e| Error::Internal(format!("db connect failed: {e}")))?;
42        Ok(Self { pool })
43    }
44
45    // public:
46    pub fn pool(&self) -> &PgPool {
47        &self.pool
48    }
49
50    // public:
51    pub async fn health_check(&self) -> Result<()> {
52        sqlx::query("SELECT 1")
53            .fetch_one(&self.pool)
54            .await
55            .map(|_| ())
56            .map_err(|e| Error::Internal(format!("health check: {e}")))
57    }
58
59    /// Test-only constructor. Builds a `Db` whose pool is real but
60    /// never opens a TCP connection — `connect_lazy_with` defers that
61    /// until `.acquire()` is called. Tests that don't actually hit the
62    /// database can use this.
63    #[cfg(test)]
64    #[allow(dead_code)]
65    pub(crate) fn for_testing_no_connection() -> Self {
66        let pool = PgPoolOptions::new()
67            .max_connections(1)
68            .connect_lazy("postgres://test:test@127.0.0.1:1/never_used")
69            .expect("connect_lazy never fails on a syntactically valid URL");
70        Self { pool }
71    }
72}
73
74// public:
75#[derive(Clone, Debug)]
76pub struct DbOptions {
77    pub max_connections: u32,
78    pub min_connections: u32,
79    pub acquire_timeout: Duration,
80    pub idle_timeout: Duration,
81    pub max_lifetime: Duration,
82}
83
84impl Default for DbOptions {
85    fn default() -> Self {
86        Self {
87            max_connections: 30,
88            min_connections: 2,
89            acquire_timeout: Duration::from_secs(1),
90            idle_timeout: Duration::from_secs(300),
91            max_lifetime: Duration::from_secs(1800),
92        }
93    }
94}
95
96// public:
97/// The value types the framework understands. Kept small on purpose.
98#[derive(Debug, Clone)]
99pub enum Value {
100    Null,
101    I32(i32),
102    I64(i64),
103    Bool(bool),
104    Text(String),
105    DateTime(DateTime<Utc>),
106    Uuid(Uuid),
107    Json(JsonValue),
108}
109
110impl From<i32> for Value {
111    fn from(v: i32) -> Self {
112        Value::I32(v)
113    }
114}
115impl From<i64> for Value {
116    fn from(v: i64) -> Self {
117        Value::I64(v)
118    }
119}
120impl From<bool> for Value {
121    fn from(v: bool) -> Self {
122        Value::Bool(v)
123    }
124}
125impl From<String> for Value {
126    fn from(v: String) -> Self {
127        Value::Text(v)
128    }
129}
130impl<'a> From<&'a str> for Value {
131    fn from(v: &'a str) -> Self {
132        Value::Text(v.to_string())
133    }
134}
135impl From<DateTime<Utc>> for Value {
136    fn from(v: DateTime<Utc>) -> Self {
137        Value::DateTime(v)
138    }
139}
140impl From<Uuid> for Value {
141    fn from(v: Uuid) -> Self {
142        Value::Uuid(v)
143    }
144}
145impl From<JsonValue> for Value {
146    fn from(v: JsonValue) -> Self {
147        Value::Json(v)
148    }
149}
150impl<T: Into<Value>> From<Option<T>> for Value {
151    fn from(v: Option<T>) -> Self {
152        match v {
153            Some(v) => v.into(),
154            None => Value::Null,
155        }
156    }
157}
158
159// public:
160pub struct Row<'a> {
161    inner: &'a PgRow,
162}
163
164impl<'a> Row<'a> {
165    // public:
166    pub fn from_pg(row: &'a PgRow) -> Self {
167        Self { inner: row }
168    }
169
170    // public:
171    pub fn get_i32(&self, col: &str) -> Result<i32> {
172        self.inner
173            .try_get::<i32, _>(col)
174            .map_err(|e| Error::Internal(format!("get_i32({col}): {e}")))
175    }
176    // public:
177    pub fn get_i64(&self, col: &str) -> Result<i64> {
178        self.inner
179            .try_get::<i64, _>(col)
180            .map_err(|e| Error::Internal(format!("get_i64({col}): {e}")))
181    }
182    // public:
183    pub fn get_optional_i64(&self, col: &str) -> Result<Option<i64>> {
184        self.inner
185            .try_get::<Option<i64>, _>(col)
186            .map_err(|e| Error::Internal(format!("{col}: {e}")))
187    }
188    // public:
189    pub fn get_bool(&self, col: &str) -> Result<bool> {
190        self.inner
191            .try_get::<bool, _>(col)
192            .map_err(|e| Error::Internal(format!("get_bool({col}): {e}")))
193    }
194    // public:
195    pub fn get_string(&self, col: &str) -> Result<String> {
196        self.inner
197            .try_get::<String, _>(col)
198            .map_err(|e| Error::Internal(format!("get_string({col}): {e}")))
199    }
200    // public:
201    pub fn get_optional_string(&self, col: &str) -> Result<Option<String>> {
202        self.inner
203            .try_get::<Option<String>, _>(col)
204            .map_err(|e| Error::Internal(format!("{col}: {e}")))
205    }
206    // public:
207    pub fn get_datetime(&self, col: &str) -> Result<DateTime<Utc>> {
208        self.inner
209            .try_get::<DateTime<Utc>, _>(col)
210            .map_err(|e| Error::Internal(format!("{col}: {e}")))
211    }
212    // public:
213    pub fn get_optional_datetime(&self, col: &str) -> Result<Option<DateTime<Utc>>> {
214        self.inner
215            .try_get::<Option<DateTime<Utc>>, _>(col)
216            .map_err(|e| Error::Internal(format!("{col}: {e}")))
217    }
218    // public:
219    pub fn get_uuid(&self, col: &str) -> Result<Uuid> {
220        self.inner
221            .try_get::<Uuid, _>(col)
222            .map_err(|e| Error::Internal(format!("get_uuid({col}): {e}")))
223    }
224    // public:
225    pub fn get_json(&self, col: &str) -> Result<JsonValue> {
226        self.inner
227            .try_get::<JsonValue, _>(col)
228            .map_err(|e| Error::Internal(format!("{col}: {e}")))
229    }
230}
231
232// public:
233pub trait Model: Send + Sync + Sized + 'static {
234    const TABLE: &'static str;
235    const COLUMNS: &'static [&'static str];
236    const INSERT_COLUMNS: &'static [&'static str];
237
238    fn id(&self) -> i64;
239    fn from_row(row: Row<'_>) -> Result<Self>;
240    fn insert_values(&self) -> Vec<Value>;
241}
242
243// ---- Generic CRUD helpers -----------------------------------------------
244
245// public:
246pub async fn all<M: Model>(db: &Db) -> Result<Vec<M>> {
247    let sql = format!(
248        "SELECT {} FROM {} ORDER BY id DESC",
249        M::COLUMNS.join(", "),
250        M::TABLE
251    );
252    let rows = sqlx::query(&sql).fetch_all(db.pool()).await?;
253    rows.iter().map(|r| M::from_row(Row::from_pg(r))).collect()
254}
255
256// public:
257pub async fn page<M: Model>(db: &Db, limit: i64, offset: i64) -> Result<Vec<M>> {
258    let sql = format!(
259        "SELECT {} FROM {} ORDER BY id DESC LIMIT $1 OFFSET $2",
260        M::COLUMNS.join(", "),
261        M::TABLE
262    );
263    let rows = sqlx::query(&sql)
264        .bind(limit)
265        .bind(offset)
266        .fetch_all(db.pool())
267        .await?;
268    rows.iter().map(|r| M::from_row(Row::from_pg(r))).collect()
269}
270
271// public:
272pub async fn count<M: Model>(db: &Db) -> Result<i64> {
273    let sql = format!("SELECT COUNT(*) AS c FROM {}", M::TABLE);
274    let row = sqlx::query(&sql).fetch_one(db.pool()).await?;
275    row.try_get::<i64, _>("c")
276        .map_err(|e| Error::Internal(format!("count: {e}")))
277}
278
279// public:
280pub async fn find<M: Model>(db: &Db, id: i64) -> Result<Option<M>> {
281    let sql = format!(
282        "SELECT {} FROM {} WHERE id = $1",
283        M::COLUMNS.join(", "),
284        M::TABLE
285    );
286    let row = sqlx::query(&sql).bind(id).fetch_optional(db.pool()).await?;
287    match row {
288        Some(r) => Ok(Some(M::from_row(Row::from_pg(&r))?)),
289        None => Ok(None),
290    }
291}
292
293// public:
294pub async fn create<M: Model>(db: &Db, model: &M) -> Result<i64> {
295    let cols = M::INSERT_COLUMNS.join(", ");
296    let placeholders: Vec<String> = (1..=M::INSERT_COLUMNS.len())
297        .map(|i| format!("${i}"))
298        .collect();
299    let sql = format!(
300        "INSERT INTO {} ({}) VALUES ({}) RETURNING id",
301        M::TABLE,
302        cols,
303        placeholders.join(", ")
304    );
305    let mut query = sqlx::query(&sql);
306    for value in model.insert_values() {
307        query = bind_value(query, value);
308    }
309    let row = query.fetch_one(db.pool()).await?;
310    let id: i64 = row
311        .try_get("id")
312        .map_err(|e| Error::Internal(format!("returning id: {e}")))?;
313    Ok(id)
314}
315
316// public:
317pub async fn update<M: Model>(db: &Db, id: i64, model: &M) -> Result<()> {
318    let sets: Vec<String> = M::INSERT_COLUMNS
319        .iter()
320        .enumerate()
321        .map(|(i, col)| format!("{col} = ${}", i + 1))
322        .collect();
323    let sql = format!(
324        "UPDATE {} SET {} WHERE id = ${}",
325        M::TABLE,
326        sets.join(", "),
327        M::INSERT_COLUMNS.len() + 1
328    );
329    let mut query = sqlx::query(&sql);
330    for value in model.insert_values() {
331        query = bind_value(query, value);
332    }
333    query = query.bind(id);
334    query.execute(db.pool()).await?;
335    Ok(())
336}
337
338// public:
339pub async fn delete<M: Model>(db: &Db, id: i64) -> Result<()> {
340    let sql = format!("DELETE FROM {} WHERE id = $1", M::TABLE);
341    sqlx::query(&sql).bind(id).execute(db.pool()).await?;
342    Ok(())
343}
344
345fn bind_value<'a>(
346    q: sqlx::query::Query<'a, sqlx::Postgres, sqlx::postgres::PgArguments>,
347    v: Value,
348) -> sqlx::query::Query<'a, sqlx::Postgres, sqlx::postgres::PgArguments> {
349    match v {
350        Value::Null => q.bind(None::<i64>),
351        Value::I32(n) => q.bind(n),
352        Value::I64(n) => q.bind(n),
353        Value::Bool(b) => q.bind(b),
354        Value::Text(s) => q.bind(s),
355        Value::DateTime(d) => q.bind(d),
356        Value::Uuid(u) => q.bind(u),
357        Value::Json(j) => q.bind(j),
358    }
359}