tusk_rs/
database.rs

1use deadpool_postgres::{Object, Pool};
2use openssl::ssl::{SslConnector, SslMethod};
3use postgres_openssl::MakeTlsConnector;
4use tokio_postgres::{types::ToSql, NoTls, Row};
5
6use crate::{
7    config::DatabaseConfig,
8    query::{IntoSyntax, PostgresReadable},
9    FromPostgres, PostgresReadFields, PostgresTable, PostgresWrite, RouteError,
10};
11
12/// Convenience macro used when fetching a single record from the database.
13///
14/// If the provided expression evaluates to `None` a [`RouteError::not_found`]
15/// is returned automatically. This allows easily bubbling a not found error
16/// when using [`QueryBuilder::get`] or custom queries.
17#[macro_export]
18macro_rules! expect {
19    ($query:expr) => {
20        $query.ok_or_else(|| RouteError::not_found("The record you requested was not found."))?
21    };
22    ($msg:literal, $query:expr) => {
23        $query.ok_or_else(|| RouteError::not_found($msg))?
24    };
25}
26
27/// Similar to [`expect`], but formats the error message with the supplied
28/// object name for a friendlier response.
29#[macro_export]
30macro_rules! expect_obj {
31    ($obj:literal, $query:expr) => {
32        $query.ok_or_else(|| {
33            RouteError::not_found(&format!("The {} you requested was not found.", $obj))
34        })?
35    };
36}
37
38/// A thin wrapper around [`deadpool_postgres`] used by Tusk.
39///
40/// The [`Database`] type manages a connection pool for your application and is
41/// created through [`Database::new`]. Connections are retrieved via
42/// [`Database::get_connection`] and passed into route handlers through the
43/// [`Request`](crate::Request) type.
44#[derive(Clone)]
45pub struct Database {
46    pool: Pool,
47    debug: bool,
48}
49
50impl Database {
51    /// Create a new connection pool from the provided [`DatabaseConfig`].
52    ///
53    /// Returns `None` if the pool could not be created.
54    pub async fn new(config: DatabaseConfig) -> Option<Database> {
55        let mut cfg = deadpool_postgres::Config::new();
56        cfg.user = Some(config.username);
57        cfg.password = Some(config.password);
58        cfg.host = Some(config.host);
59        cfg.dbname = Some(config.database);
60
61        if config.ssl {
62            let mut builder = SslConnector::builder(SslMethod::tls()).ok()?;
63            let _ = builder.set_ca_file("/etc/ssl/cert.pem");
64            let connector = MakeTlsConnector::new(builder.build());
65            let pool = cfg.create_pool(None, connector).ok()?;
66            Some(Database {
67                pool,
68                debug: config.debug,
69            })
70        } else {
71            let pool = cfg.create_pool(None, NoTls).ok()?;
72            Some(Database {
73                pool,
74                debug: config.debug,
75            })
76        }
77    }
78
79    /// Retrieve a [`DatabaseConnection`] from the pool.
80    ///
81    /// This should be called for every incoming request and the returned
82    /// connection passed to your route handlers.
83    pub async fn get_connection(&self) -> Result<DatabaseConnection, deadpool_postgres::PoolError> {
84        Ok(DatabaseConnection {
85            cn: self.pool.get().await?,
86            debug: self.debug,
87        })
88    }
89}
90
91/// Errors that may occur when reading from Postgres.
92#[derive(Debug)]
93pub enum PostgresReadError {
94    Unknown(tokio_postgres::Error),
95    // (Column)
96    AmbigiousColumn(String),
97    // (Table)
98    PermissionDenied(String),
99}
100impl PostgresReadError {
101    pub fn from_pg_err(err: tokio_postgres::Error) -> PostgresReadError {
102        dbg!(&err);
103        if let Some(code) = err.code() {
104            match code.code() {
105                "42702" => PostgresReadError::AmbigiousColumn(
106                    err.as_db_error()
107                        .unwrap()
108                        .message()
109                        .split('\"')
110                        .nth(1)
111                        .unwrap()
112                        .to_string(),
113                ),
114                "42501" => PostgresReadError::PermissionDenied(
115                    err.as_db_error().unwrap().table().unwrap().to_string(),
116                ),
117                _ => PostgresReadError::Unknown(err),
118            }
119        } else {
120            PostgresReadError::Unknown(err)
121        }
122    }
123}
124impl From<tokio_postgres::Error> for PostgresReadError {
125    fn from(value: tokio_postgres::Error) -> Self {
126        PostgresReadError::from_pg_err(value)
127    }
128}
129impl From<PostgresReadError> for RouteError {
130    fn from(value: PostgresReadError) -> Self {
131        dbg!(&value);
132        RouteError::bad_request("An error occurred and your request could not be fullfilled.")
133    }
134}
135
136/// Errors that may occur when writing to Postgres.
137#[derive(Debug)]
138pub enum PostgresWriteError {
139    NoWhereProvided,
140    InsertValueCountMismatch,
141    // (Constraint, Detail)
142    UniqueConstraintViolation(String, String),
143    // (Column)
144    NotNullConstraintViolation(String),
145    // (Table)
146    PermissionDenied(String),
147    NoRows,
148    Unknown(tokio_postgres::Error),
149}
150impl PostgresWriteError {
151    pub fn from_pg_err(err: tokio_postgres::Error) -> PostgresWriteError {
152        dbg!(&err);
153        if let Some(code) = err.code() {
154            match code.code() {
155                "42601" => PostgresWriteError::InsertValueCountMismatch,
156                "23505" => PostgresWriteError::UniqueConstraintViolation(
157                    err.as_db_error().unwrap().constraint().unwrap().to_string(),
158                    err.as_db_error().unwrap().detail().unwrap().to_string(),
159                ),
160                "23502" => PostgresWriteError::NotNullConstraintViolation(
161                    err.as_db_error().unwrap().column().unwrap().to_string(),
162                ),
163                "42501" => PostgresWriteError::PermissionDenied(
164                    err.as_db_error().unwrap().table().unwrap().to_string(),
165                ),
166                _ => PostgresWriteError::Unknown(err),
167            }
168        } else {
169            PostgresWriteError::Unknown(err)
170        }
171    }
172}
173impl From<tokio_postgres::Error> for PostgresWriteError {
174    fn from(value: tokio_postgres::Error) -> Self {
175        PostgresWriteError::from_pg_err(value)
176    }
177}
178impl From<PostgresWriteError> for RouteError {
179    fn from(value: PostgresWriteError) -> Self {
180        dbg!(&value);
181        RouteError::bad_request("An error occurred and your request could not be fullfilled.")
182    }
183}
184
185/// Trait used by [`QueryBuilder`] to represent column identifiers.
186pub trait ColumnKeys {
187    /// Return the column name as it exists in the database.
188    fn name(&self) -> &'static str;
189}
190
191/// Helper trait for models that can be used with [`QueryBuilder`].
192pub trait Columned: PostgresReadable + PostgresReadFields + FromPostgres + PostgresTable {
193    type ReadKeys: ColumnKeys;
194    type WriteKeys: ColumnKeys;
195}
196
197enum QueryComponent<T: ColumnKeys> {
198    Filter(QueryParam<T>),
199    And,
200    Or,
201    Limit(i32),
202    Offset(i32),
203}
204impl<T: ColumnKeys> QueryComponent<T> {
205    fn to_query(&self) -> String {
206        match self {
207            Self::Filter(param) => {
208                format!("{} {} ${}", param.key.name(), param.condition, param.arg)
209            }
210            Self::And => "AND".to_string(),
211            Self::Or => "OR".to_string(),
212            Self::Limit(limit) => format!("LIMIT {}", limit),
213            Self::Offset(offset) => format!("OFFSET {}", offset),
214        }
215    }
216    fn is_filter(&self) -> bool {
217        matches!(self, Self::Filter(_))
218    }
219}
220
221struct QueryParam<T: ColumnKeys> {
222    key: T,
223    condition: String,
224    arg: usize,
225}
226
227/// Builder used to easily construct simple `SELECT`, `UPDATE` and `DELETE` queries.
228pub struct QueryBuilder<'a, T: Columned> {
229    set: Vec<String>,
230    filters: Vec<QueryComponent<T::ReadKeys>>,
231    args: Vec<&'a (dyn ToSql + Sync)>,
232    force: bool,
233}
234impl<T: Columned> Default for QueryBuilder<'_, T> {
235    fn default() -> Self {
236        Self::new()
237    }
238}
239
240impl<'a, T: Columned> QueryBuilder<'a, T> {
241    /// Create a new [`QueryBuilder`].
242    pub fn new() -> QueryBuilder<'a, T> {
243        QueryBuilder {
244            set: Vec::new(),
245            filters: Vec::new(),
246            args: Vec::new(),
247            force: false,
248        }
249    }
250    /// Set a column to a value for an upcoming [`update_one`](QueryBuilder::update_one)
251    /// or [`update_many`](QueryBuilder::update_many) call.
252    pub fn set(mut self, key: T::WriteKeys, val: &'a (dyn ToSql + Sync)) -> Self {
253        self.set
254            .push(format!("{} = ${}", key.name(), self.args.len() + 1));
255        self.args.push(val);
256        self
257    }
258    /// Filter results where the provided column equals the value.
259    pub fn where_eq(mut self, key: T::ReadKeys, val: &'a (dyn ToSql + Sync)) -> Self {
260        self.filters.push(QueryComponent::Filter(QueryParam {
261            key,
262            condition: "=".to_string(),
263            arg: self.args.len() + 1,
264        }));
265        self.args.push(val);
266        self
267    }
268    /// Filter results where the provided column does not equal the value.
269    pub fn where_ne(mut self, key: T::ReadKeys, val: &'a (dyn ToSql + Sync)) -> Self {
270        self.filters.push(QueryComponent::Filter(QueryParam {
271            key,
272            condition: "<>".to_string(),
273            arg: self.args.len() + 1,
274        }));
275        self.args.push(val);
276        self
277    }
278    /// Append an `AND` to the where clause.
279    pub fn and(mut self) -> Self {
280        self.filters.push(QueryComponent::And);
281        self
282    }
283    /// Append an `OR` to the where clause.
284    pub fn or(mut self) -> Self {
285        self.filters.push(QueryComponent::Or);
286        self
287    }
288    /// Apply a limit
289    pub fn limit(mut self, val: i32) -> Self {
290        self.filters.push(QueryComponent::Limit(val));
291        self
292    }
293    /// Apply an offset
294    pub fn offset(mut self, val: i32) -> Self {
295        self.filters.push(QueryComponent::Offset(val));
296        self
297    }
298    /// By default, write operations without a "WHERE" clause
299    /// will be rejected. Call this function
300    /// to force it to work.
301    pub fn force(mut self) -> Self {
302        self.force = true;
303        self
304    }
305
306    fn build_trail(&self) -> String {
307        if !self.filters.is_empty() {
308            format!(
309                "{} {}",
310                if self.filters.first().map(|x| x.is_filter()).unwrap_or(false) {
311                    "WHERE "
312                } else {
313                    ""
314                },
315                self.filters
316                    .iter()
317                    .map(|x| x.to_query())
318                    .collect::<Vec<_>>()
319                    .join(" ")
320            )
321        } else {
322            String::new()
323        }
324    }
325
326    /// Select one will fetch an object from the database, and return an Option indicating whether it's
327    /// been found.
328    pub async fn get(self, db: &DatabaseConnection) -> Result<Option<T>, PostgresReadError> {
329        Ok(db
330            .query(
331                &format!(
332                    "SELECT {} FROM {} {} {}",
333                    T::read_fields().as_syntax(T::table_name()),
334                    T::table_name(),
335                    T::joins()
336                        .iter()
337                        .map(|j| j.to_read(T::table_name()))
338                        .collect::<Vec<String>>()
339                        .join(" "),
340                    self.build_trail()
341                ),
342                &self.args,
343            )
344            .await?
345            .iter()
346            .map(|x| T::from_postgres(x))
347            .next())
348    }
349    /// Select all will fetch many objects from the database, and return a Vec. If no options are
350    /// found, an empty Vec is returned.
351    pub async fn select_all(self, db: &DatabaseConnection) -> Result<Vec<T>, PostgresReadError> {
352        Ok(db
353            .query(
354                &format!(
355                    "SELECT {} FROM {} {} {}",
356                    T::read_fields().as_syntax(T::table_name()),
357                    T::table_name(),
358                    T::joins()
359                        .iter()
360                        .map(|j| j.to_read(T::table_name()))
361                        .collect::<Vec<String>>()
362                        .join(" "),
363                    self.build_trail(),
364                ),
365                &self.args,
366            )
367            .await?
368            .iter()
369            .map(|x| T::from_postgres(x))
370            .collect())
371    }
372    /// Update rows and return all updated records.
373    pub async fn update_many(self, db: &DatabaseConnection) -> Result<Vec<T>, PostgresWriteError> {
374        let temp_table = format!("write_{}", T::table_name());
375        if self.filters.is_empty() && !self.force {
376            return Err(PostgresWriteError::NoWhereProvided);
377        }
378        Ok(db
379            .query(
380                &format!(
381                    "WITH {} AS (UPDATE {} SET {} {} RETURNING *) SELECT {} FROM {} {}",
382                    temp_table,
383                    T::table_name(),
384                    self.set.join(", "),
385                    self.build_trail(),
386                    T::read_fields().as_syntax(&temp_table),
387                    temp_table,
388                    T::joins().as_syntax(&temp_table),
389                ),
390                self.args.as_slice(),
391            )
392            .await?
393            .iter()
394            .map(|x| T::from_postgres(x))
395            .collect())
396    }
397    /// Update row and return first updated record.
398    pub async fn update_one(
399        self,
400        db: &DatabaseConnection,
401    ) -> Result<Option<T>, PostgresWriteError> {
402        let temp_table = format!("write_{}", T::table_name());
403        if self.filters.is_empty() && !self.force {
404            return Err(PostgresWriteError::NoWhereProvided);
405        }
406        Ok(db
407            .query(
408                &format!(
409                    "WITH {} AS (UPDATE {} SET {} {} RETURNING *) SELECT {} FROM {} {}",
410                    temp_table,
411                    T::table_name(),
412                    self.set.join(", "),
413                    self.build_trail(),
414                    T::read_fields().as_syntax(&temp_table),
415                    temp_table,
416                    T::joins().as_syntax(&temp_table),
417                ),
418                self.args.as_slice(),
419            )
420            .await?
421            .iter()
422            .map(|x| T::from_postgres(x))
423            .next())
424    }
425    /// Delete rows matching the provided condition.
426    pub async fn delete(&self, db: &DatabaseConnection) -> Result<(), PostgresWriteError> {
427        _ = db
428            .query(
429                &format!("DELETE FROM {} {}", T::table_name(), self.build_trail()),
430                &self.args,
431            )
432            .await?;
433        Ok(())
434    }
435}
436
437/// Wrapper around a single pooled database connection.
438///
439/// Instances of this type are passed to route handlers and expose helper
440/// methods for common CRUD operations.
441pub struct DatabaseConnection {
442    cn: Object,
443    debug: bool,
444}
445impl DatabaseConnection {
446    /// Execute a raw SQL query and return the resulting rows.
447    pub async fn query<T: AsRef<str>>(
448        &self,
449        query: T,
450        args: &[&(dyn ToSql + Sync)],
451    ) -> Result<Vec<Row>, tokio_postgres::Error> {
452        if self.debug {
453            println!("[DEBUG: QUERY] {}", query.as_ref());
454            println!("[DEBUG: ARGS] Args: {:?}", args);
455        }
456        self.cn.query(query.as_ref(), args).await
457    }
458    /// Insert a single record and return the inserted row.
459    pub async fn insert<T: FromPostgres + PostgresTable + PostgresReadFields>(
460        &self,
461        write: PostgresWrite,
462    ) -> Result<T, PostgresWriteError> {
463        let (insert_q, insert_a) = write.into_insert(T::table_name());
464        if self.debug {
465            println!(
466                "[DEBUG: QUERY] (insert) {} RETURNING {}",
467                insert_q,
468                T::read_fields().as_syntax(T::table_name())
469            );
470            println!("[DEBUG: ARGS] (insert) Args: {:?}", insert_a);
471        }
472        Ok(self
473            .cn
474            .query(
475                &format!(
476                    "{} RETURNING {}",
477                    insert_q,
478                    T::read_fields().as_syntax(T::table_name())
479                ),
480                insert_a.as_slice(),
481            )
482            .await?
483            .iter()
484            .map(|x| T::from_postgres(x))
485            .next()
486            .unwrap())
487    }
488
489    /// Insert many records and return the inserted rows.
490    pub async fn insert_vec<T: FromPostgres + PostgresTable + PostgresReadable>(
491        &self,
492        write: PostgresWrite,
493    ) -> Result<Vec<T>, PostgresWriteError> {
494        let (insert_q, insert_a) = write.into_bulk_insert(T::table_name());
495        if insert_a.is_empty() {
496            return Err(PostgresWriteError::NoRows);
497        }
498        let temp_table = format!("write_{}", T::table_name());
499        let join_str = if !T::joins().is_empty() {
500            T::joins().as_syntax(&temp_table)
501        } else {
502            "".to_string()
503        };
504        if self.debug {
505            println!(
506                "[DEBUG: QUERY] (insert_vec) WITH {} AS ({} RETURNING *) SELECT {} FROM {} {}",
507                temp_table,
508                insert_q,
509                T::read_fields().as_syntax(&temp_table),
510                temp_table,
511                join_str
512            );
513            println!("[DEBUG: ARGS] (insert_vec) Args: {:?}", insert_a);
514        }
515        Ok(self
516            .cn
517            .query(
518                &format!(
519                    "WITH {} AS ({} RETURNING *) SELECT {} FROM {} {}",
520                    temp_table,
521                    insert_q,
522                    T::read_fields().as_syntax(&temp_table),
523                    temp_table,
524                    join_str
525                ),
526                insert_a.as_slice(),
527            )
528            .await?
529            .iter()
530            .map(|x| T::from_postgres(x))
531            .collect())
532    }
533}
534
535/// Generic errors that may occur during database operations.
536pub enum DatabaseError {
537    Unknown,
538    ForeignKey(String),
539    NoResults,
540}