1use std::time::Duration;
6
7use chrono::{DateTime, NaiveDate, NaiveTime, Utc};
8use rust_decimal::Decimal;
9use serde_json::Value as JsonValue;
10use sqlx::postgres::{PgPoolOptions, PgRow};
11use sqlx::{PgPool, Row as SqlxRow};
12use uuid::Uuid;
13
14use crate::error::{Error, Result};
15
16#[derive(Clone)]
20pub struct Db {
21 pool: PgPool,
22}
23
24impl Db {
25 pub async fn connect(url: &str) -> Result<Self> {
29 Self::connect_with(url, DbOptions::default()).await
30 }
31
32 pub async fn connect_with(url: &str, opts: DbOptions) -> Result<Self> {
34 let pool = PgPoolOptions::new()
35 .max_connections(opts.max_connections)
36 .min_connections(opts.min_connections)
37 .acquire_timeout(opts.acquire_timeout)
38 .idle_timeout(Some(opts.idle_timeout))
39 .max_lifetime(Some(opts.max_lifetime))
40 .after_connect(|conn, _meta| {
48 Box::pin(async move {
49 use sqlx::Executor;
50 conn.execute("SET client_min_messages = warning")
51 .await
52 .map(|_| ())
53 })
54 })
55 .connect(url)
56 .await
57 .map_err(|e| Error::Internal(format!("db connect failed: {e}")))?;
58 Ok(Self { pool })
59 }
60
61 pub fn pool(&self) -> &PgPool {
63 &self.pool
64 }
65
66 pub async fn health_check(&self) -> Result<()> {
68 sqlx::query("SELECT 1")
69 .fetch_one(&self.pool)
70 .await
71 .map(|_| ())
72 .map_err(|e| Error::Internal(format!("health check: {e}")))
73 }
74
75 #[cfg(test)]
80 #[allow(dead_code)]
81 pub(crate) fn for_testing_no_connection() -> Self {
82 let pool = PgPoolOptions::new()
83 .max_connections(1)
84 .connect_lazy("postgres://test:test@127.0.0.1:1/never_used")
85 .expect("connect_lazy never fails on a syntactically valid URL");
86 Self { pool }
87 }
88}
89
90#[derive(Clone, Debug)]
92pub struct DbOptions {
93 pub max_connections: u32,
94 pub min_connections: u32,
95 pub acquire_timeout: Duration,
96 pub idle_timeout: Duration,
97 pub max_lifetime: Duration,
98}
99
100impl Default for DbOptions {
101 fn default() -> Self {
102 Self {
103 max_connections: 30,
104 min_connections: 2,
105 acquire_timeout: Duration::from_secs(1),
106 idle_timeout: Duration::from_secs(300),
107 max_lifetime: Duration::from_secs(1800),
108 }
109 }
110}
111
112#[derive(Debug, Clone)]
115pub enum Value {
116 Null,
117 I32(i32),
118 I64(i64),
119 F64(f64),
120 Decimal(Decimal),
121 Bool(bool),
122 Text(String),
123 DateTime(DateTime<Utc>),
124 Date(NaiveDate),
125 Time(NaiveTime),
126 Uuid(Uuid),
127 Json(JsonValue),
128}
129
130impl From<i32> for Value {
131 fn from(v: i32) -> Self {
132 Value::I32(v)
133 }
134}
135impl From<i64> for Value {
136 fn from(v: i64) -> Self {
137 Value::I64(v)
138 }
139}
140impl From<f64> for Value {
141 fn from(v: f64) -> Self {
142 Value::F64(v)
143 }
144}
145impl From<Decimal> for Value {
146 fn from(v: Decimal) -> Self {
147 Value::Decimal(v)
148 }
149}
150impl From<bool> for Value {
151 fn from(v: bool) -> Self {
152 Value::Bool(v)
153 }
154}
155impl From<String> for Value {
156 fn from(v: String) -> Self {
157 Value::Text(v)
158 }
159}
160impl<'a> From<&'a str> for Value {
161 fn from(v: &'a str) -> Self {
162 Value::Text(v.to_string())
163 }
164}
165impl From<DateTime<Utc>> for Value {
166 fn from(v: DateTime<Utc>) -> Self {
167 Value::DateTime(v)
168 }
169}
170impl From<NaiveDate> for Value {
171 fn from(v: NaiveDate) -> Self {
172 Value::Date(v)
173 }
174}
175impl From<NaiveTime> for Value {
176 fn from(v: NaiveTime) -> Self {
177 Value::Time(v)
178 }
179}
180impl From<Uuid> for Value {
181 fn from(v: Uuid) -> Self {
182 Value::Uuid(v)
183 }
184}
185impl From<JsonValue> for Value {
186 fn from(v: JsonValue) -> Self {
187 Value::Json(v)
188 }
189}
190impl<T: Into<Value>> From<Option<T>> for Value {
191 fn from(v: Option<T>) -> Self {
192 match v {
193 Some(v) => v.into(),
194 None => Value::Null,
195 }
196 }
197}
198
199pub struct Row<'a> {
201 inner: &'a PgRow,
202}
203
204impl<'a> Row<'a> {
205 pub fn from_pg(row: &'a PgRow) -> Self {
207 Self { inner: row }
208 }
209
210 pub fn get_i32(&self, col: &str) -> Result<i32> {
212 self.inner
213 .try_get::<i32, _>(col)
214 .map_err(|e| Error::Internal(format!("get_i32({col}): {e}")))
215 }
216 pub fn get_i64(&self, col: &str) -> Result<i64> {
218 self.inner
219 .try_get::<i64, _>(col)
220 .map_err(|e| Error::Internal(format!("get_i64({col}): {e}")))
221 }
222 pub fn get_f64(&self, col: &str) -> Result<f64> {
224 self.inner
225 .try_get::<f64, _>(col)
226 .map_err(|e| Error::Internal(format!("get_f64({col}): {e}")))
227 }
228 pub fn get_decimal(&self, col: &str) -> Result<Decimal> {
230 self.inner
231 .try_get::<Decimal, _>(col)
232 .map_err(|e| Error::Internal(format!("get_decimal({col}): {e}")))
233 }
234 pub fn get_optional_i64(&self, col: &str) -> Result<Option<i64>> {
236 self.inner
237 .try_get::<Option<i64>, _>(col)
238 .map_err(|e| Error::Internal(format!("{col}: {e}")))
239 }
240 pub fn get_bool(&self, col: &str) -> Result<bool> {
242 self.inner
243 .try_get::<bool, _>(col)
244 .map_err(|e| Error::Internal(format!("get_bool({col}): {e}")))
245 }
246 pub fn get_string(&self, col: &str) -> Result<String> {
248 self.inner
249 .try_get::<String, _>(col)
250 .map_err(|e| Error::Internal(format!("get_string({col}): {e}")))
251 }
252 pub fn get_optional_string(&self, col: &str) -> Result<Option<String>> {
254 self.inner
255 .try_get::<Option<String>, _>(col)
256 .map_err(|e| Error::Internal(format!("{col}: {e}")))
257 }
258 pub fn get_datetime(&self, col: &str) -> Result<DateTime<Utc>> {
260 self.inner
261 .try_get::<DateTime<Utc>, _>(col)
262 .map_err(|e| Error::Internal(format!("{col}: {e}")))
263 }
264 pub fn get_date(&self, col: &str) -> Result<NaiveDate> {
266 self.inner
267 .try_get::<NaiveDate, _>(col)
268 .map_err(|e| Error::Internal(format!("get_date({col}): {e}")))
269 }
270 pub fn get_time(&self, col: &str) -> Result<NaiveTime> {
272 self.inner
273 .try_get::<NaiveTime, _>(col)
274 .map_err(|e| Error::Internal(format!("get_time({col}): {e}")))
275 }
276 pub fn get_optional_datetime(&self, col: &str) -> Result<Option<DateTime<Utc>>> {
278 self.inner
279 .try_get::<Option<DateTime<Utc>>, _>(col)
280 .map_err(|e| Error::Internal(format!("{col}: {e}")))
281 }
282 pub fn get_uuid(&self, col: &str) -> Result<Uuid> {
284 self.inner
285 .try_get::<Uuid, _>(col)
286 .map_err(|e| Error::Internal(format!("get_uuid({col}): {e}")))
287 }
288 pub fn get_json(&self, col: &str) -> Result<JsonValue> {
290 self.inner
291 .try_get::<JsonValue, _>(col)
292 .map_err(|e| Error::Internal(format!("{col}: {e}")))
293 }
294}
295
296pub trait Model: Send + Sync + Sized + 'static {
298 const TABLE: &'static str;
299 const COLUMNS: &'static [&'static str];
300 const INSERT_COLUMNS: &'static [&'static str];
301
302 fn id(&self) -> i64;
303 fn from_row(row: Row<'_>) -> Result<Self>;
304 fn insert_values(&self) -> Vec<Value>;
305}
306
307pub async fn all<M: Model>(db: &Db) -> Result<Vec<M>> {
311 let sql = format!(
312 "SELECT {} FROM {} ORDER BY id DESC",
313 M::COLUMNS.join(", "),
314 M::TABLE
315 );
316 let rows = sqlx::query(&sql).fetch_all(db.pool()).await?;
317 rows.iter().map(|r| M::from_row(Row::from_pg(r))).collect()
318}
319
320pub async fn page<M: Model>(db: &Db, limit: i64, offset: i64) -> Result<Vec<M>> {
322 let sql = format!(
323 "SELECT {} FROM {} ORDER BY id DESC LIMIT $1 OFFSET $2",
324 M::COLUMNS.join(", "),
325 M::TABLE
326 );
327 let rows = sqlx::query(&sql)
328 .bind(limit)
329 .bind(offset)
330 .fetch_all(db.pool())
331 .await?;
332 rows.iter().map(|r| M::from_row(Row::from_pg(r))).collect()
333}
334
335pub async fn count<M: Model>(db: &Db) -> Result<i64> {
337 let sql = format!("SELECT COUNT(*) AS c FROM {}", M::TABLE);
338 let row = sqlx::query(&sql).fetch_one(db.pool()).await?;
339 row.try_get::<i64, _>("c")
340 .map_err(|e| Error::Internal(format!("count: {e}")))
341}
342
343pub async fn find<M: Model>(db: &Db, id: i64) -> Result<Option<M>> {
345 let sql = format!(
346 "SELECT {} FROM {} WHERE id = $1",
347 M::COLUMNS.join(", "),
348 M::TABLE
349 );
350 let row = sqlx::query(&sql).bind(id).fetch_optional(db.pool()).await?;
351 match row {
352 Some(r) => Ok(Some(M::from_row(Row::from_pg(&r))?)),
353 None => Ok(None),
354 }
355}
356
357pub async fn create<M: Model>(db: &Db, model: &M) -> Result<i64> {
359 let cols = M::INSERT_COLUMNS.join(", ");
360 let placeholders: Vec<String> = (1..=M::INSERT_COLUMNS.len())
361 .map(|i| format!("${i}"))
362 .collect();
363 let sql = format!(
364 "INSERT INTO {} ({}) VALUES ({}) RETURNING id",
365 M::TABLE,
366 cols,
367 placeholders.join(", ")
368 );
369 let mut query = sqlx::query(&sql);
370 for value in model.insert_values() {
371 query = bind_value(query, value);
372 }
373 let row = query.fetch_one(db.pool()).await?;
374 let id: i64 = row
375 .try_get("id")
376 .map_err(|e| Error::Internal(format!("returning id: {e}")))?;
377 Ok(id)
378}
379
380pub async fn update<M: Model>(db: &Db, id: i64, model: &M) -> Result<()> {
382 let sets: Vec<String> = M::INSERT_COLUMNS
383 .iter()
384 .enumerate()
385 .map(|(i, col)| format!("{col} = ${}", i + 1))
386 .collect();
387 let sql = format!(
388 "UPDATE {} SET {} WHERE id = ${}",
389 M::TABLE,
390 sets.join(", "),
391 M::INSERT_COLUMNS.len() + 1
392 );
393 let mut query = sqlx::query(&sql);
394 for value in model.insert_values() {
395 query = bind_value(query, value);
396 }
397 query = query.bind(id);
398 query.execute(db.pool()).await?;
399 Ok(())
400}
401
402pub async fn delete<M: Model>(db: &Db, id: i64) -> Result<()> {
404 let sql = format!("DELETE FROM {} WHERE id = $1", M::TABLE);
405 sqlx::query(&sql).bind(id).execute(db.pool()).await?;
406 Ok(())
407}
408
409fn bind_value<'a>(
410 q: sqlx::query::Query<'a, sqlx::Postgres, sqlx::postgres::PgArguments>,
411 v: Value,
412) -> sqlx::query::Query<'a, sqlx::Postgres, sqlx::postgres::PgArguments> {
413 match v {
414 Value::Null => q.bind(None::<i64>),
415 Value::I32(n) => q.bind(n),
416 Value::I64(n) => q.bind(n),
417 Value::F64(n) => q.bind(n),
418 Value::Decimal(d) => q.bind(d),
419 Value::Bool(b) => q.bind(b),
420 Value::Text(s) => q.bind(s),
421 Value::DateTime(d) => q.bind(d),
422 Value::Date(d) => q.bind(d),
423 Value::Time(t) => q.bind(t),
424 Value::Uuid(u) => q.bind(u),
425 Value::Json(j) => q.bind(j),
426 }
427}
428
429#[cfg(test)]
430mod value_conversion_tests {
431 use super::Value;
432 use chrono::{NaiveDate, NaiveTime};
433 use rust_decimal::Decimal;
434 use std::str::FromStr;
435 use uuid::Uuid;
436
437 #[test]
438 fn scalar_from_impls_map_to_their_variants() {
439 assert!(matches!(Value::from(3.5_f64), Value::F64(v) if v == 3.5));
440 let dec = Decimal::from_str("19.99").unwrap();
441 assert!(matches!(Value::from(dec), Value::Decimal(v) if v == dec));
442 let d = NaiveDate::from_ymd_opt(2026, 6, 2).unwrap();
443 assert!(matches!(Value::from(d), Value::Date(v) if v == d));
444 let t = NaiveTime::from_hms_opt(9, 30, 0).unwrap();
445 assert!(matches!(Value::from(t), Value::Time(v) if v == t));
446 let u = Uuid::from_u128(0x550e8400_e29b_41d4_a716_446655440000);
447 assert!(matches!(Value::from(u), Value::Uuid(v) if v == u));
448 }
449}