1use std::time::Duration;
6
7use chrono::{DateTime, Utc};
8use serde_json::Value as JsonValue;
9use sqlx::postgres::{PgPoolOptions, PgRow};
10use sqlx::{PgPool, Row as SqlxRow};
11use uuid::Uuid;
12
13use crate::error::{Error, Result};
14
15#[derive(Clone)]
19pub struct Db {
20 pool: PgPool,
21}
22
23impl Db {
24 pub async fn connect(url: &str) -> Result<Self> {
28 Self::connect_with(url, DbOptions::default()).await
29 }
30
31 pub async fn connect_with(url: &str, opts: DbOptions) -> Result<Self> {
33 let pool = PgPoolOptions::new()
34 .max_connections(opts.max_connections)
35 .min_connections(opts.min_connections)
36 .acquire_timeout(opts.acquire_timeout)
37 .idle_timeout(Some(opts.idle_timeout))
38 .max_lifetime(Some(opts.max_lifetime))
39 .connect(url)
40 .await
41 .map_err(|e| Error::Internal(format!("db connect failed: {e}")))?;
42 Ok(Self { pool })
43 }
44
45 pub fn pool(&self) -> &PgPool {
47 &self.pool
48 }
49
50 pub async fn health_check(&self) -> Result<()> {
52 sqlx::query("SELECT 1")
53 .fetch_one(&self.pool)
54 .await
55 .map(|_| ())
56 .map_err(|e| Error::Internal(format!("health check: {e}")))
57 }
58
59 #[cfg(test)]
64 #[allow(dead_code)]
65 pub(crate) fn for_testing_no_connection() -> Self {
66 let pool = PgPoolOptions::new()
67 .max_connections(1)
68 .connect_lazy("postgres://test:test@127.0.0.1:1/never_used")
69 .expect("connect_lazy never fails on a syntactically valid URL");
70 Self { pool }
71 }
72}
73
74#[derive(Clone, Debug)]
76pub struct DbOptions {
77 pub max_connections: u32,
78 pub min_connections: u32,
79 pub acquire_timeout: Duration,
80 pub idle_timeout: Duration,
81 pub max_lifetime: Duration,
82}
83
84impl Default for DbOptions {
85 fn default() -> Self {
86 Self {
87 max_connections: 30,
88 min_connections: 2,
89 acquire_timeout: Duration::from_secs(1),
90 idle_timeout: Duration::from_secs(300),
91 max_lifetime: Duration::from_secs(1800),
92 }
93 }
94}
95
96#[derive(Debug, Clone)]
99pub enum Value {
100 Null,
101 I32(i32),
102 I64(i64),
103 Bool(bool),
104 Text(String),
105 DateTime(DateTime<Utc>),
106 Uuid(Uuid),
107 Json(JsonValue),
108}
109
110impl From<i32> for Value {
111 fn from(v: i32) -> Self {
112 Value::I32(v)
113 }
114}
115impl From<i64> for Value {
116 fn from(v: i64) -> Self {
117 Value::I64(v)
118 }
119}
120impl From<bool> for Value {
121 fn from(v: bool) -> Self {
122 Value::Bool(v)
123 }
124}
125impl From<String> for Value {
126 fn from(v: String) -> Self {
127 Value::Text(v)
128 }
129}
130impl<'a> From<&'a str> for Value {
131 fn from(v: &'a str) -> Self {
132 Value::Text(v.to_string())
133 }
134}
135impl From<DateTime<Utc>> for Value {
136 fn from(v: DateTime<Utc>) -> Self {
137 Value::DateTime(v)
138 }
139}
140impl From<Uuid> for Value {
141 fn from(v: Uuid) -> Self {
142 Value::Uuid(v)
143 }
144}
145impl From<JsonValue> for Value {
146 fn from(v: JsonValue) -> Self {
147 Value::Json(v)
148 }
149}
150impl<T: Into<Value>> From<Option<T>> for Value {
151 fn from(v: Option<T>) -> Self {
152 match v {
153 Some(v) => v.into(),
154 None => Value::Null,
155 }
156 }
157}
158
159pub struct Row<'a> {
161 inner: &'a PgRow,
162}
163
164impl<'a> Row<'a> {
165 pub fn from_pg(row: &'a PgRow) -> Self {
167 Self { inner: row }
168 }
169
170 pub fn get_i32(&self, col: &str) -> Result<i32> {
172 self.inner
173 .try_get::<i32, _>(col)
174 .map_err(|e| Error::Internal(format!("get_i32({col}): {e}")))
175 }
176 pub fn get_i64(&self, col: &str) -> Result<i64> {
178 self.inner
179 .try_get::<i64, _>(col)
180 .map_err(|e| Error::Internal(format!("get_i64({col}): {e}")))
181 }
182 pub fn get_optional_i64(&self, col: &str) -> Result<Option<i64>> {
184 self.inner
185 .try_get::<Option<i64>, _>(col)
186 .map_err(|e| Error::Internal(format!("{col}: {e}")))
187 }
188 pub fn get_bool(&self, col: &str) -> Result<bool> {
190 self.inner
191 .try_get::<bool, _>(col)
192 .map_err(|e| Error::Internal(format!("get_bool({col}): {e}")))
193 }
194 pub fn get_string(&self, col: &str) -> Result<String> {
196 self.inner
197 .try_get::<String, _>(col)
198 .map_err(|e| Error::Internal(format!("get_string({col}): {e}")))
199 }
200 pub fn get_optional_string(&self, col: &str) -> Result<Option<String>> {
202 self.inner
203 .try_get::<Option<String>, _>(col)
204 .map_err(|e| Error::Internal(format!("{col}: {e}")))
205 }
206 pub fn get_datetime(&self, col: &str) -> Result<DateTime<Utc>> {
208 self.inner
209 .try_get::<DateTime<Utc>, _>(col)
210 .map_err(|e| Error::Internal(format!("{col}: {e}")))
211 }
212 pub fn get_optional_datetime(&self, col: &str) -> Result<Option<DateTime<Utc>>> {
214 self.inner
215 .try_get::<Option<DateTime<Utc>>, _>(col)
216 .map_err(|e| Error::Internal(format!("{col}: {e}")))
217 }
218 pub fn get_uuid(&self, col: &str) -> Result<Uuid> {
220 self.inner
221 .try_get::<Uuid, _>(col)
222 .map_err(|e| Error::Internal(format!("get_uuid({col}): {e}")))
223 }
224 pub fn get_json(&self, col: &str) -> Result<JsonValue> {
226 self.inner
227 .try_get::<JsonValue, _>(col)
228 .map_err(|e| Error::Internal(format!("{col}: {e}")))
229 }
230}
231
232pub trait Model: Send + Sync + Sized + 'static {
234 const TABLE: &'static str;
235 const COLUMNS: &'static [&'static str];
236 const INSERT_COLUMNS: &'static [&'static str];
237
238 fn id(&self) -> i64;
239 fn from_row(row: Row<'_>) -> Result<Self>;
240 fn insert_values(&self) -> Vec<Value>;
241}
242
243pub async fn all<M: Model>(db: &Db) -> Result<Vec<M>> {
247 let sql = format!(
248 "SELECT {} FROM {} ORDER BY id DESC",
249 M::COLUMNS.join(", "),
250 M::TABLE
251 );
252 let rows = sqlx::query(&sql).fetch_all(db.pool()).await?;
253 rows.iter().map(|r| M::from_row(Row::from_pg(r))).collect()
254}
255
256pub async fn page<M: Model>(db: &Db, limit: i64, offset: i64) -> Result<Vec<M>> {
258 let sql = format!(
259 "SELECT {} FROM {} ORDER BY id DESC LIMIT $1 OFFSET $2",
260 M::COLUMNS.join(", "),
261 M::TABLE
262 );
263 let rows = sqlx::query(&sql)
264 .bind(limit)
265 .bind(offset)
266 .fetch_all(db.pool())
267 .await?;
268 rows.iter().map(|r| M::from_row(Row::from_pg(r))).collect()
269}
270
271pub async fn count<M: Model>(db: &Db) -> Result<i64> {
273 let sql = format!("SELECT COUNT(*) AS c FROM {}", M::TABLE);
274 let row = sqlx::query(&sql).fetch_one(db.pool()).await?;
275 row.try_get::<i64, _>("c")
276 .map_err(|e| Error::Internal(format!("count: {e}")))
277}
278
279pub async fn find<M: Model>(db: &Db, id: i64) -> Result<Option<M>> {
281 let sql = format!(
282 "SELECT {} FROM {} WHERE id = $1",
283 M::COLUMNS.join(", "),
284 M::TABLE
285 );
286 let row = sqlx::query(&sql).bind(id).fetch_optional(db.pool()).await?;
287 match row {
288 Some(r) => Ok(Some(M::from_row(Row::from_pg(&r))?)),
289 None => Ok(None),
290 }
291}
292
293pub async fn create<M: Model>(db: &Db, model: &M) -> Result<i64> {
295 let cols = M::INSERT_COLUMNS.join(", ");
296 let placeholders: Vec<String> = (1..=M::INSERT_COLUMNS.len())
297 .map(|i| format!("${i}"))
298 .collect();
299 let sql = format!(
300 "INSERT INTO {} ({}) VALUES ({}) RETURNING id",
301 M::TABLE,
302 cols,
303 placeholders.join(", ")
304 );
305 let mut query = sqlx::query(&sql);
306 for value in model.insert_values() {
307 query = bind_value(query, value);
308 }
309 let row = query.fetch_one(db.pool()).await?;
310 let id: i64 = row
311 .try_get("id")
312 .map_err(|e| Error::Internal(format!("returning id: {e}")))?;
313 Ok(id)
314}
315
316pub async fn update<M: Model>(db: &Db, id: i64, model: &M) -> Result<()> {
318 let sets: Vec<String> = M::INSERT_COLUMNS
319 .iter()
320 .enumerate()
321 .map(|(i, col)| format!("{col} = ${}", i + 1))
322 .collect();
323 let sql = format!(
324 "UPDATE {} SET {} WHERE id = ${}",
325 M::TABLE,
326 sets.join(", "),
327 M::INSERT_COLUMNS.len() + 1
328 );
329 let mut query = sqlx::query(&sql);
330 for value in model.insert_values() {
331 query = bind_value(query, value);
332 }
333 query = query.bind(id);
334 query.execute(db.pool()).await?;
335 Ok(())
336}
337
338pub async fn delete<M: Model>(db: &Db, id: i64) -> Result<()> {
340 let sql = format!("DELETE FROM {} WHERE id = $1", M::TABLE);
341 sqlx::query(&sql).bind(id).execute(db.pool()).await?;
342 Ok(())
343}
344
345fn bind_value<'a>(
346 q: sqlx::query::Query<'a, sqlx::Postgres, sqlx::postgres::PgArguments>,
347 v: Value,
348) -> sqlx::query::Query<'a, sqlx::Postgres, sqlx::postgres::PgArguments> {
349 match v {
350 Value::Null => q.bind(None::<i64>),
351 Value::I32(n) => q.bind(n),
352 Value::I64(n) => q.bind(n),
353 Value::Bool(b) => q.bind(b),
354 Value::Text(s) => q.bind(s),
355 Value::DateTime(d) => q.bind(d),
356 Value::Uuid(u) => q.bind(u),
357 Value::Json(j) => q.bind(j),
358 }
359}