1use std::time::Duration;
6
7use chrono::{DateTime, Utc};
8use serde_json::Value as JsonValue;
9use sqlx::postgres::{PgPoolOptions, PgRow};
10use sqlx::{PgPool, Row as SqlxRow};
11use uuid::Uuid;
12
13use crate::error::{Error, Result};
14
15#[derive(Clone)]
19pub struct Db {
20 pool: PgPool,
21}
22
23impl Db {
24 pub async fn connect(url: &str) -> Result<Self> {
28 Self::connect_with(url, DbOptions::default()).await
29 }
30
31 pub async fn connect_with(url: &str, opts: DbOptions) -> Result<Self> {
33 let pool = PgPoolOptions::new()
34 .max_connections(opts.max_connections)
35 .min_connections(opts.min_connections)
36 .acquire_timeout(opts.acquire_timeout)
37 .idle_timeout(Some(opts.idle_timeout))
38 .max_lifetime(Some(opts.max_lifetime))
39 .after_connect(|conn, _meta| {
47 Box::pin(async move {
48 use sqlx::Executor;
49 conn.execute("SET client_min_messages = warning")
50 .await
51 .map(|_| ())
52 })
53 })
54 .connect(url)
55 .await
56 .map_err(|e| Error::Internal(format!("db connect failed: {e}")))?;
57 Ok(Self { pool })
58 }
59
60 pub fn pool(&self) -> &PgPool {
62 &self.pool
63 }
64
65 pub async fn health_check(&self) -> Result<()> {
67 sqlx::query("SELECT 1")
68 .fetch_one(&self.pool)
69 .await
70 .map(|_| ())
71 .map_err(|e| Error::Internal(format!("health check: {e}")))
72 }
73
74 #[cfg(test)]
79 #[allow(dead_code)]
80 pub(crate) fn for_testing_no_connection() -> Self {
81 let pool = PgPoolOptions::new()
82 .max_connections(1)
83 .connect_lazy("postgres://test:test@127.0.0.1:1/never_used")
84 .expect("connect_lazy never fails on a syntactically valid URL");
85 Self { pool }
86 }
87}
88
89#[derive(Clone, Debug)]
91pub struct DbOptions {
92 pub max_connections: u32,
93 pub min_connections: u32,
94 pub acquire_timeout: Duration,
95 pub idle_timeout: Duration,
96 pub max_lifetime: Duration,
97}
98
99impl Default for DbOptions {
100 fn default() -> Self {
101 Self {
102 max_connections: 30,
103 min_connections: 2,
104 acquire_timeout: Duration::from_secs(1),
105 idle_timeout: Duration::from_secs(300),
106 max_lifetime: Duration::from_secs(1800),
107 }
108 }
109}
110
111#[derive(Debug, Clone)]
114pub enum Value {
115 Null,
116 I32(i32),
117 I64(i64),
118 Bool(bool),
119 Text(String),
120 DateTime(DateTime<Utc>),
121 Uuid(Uuid),
122 Json(JsonValue),
123}
124
125impl From<i32> for Value {
126 fn from(v: i32) -> Self {
127 Value::I32(v)
128 }
129}
130impl From<i64> for Value {
131 fn from(v: i64) -> Self {
132 Value::I64(v)
133 }
134}
135impl From<bool> for Value {
136 fn from(v: bool) -> Self {
137 Value::Bool(v)
138 }
139}
140impl From<String> for Value {
141 fn from(v: String) -> Self {
142 Value::Text(v)
143 }
144}
145impl<'a> From<&'a str> for Value {
146 fn from(v: &'a str) -> Self {
147 Value::Text(v.to_string())
148 }
149}
150impl From<DateTime<Utc>> for Value {
151 fn from(v: DateTime<Utc>) -> Self {
152 Value::DateTime(v)
153 }
154}
155impl From<Uuid> for Value {
156 fn from(v: Uuid) -> Self {
157 Value::Uuid(v)
158 }
159}
160impl From<JsonValue> for Value {
161 fn from(v: JsonValue) -> Self {
162 Value::Json(v)
163 }
164}
165impl<T: Into<Value>> From<Option<T>> for Value {
166 fn from(v: Option<T>) -> Self {
167 match v {
168 Some(v) => v.into(),
169 None => Value::Null,
170 }
171 }
172}
173
174pub struct Row<'a> {
176 inner: &'a PgRow,
177}
178
179impl<'a> Row<'a> {
180 pub fn from_pg(row: &'a PgRow) -> Self {
182 Self { inner: row }
183 }
184
185 pub fn get_i32(&self, col: &str) -> Result<i32> {
187 self.inner
188 .try_get::<i32, _>(col)
189 .map_err(|e| Error::Internal(format!("get_i32({col}): {e}")))
190 }
191 pub fn get_i64(&self, col: &str) -> Result<i64> {
193 self.inner
194 .try_get::<i64, _>(col)
195 .map_err(|e| Error::Internal(format!("get_i64({col}): {e}")))
196 }
197 pub fn get_optional_i64(&self, col: &str) -> Result<Option<i64>> {
199 self.inner
200 .try_get::<Option<i64>, _>(col)
201 .map_err(|e| Error::Internal(format!("{col}: {e}")))
202 }
203 pub fn get_bool(&self, col: &str) -> Result<bool> {
205 self.inner
206 .try_get::<bool, _>(col)
207 .map_err(|e| Error::Internal(format!("get_bool({col}): {e}")))
208 }
209 pub fn get_string(&self, col: &str) -> Result<String> {
211 self.inner
212 .try_get::<String, _>(col)
213 .map_err(|e| Error::Internal(format!("get_string({col}): {e}")))
214 }
215 pub fn get_optional_string(&self, col: &str) -> Result<Option<String>> {
217 self.inner
218 .try_get::<Option<String>, _>(col)
219 .map_err(|e| Error::Internal(format!("{col}: {e}")))
220 }
221 pub fn get_datetime(&self, col: &str) -> Result<DateTime<Utc>> {
223 self.inner
224 .try_get::<DateTime<Utc>, _>(col)
225 .map_err(|e| Error::Internal(format!("{col}: {e}")))
226 }
227 pub fn get_optional_datetime(&self, col: &str) -> Result<Option<DateTime<Utc>>> {
229 self.inner
230 .try_get::<Option<DateTime<Utc>>, _>(col)
231 .map_err(|e| Error::Internal(format!("{col}: {e}")))
232 }
233 pub fn get_uuid(&self, col: &str) -> Result<Uuid> {
235 self.inner
236 .try_get::<Uuid, _>(col)
237 .map_err(|e| Error::Internal(format!("get_uuid({col}): {e}")))
238 }
239 pub fn get_json(&self, col: &str) -> Result<JsonValue> {
241 self.inner
242 .try_get::<JsonValue, _>(col)
243 .map_err(|e| Error::Internal(format!("{col}: {e}")))
244 }
245}
246
247pub trait Model: Send + Sync + Sized + 'static {
249 const TABLE: &'static str;
250 const COLUMNS: &'static [&'static str];
251 const INSERT_COLUMNS: &'static [&'static str];
252
253 fn id(&self) -> i64;
254 fn from_row(row: Row<'_>) -> Result<Self>;
255 fn insert_values(&self) -> Vec<Value>;
256}
257
258pub async fn all<M: Model>(db: &Db) -> Result<Vec<M>> {
262 let sql = format!(
263 "SELECT {} FROM {} ORDER BY id DESC",
264 M::COLUMNS.join(", "),
265 M::TABLE
266 );
267 let rows = sqlx::query(&sql).fetch_all(db.pool()).await?;
268 rows.iter().map(|r| M::from_row(Row::from_pg(r))).collect()
269}
270
271pub async fn page<M: Model>(db: &Db, limit: i64, offset: i64) -> Result<Vec<M>> {
273 let sql = format!(
274 "SELECT {} FROM {} ORDER BY id DESC LIMIT $1 OFFSET $2",
275 M::COLUMNS.join(", "),
276 M::TABLE
277 );
278 let rows = sqlx::query(&sql)
279 .bind(limit)
280 .bind(offset)
281 .fetch_all(db.pool())
282 .await?;
283 rows.iter().map(|r| M::from_row(Row::from_pg(r))).collect()
284}
285
286pub async fn count<M: Model>(db: &Db) -> Result<i64> {
288 let sql = format!("SELECT COUNT(*) AS c FROM {}", M::TABLE);
289 let row = sqlx::query(&sql).fetch_one(db.pool()).await?;
290 row.try_get::<i64, _>("c")
291 .map_err(|e| Error::Internal(format!("count: {e}")))
292}
293
294pub async fn find<M: Model>(db: &Db, id: i64) -> Result<Option<M>> {
296 let sql = format!(
297 "SELECT {} FROM {} WHERE id = $1",
298 M::COLUMNS.join(", "),
299 M::TABLE
300 );
301 let row = sqlx::query(&sql).bind(id).fetch_optional(db.pool()).await?;
302 match row {
303 Some(r) => Ok(Some(M::from_row(Row::from_pg(&r))?)),
304 None => Ok(None),
305 }
306}
307
308pub async fn create<M: Model>(db: &Db, model: &M) -> Result<i64> {
310 let cols = M::INSERT_COLUMNS.join(", ");
311 let placeholders: Vec<String> = (1..=M::INSERT_COLUMNS.len())
312 .map(|i| format!("${i}"))
313 .collect();
314 let sql = format!(
315 "INSERT INTO {} ({}) VALUES ({}) RETURNING id",
316 M::TABLE,
317 cols,
318 placeholders.join(", ")
319 );
320 let mut query = sqlx::query(&sql);
321 for value in model.insert_values() {
322 query = bind_value(query, value);
323 }
324 let row = query.fetch_one(db.pool()).await?;
325 let id: i64 = row
326 .try_get("id")
327 .map_err(|e| Error::Internal(format!("returning id: {e}")))?;
328 Ok(id)
329}
330
331pub async fn update<M: Model>(db: &Db, id: i64, model: &M) -> Result<()> {
333 let sets: Vec<String> = M::INSERT_COLUMNS
334 .iter()
335 .enumerate()
336 .map(|(i, col)| format!("{col} = ${}", i + 1))
337 .collect();
338 let sql = format!(
339 "UPDATE {} SET {} WHERE id = ${}",
340 M::TABLE,
341 sets.join(", "),
342 M::INSERT_COLUMNS.len() + 1
343 );
344 let mut query = sqlx::query(&sql);
345 for value in model.insert_values() {
346 query = bind_value(query, value);
347 }
348 query = query.bind(id);
349 query.execute(db.pool()).await?;
350 Ok(())
351}
352
353pub async fn delete<M: Model>(db: &Db, id: i64) -> Result<()> {
355 let sql = format!("DELETE FROM {} WHERE id = $1", M::TABLE);
356 sqlx::query(&sql).bind(id).execute(db.pool()).await?;
357 Ok(())
358}
359
360fn bind_value<'a>(
361 q: sqlx::query::Query<'a, sqlx::Postgres, sqlx::postgres::PgArguments>,
362 v: Value,
363) -> sqlx::query::Query<'a, sqlx::Postgres, sqlx::postgres::PgArguments> {
364 match v {
365 Value::Null => q.bind(None::<i64>),
366 Value::I32(n) => q.bind(n),
367 Value::I64(n) => q.bind(n),
368 Value::Bool(b) => q.bind(b),
369 Value::Text(s) => q.bind(s),
370 Value::DateTime(d) => q.bind(d),
371 Value::Uuid(u) => q.bind(u),
372 Value::Json(j) => q.bind(j),
373 }
374}