1use std::time::Duration;
6
7use chrono::{DateTime, Utc};
8use serde_json::Value as JsonValue;
9use sqlx::postgres::{PgPoolOptions, PgRow};
10use sqlx::{PgPool, Row as SqlxRow};
11use uuid::Uuid;
12
13use crate::error::{Error, Result};
14
15#[derive(Clone)]
18pub struct Db {
19 pool: PgPool,
20}
21
22impl Db {
23 pub async fn connect(url: &str) -> Result<Self> {
26 Self::connect_with(url, DbOptions::default()).await
27 }
28
29 pub async fn connect_with(url: &str, opts: DbOptions) -> Result<Self> {
30 let pool = PgPoolOptions::new()
31 .max_connections(opts.max_connections)
32 .min_connections(opts.min_connections)
33 .acquire_timeout(opts.acquire_timeout)
34 .idle_timeout(Some(opts.idle_timeout))
35 .max_lifetime(Some(opts.max_lifetime))
36 .connect(url)
37 .await
38 .map_err(|e| Error::Internal(format!("db connect failed: {e}")))?;
39 Ok(Self { pool })
40 }
41
42 pub fn pool(&self) -> &PgPool {
43 &self.pool
44 }
45
46 pub async fn health_check(&self) -> Result<()> {
47 sqlx::query("SELECT 1")
48 .fetch_one(&self.pool)
49 .await
50 .map(|_| ())
51 .map_err(|e| Error::Internal(format!("health check: {e}")))
52 }
53
54 #[cfg(test)]
59 #[allow(dead_code)]
60 pub(crate) fn for_testing_no_connection() -> Self {
61 let pool = PgPoolOptions::new()
62 .max_connections(1)
63 .connect_lazy("postgres://test:test@127.0.0.1:1/never_used")
64 .expect("connect_lazy never fails on a syntactically valid URL");
65 Self { pool }
66 }
67}
68
69#[derive(Clone, Debug)]
70pub struct DbOptions {
71 pub max_connections: u32,
72 pub min_connections: u32,
73 pub acquire_timeout: Duration,
74 pub idle_timeout: Duration,
75 pub max_lifetime: Duration,
76}
77
78impl Default for DbOptions {
79 fn default() -> Self {
80 Self {
81 max_connections: 30,
82 min_connections: 2,
83 acquire_timeout: Duration::from_secs(1),
84 idle_timeout: Duration::from_secs(300),
85 max_lifetime: Duration::from_secs(1800),
86 }
87 }
88}
89
90#[derive(Debug, Clone)]
92pub enum Value {
93 Null,
94 I32(i32),
95 I64(i64),
96 Bool(bool),
97 Text(String),
98 DateTime(DateTime<Utc>),
99 Uuid(Uuid),
100 Json(JsonValue),
101}
102
103impl From<i32> for Value {
104 fn from(v: i32) -> Self {
105 Value::I32(v)
106 }
107}
108impl From<i64> for Value {
109 fn from(v: i64) -> Self {
110 Value::I64(v)
111 }
112}
113impl From<bool> for Value {
114 fn from(v: bool) -> Self {
115 Value::Bool(v)
116 }
117}
118impl From<String> for Value {
119 fn from(v: String) -> Self {
120 Value::Text(v)
121 }
122}
123impl<'a> From<&'a str> for Value {
124 fn from(v: &'a str) -> Self {
125 Value::Text(v.to_string())
126 }
127}
128impl From<DateTime<Utc>> for Value {
129 fn from(v: DateTime<Utc>) -> Self {
130 Value::DateTime(v)
131 }
132}
133impl From<Uuid> for Value {
134 fn from(v: Uuid) -> Self {
135 Value::Uuid(v)
136 }
137}
138impl From<JsonValue> for Value {
139 fn from(v: JsonValue) -> Self {
140 Value::Json(v)
141 }
142}
143impl<T: Into<Value>> From<Option<T>> for Value {
144 fn from(v: Option<T>) -> Self {
145 match v {
146 Some(v) => v.into(),
147 None => Value::Null,
148 }
149 }
150}
151
152pub struct Row<'a> {
153 inner: &'a PgRow,
154}
155
156impl<'a> Row<'a> {
157 pub fn from_pg(row: &'a PgRow) -> Self {
158 Self { inner: row }
159 }
160
161 pub fn get_i32(&self, col: &str) -> Result<i32> {
162 self.inner
163 .try_get::<i32, _>(col)
164 .map_err(|e| Error::Internal(format!("get_i32({col}): {e}")))
165 }
166 pub fn get_i64(&self, col: &str) -> Result<i64> {
167 self.inner
168 .try_get::<i64, _>(col)
169 .map_err(|e| Error::Internal(format!("get_i64({col}): {e}")))
170 }
171 pub fn get_optional_i64(&self, col: &str) -> Result<Option<i64>> {
172 self.inner
173 .try_get::<Option<i64>, _>(col)
174 .map_err(|e| Error::Internal(format!("{col}: {e}")))
175 }
176 pub fn get_bool(&self, col: &str) -> Result<bool> {
177 self.inner
178 .try_get::<bool, _>(col)
179 .map_err(|e| Error::Internal(format!("get_bool({col}): {e}")))
180 }
181 pub fn get_string(&self, col: &str) -> Result<String> {
182 self.inner
183 .try_get::<String, _>(col)
184 .map_err(|e| Error::Internal(format!("get_string({col}): {e}")))
185 }
186 pub fn get_optional_string(&self, col: &str) -> Result<Option<String>> {
187 self.inner
188 .try_get::<Option<String>, _>(col)
189 .map_err(|e| Error::Internal(format!("{col}: {e}")))
190 }
191 pub fn get_datetime(&self, col: &str) -> Result<DateTime<Utc>> {
192 self.inner
193 .try_get::<DateTime<Utc>, _>(col)
194 .map_err(|e| Error::Internal(format!("{col}: {e}")))
195 }
196 pub fn get_optional_datetime(&self, col: &str) -> Result<Option<DateTime<Utc>>> {
197 self.inner
198 .try_get::<Option<DateTime<Utc>>, _>(col)
199 .map_err(|e| Error::Internal(format!("{col}: {e}")))
200 }
201 pub fn get_uuid(&self, col: &str) -> Result<Uuid> {
202 self.inner
203 .try_get::<Uuid, _>(col)
204 .map_err(|e| Error::Internal(format!("get_uuid({col}): {e}")))
205 }
206 pub fn get_json(&self, col: &str) -> Result<JsonValue> {
207 self.inner
208 .try_get::<JsonValue, _>(col)
209 .map_err(|e| Error::Internal(format!("{col}: {e}")))
210 }
211}
212
213pub trait Model: Send + Sync + Sized + 'static {
214 const TABLE: &'static str;
215 const COLUMNS: &'static [&'static str];
216 const INSERT_COLUMNS: &'static [&'static str];
217
218 fn id(&self) -> i64;
219 fn from_row(row: Row<'_>) -> Result<Self>;
220 fn insert_values(&self) -> Vec<Value>;
221}
222
223pub async fn all<M: Model>(db: &Db) -> Result<Vec<M>> {
226 let sql = format!(
227 "SELECT {} FROM {} ORDER BY id DESC",
228 M::COLUMNS.join(", "),
229 M::TABLE
230 );
231 let rows = sqlx::query(&sql).fetch_all(db.pool()).await?;
232 rows.iter().map(|r| M::from_row(Row::from_pg(r))).collect()
233}
234
235pub async fn page<M: Model>(db: &Db, limit: i64, offset: i64) -> Result<Vec<M>> {
236 let sql = format!(
237 "SELECT {} FROM {} ORDER BY id DESC LIMIT $1 OFFSET $2",
238 M::COLUMNS.join(", "),
239 M::TABLE
240 );
241 let rows = sqlx::query(&sql)
242 .bind(limit)
243 .bind(offset)
244 .fetch_all(db.pool())
245 .await?;
246 rows.iter().map(|r| M::from_row(Row::from_pg(r))).collect()
247}
248
249pub async fn count<M: Model>(db: &Db) -> Result<i64> {
250 let sql = format!("SELECT COUNT(*) AS c FROM {}", M::TABLE);
251 let row = sqlx::query(&sql).fetch_one(db.pool()).await?;
252 row.try_get::<i64, _>("c")
253 .map_err(|e| Error::Internal(format!("count: {e}")))
254}
255
256pub async fn find<M: Model>(db: &Db, id: i64) -> Result<Option<M>> {
257 let sql = format!(
258 "SELECT {} FROM {} WHERE id = $1",
259 M::COLUMNS.join(", "),
260 M::TABLE
261 );
262 let row = sqlx::query(&sql).bind(id).fetch_optional(db.pool()).await?;
263 match row {
264 Some(r) => Ok(Some(M::from_row(Row::from_pg(&r))?)),
265 None => Ok(None),
266 }
267}
268
269pub async fn create<M: Model>(db: &Db, model: &M) -> Result<i64> {
270 let cols = M::INSERT_COLUMNS.join(", ");
271 let placeholders: Vec<String> = (1..=M::INSERT_COLUMNS.len())
272 .map(|i| format!("${i}"))
273 .collect();
274 let sql = format!(
275 "INSERT INTO {} ({}) VALUES ({}) RETURNING id",
276 M::TABLE,
277 cols,
278 placeholders.join(", ")
279 );
280 let mut query = sqlx::query(&sql);
281 for value in model.insert_values() {
282 query = bind_value(query, value);
283 }
284 let row = query.fetch_one(db.pool()).await?;
285 let id: i64 = row
286 .try_get("id")
287 .map_err(|e| Error::Internal(format!("returning id: {e}")))?;
288 Ok(id)
289}
290
291pub async fn update<M: Model>(db: &Db, id: i64, model: &M) -> Result<()> {
292 let sets: Vec<String> = M::INSERT_COLUMNS
293 .iter()
294 .enumerate()
295 .map(|(i, col)| format!("{col} = ${}", i + 1))
296 .collect();
297 let sql = format!(
298 "UPDATE {} SET {} WHERE id = ${}",
299 M::TABLE,
300 sets.join(", "),
301 M::INSERT_COLUMNS.len() + 1
302 );
303 let mut query = sqlx::query(&sql);
304 for value in model.insert_values() {
305 query = bind_value(query, value);
306 }
307 query = query.bind(id);
308 query.execute(db.pool()).await?;
309 Ok(())
310}
311
312pub async fn delete<M: Model>(db: &Db, id: i64) -> Result<()> {
313 let sql = format!("DELETE FROM {} WHERE id = $1", M::TABLE);
314 sqlx::query(&sql).bind(id).execute(db.pool()).await?;
315 Ok(())
316}
317
318fn bind_value<'a>(
319 q: sqlx::query::Query<'a, sqlx::Postgres, sqlx::postgres::PgArguments>,
320 v: Value,
321) -> sqlx::query::Query<'a, sqlx::Postgres, sqlx::postgres::PgArguments> {
322 match v {
323 Value::Null => q.bind(None::<i64>),
324 Value::I32(n) => q.bind(n),
325 Value::I64(n) => q.bind(n),
326 Value::Bool(b) => q.bind(b),
327 Value::Text(s) => q.bind(s),
328 Value::DateTime(d) => q.bind(d),
329 Value::Uuid(u) => q.bind(u),
330 Value::Json(j) => q.bind(j),
331 }
332}