1use std::sync::Arc;
13use std::time::Duration;
14
15use chrono::{DateTime, Utc};
16use serde_json::Value as JsonValue;
17use sqlx::postgres::{PgPoolOptions, PgRow};
18use sqlx::{PgPool, Row as SqlxRow};
19use uuid::Uuid;
20
21use crate::cache::QueryCache;
22use crate::error::{Error, Result};
23
24#[derive(Clone)]
27pub struct Db {
28 pool: PgPool,
29 cache: Arc<QueryCache>,
30}
31
32impl Db {
33 pub async fn connect(url: &str) -> Result<Self> {
36 Self::connect_with(url, DbOptions::default()).await
37 }
38
39 pub async fn connect_with(url: &str, opts: DbOptions) -> Result<Self> {
40 let pool = PgPoolOptions::new()
41 .max_connections(opts.max_connections)
42 .min_connections(opts.min_connections)
43 .acquire_timeout(opts.acquire_timeout)
44 .idle_timeout(Some(opts.idle_timeout))
45 .max_lifetime(Some(opts.max_lifetime))
46 .connect(url)
47 .await
48 .map_err(|e| Error::Internal(format!("db connect failed: {e}")))?;
49 Ok(Self {
50 pool,
51 cache: Arc::new(QueryCache::new(opts.cache_capacity)),
52 })
53 }
54
55 pub fn pool(&self) -> &PgPool {
56 &self.pool
57 }
58
59 pub fn cache(&self) -> &QueryCache {
60 &self.cache
61 }
62
63 pub fn invalidate(&self, table: &str) {
66 self.cache.invalidate_prefix(table);
67 }
68
69 pub async fn health_check(&self) -> Result<()> {
70 sqlx::query("SELECT 1")
71 .fetch_one(&self.pool)
72 .await
73 .map(|_| ())
74 .map_err(|e| Error::Internal(format!("health check: {e}")))
75 }
76
77 #[cfg(test)]
85 pub(crate) fn for_testing_no_connection() -> Self {
86 let pool = PgPoolOptions::new()
87 .max_connections(1)
88 .connect_lazy("postgres://test:test@127.0.0.1:1/never_used")
89 .expect("connect_lazy never fails on a syntactically valid URL");
90 Self {
91 pool,
92 cache: Arc::new(QueryCache::new(8)),
93 }
94 }
95}
96
97#[derive(Clone, Debug)]
98pub struct DbOptions {
99 pub max_connections: u32,
100 pub min_connections: u32,
101 pub acquire_timeout: Duration,
102 pub idle_timeout: Duration,
103 pub max_lifetime: Duration,
104 pub cache_capacity: usize,
105}
106
107impl Default for DbOptions {
108 fn default() -> Self {
109 Self {
110 max_connections: 30,
111 min_connections: 2,
112 acquire_timeout: Duration::from_secs(1),
113 idle_timeout: Duration::from_secs(300),
114 max_lifetime: Duration::from_secs(1800),
115 cache_capacity: 2048,
116 }
117 }
118}
119
120#[derive(Debug, Clone)]
122pub enum Value {
123 Null,
124 I32(i32),
125 I64(i64),
126 Bool(bool),
127 Text(String),
128 DateTime(DateTime<Utc>),
129 Uuid(Uuid),
130 Json(JsonValue),
131}
132
133impl From<i32> for Value { fn from(v: i32) -> Self { Value::I32(v) } }
134impl From<i64> for Value { fn from(v: i64) -> Self { Value::I64(v) } }
135impl From<bool> for Value { fn from(v: bool) -> Self { Value::Bool(v) } }
136impl From<String> for Value { fn from(v: String) -> Self { Value::Text(v) } }
137impl<'a> From<&'a str> for Value { fn from(v: &'a str) -> Self { Value::Text(v.to_string()) } }
138impl From<DateTime<Utc>> for Value { fn from(v: DateTime<Utc>) -> Self { Value::DateTime(v) } }
139impl From<Uuid> for Value { fn from(v: Uuid) -> Self { Value::Uuid(v) } }
140impl From<JsonValue> for Value { fn from(v: JsonValue) -> Self { Value::Json(v) } }
141impl<T: Into<Value>> From<Option<T>> for Value {
142 fn from(v: Option<T>) -> Self {
143 match v {
144 Some(v) => v.into(),
145 None => Value::Null,
146 }
147 }
148}
149
150pub struct Row<'a> {
151 inner: &'a PgRow,
152}
153
154impl<'a> Row<'a> {
155 pub fn from_pg(row: &'a PgRow) -> Self {
156 Self { inner: row }
157 }
158
159 pub fn get_i32(&self, col: &str) -> Result<i32> {
160 self.inner.try_get::<i32, _>(col).map_err(|e| Error::Internal(format!("get_i32({col}): {e}")))
161 }
162 pub fn get_i64(&self, col: &str) -> Result<i64> {
163 self.inner.try_get::<i64, _>(col).map_err(|e| Error::Internal(format!("get_i64({col}): {e}")))
164 }
165 pub fn get_optional_i64(&self, col: &str) -> Result<Option<i64>> {
166 self.inner.try_get::<Option<i64>, _>(col).map_err(|e| Error::Internal(format!("{col}: {e}")))
167 }
168 pub fn get_bool(&self, col: &str) -> Result<bool> {
169 self.inner.try_get::<bool, _>(col).map_err(|e| Error::Internal(format!("get_bool({col}): {e}")))
170 }
171 pub fn get_string(&self, col: &str) -> Result<String> {
172 self.inner.try_get::<String, _>(col).map_err(|e| Error::Internal(format!("get_string({col}): {e}")))
173 }
174 pub fn get_optional_string(&self, col: &str) -> Result<Option<String>> {
175 self.inner.try_get::<Option<String>, _>(col).map_err(|e| Error::Internal(format!("{col}: {e}")))
176 }
177 pub fn get_datetime(&self, col: &str) -> Result<DateTime<Utc>> {
178 self.inner.try_get::<DateTime<Utc>, _>(col).map_err(|e| Error::Internal(format!("{col}: {e}")))
179 }
180 pub fn get_uuid(&self, col: &str) -> Result<Uuid> {
181 self.inner.try_get::<Uuid, _>(col).map_err(|e| Error::Internal(format!("get_uuid({col}): {e}")))
182 }
183 pub fn get_json(&self, col: &str) -> Result<JsonValue> {
184 self.inner.try_get::<JsonValue, _>(col).map_err(|e| Error::Internal(format!("{col}: {e}")))
185 }
186}
187
188pub trait Model: Send + Sync + Sized + 'static {
189 const TABLE: &'static str;
190 const COLUMNS: &'static [&'static str];
191 const INSERT_COLUMNS: &'static [&'static str];
192
193 fn id(&self) -> i64;
194 fn from_row(row: Row<'_>) -> Result<Self>;
195 fn insert_values(&self) -> Vec<Value>;
196}
197
198pub async fn all<M: Model>(db: &Db) -> Result<Vec<M>> {
201 let sql = format!(
202 "SELECT {} FROM {} ORDER BY id DESC",
203 M::COLUMNS.join(", "),
204 M::TABLE
205 );
206 let rows = sqlx::query(&sql).fetch_all(db.pool()).await?;
207 rows.iter().map(|r| M::from_row(Row::from_pg(r))).collect()
208}
209
210pub async fn page<M: Model>(db: &Db, limit: i64, offset: i64) -> Result<Vec<M>> {
211 let sql = format!(
212 "SELECT {} FROM {} ORDER BY id DESC LIMIT $1 OFFSET $2",
213 M::COLUMNS.join(", "),
214 M::TABLE
215 );
216 let rows = sqlx::query(&sql)
217 .bind(limit)
218 .bind(offset)
219 .fetch_all(db.pool())
220 .await?;
221 rows.iter().map(|r| M::from_row(Row::from_pg(r))).collect()
222}
223
224pub async fn count<M: Model>(db: &Db) -> Result<i64> {
225 let sql = format!("SELECT COUNT(*) AS c FROM {}", M::TABLE);
226 let row = sqlx::query(&sql).fetch_one(db.pool()).await?;
227 row.try_get::<i64, _>("c")
228 .map_err(|e| Error::Internal(format!("count: {e}")))
229}
230
231pub async fn find<M: Model>(db: &Db, id: i64) -> Result<Option<M>> {
232 let sql = format!(
233 "SELECT {} FROM {} WHERE id = $1",
234 M::COLUMNS.join(", "),
235 M::TABLE
236 );
237 let row = sqlx::query(&sql).bind(id).fetch_optional(db.pool()).await?;
238 match row {
239 Some(r) => Ok(Some(M::from_row(Row::from_pg(&r))?)),
240 None => Ok(None),
241 }
242}
243
244pub async fn create<M: Model>(db: &Db, model: &M) -> Result<i64> {
245 let cols = M::INSERT_COLUMNS.join(", ");
246 let placeholders: Vec<String> = (1..=M::INSERT_COLUMNS.len())
247 .map(|i| format!("${i}"))
248 .collect();
249 let sql = format!(
250 "INSERT INTO {} ({}) VALUES ({}) RETURNING id",
251 M::TABLE,
252 cols,
253 placeholders.join(", ")
254 );
255 let mut query = sqlx::query(&sql);
256 for value in model.insert_values() {
257 query = bind_value(query, value);
258 }
259 let row = query.fetch_one(db.pool()).await?;
260 let id: i64 = row
261 .try_get("id")
262 .map_err(|e| Error::Internal(format!("returning id: {e}")))?;
263 db.invalidate(M::TABLE);
264 Ok(id)
265}
266
267pub async fn update<M: Model>(db: &Db, id: i64, model: &M) -> Result<()> {
268 let sets: Vec<String> = M::INSERT_COLUMNS
269 .iter()
270 .enumerate()
271 .map(|(i, col)| format!("{col} = ${}", i + 1))
272 .collect();
273 let sql = format!(
274 "UPDATE {} SET {} WHERE id = ${}",
275 M::TABLE,
276 sets.join(", "),
277 M::INSERT_COLUMNS.len() + 1
278 );
279 let mut query = sqlx::query(&sql);
280 for value in model.insert_values() {
281 query = bind_value(query, value);
282 }
283 query = query.bind(id);
284 query.execute(db.pool()).await?;
285 db.invalidate(M::TABLE);
286 Ok(())
287}
288
289pub async fn delete<M: Model>(db: &Db, id: i64) -> Result<()> {
290 let sql = format!("DELETE FROM {} WHERE id = $1", M::TABLE);
291 sqlx::query(&sql).bind(id).execute(db.pool()).await?;
292 db.invalidate(M::TABLE);
293 Ok(())
294}
295
296fn bind_value<'a>(
297 q: sqlx::query::Query<'a, sqlx::Postgres, sqlx::postgres::PgArguments>,
298 v: Value,
299) -> sqlx::query::Query<'a, sqlx::Postgres, sqlx::postgres::PgArguments> {
300 match v {
301 Value::Null => q.bind(None::<i64>),
302 Value::I32(n) => q.bind(n),
303 Value::I64(n) => q.bind(n),
304 Value::Bool(b) => q.bind(b),
305 Value::Text(s) => q.bind(s),
306 Value::DateTime(d) => q.bind(d),
307 Value::Uuid(u) => q.bind(u),
308 Value::Json(j) => q.bind(j),
309 }
310}