Skip to main content

rustio_admin/
orm.rs

1//! PostgreSQL-backed ORM. A thin shim over sqlx — `Db`, `Model`,
2//! `Value`, `Row` — and a handful of generic CRUD helpers. Hand-written
3//! `impl Model` is the contract; users keep writing SQL where it matters.
4
5use std::time::Duration;
6
7use chrono::{DateTime, Utc};
8use serde_json::Value as JsonValue;
9use sqlx::postgres::{PgPoolOptions, PgRow};
10use sqlx::{PgPool, Row as SqlxRow};
11use uuid::Uuid;
12
13use crate::error::{Error, Result};
14
15/// Shared handle to the database. Cheap to clone; every handler gets
16/// its own clone.
17#[derive(Clone)]
18pub struct Db {
19    pool: PgPool,
20}
21
22impl Db {
23    /// Connect with sensible production defaults: 30 max connections,
24    /// 1s acquire timeout, 5min idle timeout.
25    pub async fn connect(url: &str) -> Result<Self> {
26        Self::connect_with(url, DbOptions::default()).await
27    }
28
29    pub async fn connect_with(url: &str, opts: DbOptions) -> Result<Self> {
30        let pool = PgPoolOptions::new()
31            .max_connections(opts.max_connections)
32            .min_connections(opts.min_connections)
33            .acquire_timeout(opts.acquire_timeout)
34            .idle_timeout(Some(opts.idle_timeout))
35            .max_lifetime(Some(opts.max_lifetime))
36            .connect(url)
37            .await
38            .map_err(|e| Error::Internal(format!("db connect failed: {e}")))?;
39        Ok(Self { pool })
40    }
41
42    pub fn pool(&self) -> &PgPool {
43        &self.pool
44    }
45
46    pub async fn health_check(&self) -> Result<()> {
47        sqlx::query("SELECT 1")
48            .fetch_one(&self.pool)
49            .await
50            .map(|_| ())
51            .map_err(|e| Error::Internal(format!("health check: {e}")))
52    }
53
54    /// Test-only constructor. Builds a `Db` whose pool is real but
55    /// never opens a TCP connection — `connect_lazy_with` defers that
56    /// until `.acquire()` is called. Tests that don't actually hit the
57    /// database can use this.
58    #[cfg(test)]
59    #[allow(dead_code)]
60    pub(crate) fn for_testing_no_connection() -> Self {
61        let pool = PgPoolOptions::new()
62            .max_connections(1)
63            .connect_lazy("postgres://test:test@127.0.0.1:1/never_used")
64            .expect("connect_lazy never fails on a syntactically valid URL");
65        Self { pool }
66    }
67}
68
69#[derive(Clone, Debug)]
70pub struct DbOptions {
71    pub max_connections: u32,
72    pub min_connections: u32,
73    pub acquire_timeout: Duration,
74    pub idle_timeout: Duration,
75    pub max_lifetime: Duration,
76}
77
78impl Default for DbOptions {
79    fn default() -> Self {
80        Self {
81            max_connections: 30,
82            min_connections: 2,
83            acquire_timeout: Duration::from_secs(1),
84            idle_timeout: Duration::from_secs(300),
85            max_lifetime: Duration::from_secs(1800),
86        }
87    }
88}
89
90/// The value types the framework understands. Kept small on purpose.
91#[derive(Debug, Clone)]
92pub enum Value {
93    Null,
94    I32(i32),
95    I64(i64),
96    Bool(bool),
97    Text(String),
98    DateTime(DateTime<Utc>),
99    Uuid(Uuid),
100    Json(JsonValue),
101}
102
103impl From<i32> for Value {
104    fn from(v: i32) -> Self {
105        Value::I32(v)
106    }
107}
108impl From<i64> for Value {
109    fn from(v: i64) -> Self {
110        Value::I64(v)
111    }
112}
113impl From<bool> for Value {
114    fn from(v: bool) -> Self {
115        Value::Bool(v)
116    }
117}
118impl From<String> for Value {
119    fn from(v: String) -> Self {
120        Value::Text(v)
121    }
122}
123impl<'a> From<&'a str> for Value {
124    fn from(v: &'a str) -> Self {
125        Value::Text(v.to_string())
126    }
127}
128impl From<DateTime<Utc>> for Value {
129    fn from(v: DateTime<Utc>) -> Self {
130        Value::DateTime(v)
131    }
132}
133impl From<Uuid> for Value {
134    fn from(v: Uuid) -> Self {
135        Value::Uuid(v)
136    }
137}
138impl From<JsonValue> for Value {
139    fn from(v: JsonValue) -> Self {
140        Value::Json(v)
141    }
142}
143impl<T: Into<Value>> From<Option<T>> for Value {
144    fn from(v: Option<T>) -> Self {
145        match v {
146            Some(v) => v.into(),
147            None => Value::Null,
148        }
149    }
150}
151
152pub struct Row<'a> {
153    inner: &'a PgRow,
154}
155
156impl<'a> Row<'a> {
157    pub fn from_pg(row: &'a PgRow) -> Self {
158        Self { inner: row }
159    }
160
161    pub fn get_i32(&self, col: &str) -> Result<i32> {
162        self.inner
163            .try_get::<i32, _>(col)
164            .map_err(|e| Error::Internal(format!("get_i32({col}): {e}")))
165    }
166    pub fn get_i64(&self, col: &str) -> Result<i64> {
167        self.inner
168            .try_get::<i64, _>(col)
169            .map_err(|e| Error::Internal(format!("get_i64({col}): {e}")))
170    }
171    pub fn get_optional_i64(&self, col: &str) -> Result<Option<i64>> {
172        self.inner
173            .try_get::<Option<i64>, _>(col)
174            .map_err(|e| Error::Internal(format!("{col}: {e}")))
175    }
176    pub fn get_bool(&self, col: &str) -> Result<bool> {
177        self.inner
178            .try_get::<bool, _>(col)
179            .map_err(|e| Error::Internal(format!("get_bool({col}): {e}")))
180    }
181    pub fn get_string(&self, col: &str) -> Result<String> {
182        self.inner
183            .try_get::<String, _>(col)
184            .map_err(|e| Error::Internal(format!("get_string({col}): {e}")))
185    }
186    pub fn get_optional_string(&self, col: &str) -> Result<Option<String>> {
187        self.inner
188            .try_get::<Option<String>, _>(col)
189            .map_err(|e| Error::Internal(format!("{col}: {e}")))
190    }
191    pub fn get_datetime(&self, col: &str) -> Result<DateTime<Utc>> {
192        self.inner
193            .try_get::<DateTime<Utc>, _>(col)
194            .map_err(|e| Error::Internal(format!("{col}: {e}")))
195    }
196    pub fn get_uuid(&self, col: &str) -> Result<Uuid> {
197        self.inner
198            .try_get::<Uuid, _>(col)
199            .map_err(|e| Error::Internal(format!("get_uuid({col}): {e}")))
200    }
201    pub fn get_json(&self, col: &str) -> Result<JsonValue> {
202        self.inner
203            .try_get::<JsonValue, _>(col)
204            .map_err(|e| Error::Internal(format!("{col}: {e}")))
205    }
206}
207
208pub trait Model: Send + Sync + Sized + 'static {
209    const TABLE: &'static str;
210    const COLUMNS: &'static [&'static str];
211    const INSERT_COLUMNS: &'static [&'static str];
212
213    fn id(&self) -> i64;
214    fn from_row(row: Row<'_>) -> Result<Self>;
215    fn insert_values(&self) -> Vec<Value>;
216}
217
218// ---- Generic CRUD helpers -----------------------------------------------
219
220pub async fn all<M: Model>(db: &Db) -> Result<Vec<M>> {
221    let sql = format!(
222        "SELECT {} FROM {} ORDER BY id DESC",
223        M::COLUMNS.join(", "),
224        M::TABLE
225    );
226    let rows = sqlx::query(&sql).fetch_all(db.pool()).await?;
227    rows.iter().map(|r| M::from_row(Row::from_pg(r))).collect()
228}
229
230pub async fn page<M: Model>(db: &Db, limit: i64, offset: i64) -> Result<Vec<M>> {
231    let sql = format!(
232        "SELECT {} FROM {} ORDER BY id DESC LIMIT $1 OFFSET $2",
233        M::COLUMNS.join(", "),
234        M::TABLE
235    );
236    let rows = sqlx::query(&sql)
237        .bind(limit)
238        .bind(offset)
239        .fetch_all(db.pool())
240        .await?;
241    rows.iter().map(|r| M::from_row(Row::from_pg(r))).collect()
242}
243
244pub async fn count<M: Model>(db: &Db) -> Result<i64> {
245    let sql = format!("SELECT COUNT(*) AS c FROM {}", M::TABLE);
246    let row = sqlx::query(&sql).fetch_one(db.pool()).await?;
247    row.try_get::<i64, _>("c")
248        .map_err(|e| Error::Internal(format!("count: {e}")))
249}
250
251pub async fn find<M: Model>(db: &Db, id: i64) -> Result<Option<M>> {
252    let sql = format!(
253        "SELECT {} FROM {} WHERE id = $1",
254        M::COLUMNS.join(", "),
255        M::TABLE
256    );
257    let row = sqlx::query(&sql).bind(id).fetch_optional(db.pool()).await?;
258    match row {
259        Some(r) => Ok(Some(M::from_row(Row::from_pg(&r))?)),
260        None => Ok(None),
261    }
262}
263
264pub async fn create<M: Model>(db: &Db, model: &M) -> Result<i64> {
265    let cols = M::INSERT_COLUMNS.join(", ");
266    let placeholders: Vec<String> = (1..=M::INSERT_COLUMNS.len())
267        .map(|i| format!("${i}"))
268        .collect();
269    let sql = format!(
270        "INSERT INTO {} ({}) VALUES ({}) RETURNING id",
271        M::TABLE,
272        cols,
273        placeholders.join(", ")
274    );
275    let mut query = sqlx::query(&sql);
276    for value in model.insert_values() {
277        query = bind_value(query, value);
278    }
279    let row = query.fetch_one(db.pool()).await?;
280    let id: i64 = row
281        .try_get("id")
282        .map_err(|e| Error::Internal(format!("returning id: {e}")))?;
283    Ok(id)
284}
285
286pub async fn update<M: Model>(db: &Db, id: i64, model: &M) -> Result<()> {
287    let sets: Vec<String> = M::INSERT_COLUMNS
288        .iter()
289        .enumerate()
290        .map(|(i, col)| format!("{col} = ${}", i + 1))
291        .collect();
292    let sql = format!(
293        "UPDATE {} SET {} WHERE id = ${}",
294        M::TABLE,
295        sets.join(", "),
296        M::INSERT_COLUMNS.len() + 1
297    );
298    let mut query = sqlx::query(&sql);
299    for value in model.insert_values() {
300        query = bind_value(query, value);
301    }
302    query = query.bind(id);
303    query.execute(db.pool()).await?;
304    Ok(())
305}
306
307pub async fn delete<M: Model>(db: &Db, id: i64) -> Result<()> {
308    let sql = format!("DELETE FROM {} WHERE id = $1", M::TABLE);
309    sqlx::query(&sql).bind(id).execute(db.pool()).await?;
310    Ok(())
311}
312
313fn bind_value<'a>(
314    q: sqlx::query::Query<'a, sqlx::Postgres, sqlx::postgres::PgArguments>,
315    v: Value,
316) -> sqlx::query::Query<'a, sqlx::Postgres, sqlx::postgres::PgArguments> {
317    match v {
318        Value::Null => q.bind(None::<i64>),
319        Value::I32(n) => q.bind(n),
320        Value::I64(n) => q.bind(n),
321        Value::Bool(b) => q.bind(b),
322        Value::Text(s) => q.bind(s),
323        Value::DateTime(d) => q.bind(d),
324        Value::Uuid(u) => q.bind(u),
325        Value::Json(j) => q.bind(j),
326    }
327}