1use deadpool_postgres::{Object, Pool};
2use openssl::ssl::{SslConnector, SslMethod};
3use postgres_openssl::MakeTlsConnector;
4use tokio_postgres::{types::ToSql, NoTls, Row};
5
6use crate::{
7 config::DatabaseConfig,
8 query::{IntoSyntax, PostgresReadable},
9 FromPostgres, PostgresReadFields, PostgresTable, PostgresWrite, RouteError,
10};
11
12#[macro_export]
18macro_rules! expect {
19 ($query:expr) => {
20 $query.ok_or_else(|| RouteError::not_found("The record you requested was not found."))?
21 };
22 ($msg:literal, $query:expr) => {
23 $query.ok_or_else(|| RouteError::not_found($msg))?
24 };
25}
26
27#[macro_export]
30macro_rules! expect_obj {
31 ($obj:literal, $query:expr) => {
32 $query.ok_or_else(|| {
33 RouteError::not_found(&format!("The {} you requested was not found.", $obj))
34 })?
35 };
36}
37
38#[derive(Clone)]
45pub struct Database {
46 pool: Pool,
47 debug: bool,
48}
49
50impl Database {
51 pub async fn new(config: DatabaseConfig) -> Option<Database> {
55 let mut cfg = deadpool_postgres::Config::new();
56 cfg.user = Some(config.username);
57 cfg.password = Some(config.password);
58 cfg.host = Some(config.host);
59 cfg.dbname = Some(config.database);
60
61 if config.ssl {
62 let mut builder = SslConnector::builder(SslMethod::tls()).ok()?;
63 let _ = builder.set_ca_file("/etc/ssl/cert.pem");
64 let connector = MakeTlsConnector::new(builder.build());
65 let pool = cfg.create_pool(None, connector).ok()?;
66 Some(Database {
67 pool,
68 debug: config.debug,
69 })
70 } else {
71 let pool = cfg.create_pool(None, NoTls).ok()?;
72 Some(Database {
73 pool,
74 debug: config.debug,
75 })
76 }
77 }
78
79 pub async fn get_connection(&self) -> Result<DatabaseConnection, deadpool_postgres::PoolError> {
84 Ok(DatabaseConnection {
85 cn: self.pool.get().await?,
86 debug: self.debug,
87 })
88 }
89}
90
91#[derive(Debug)]
93pub enum PostgresReadError {
94 Unknown(tokio_postgres::Error),
95 AmbigiousColumn(String),
97 PermissionDenied(String),
99}
100impl PostgresReadError {
101 pub fn from_pg_err(err: tokio_postgres::Error) -> PostgresReadError {
102 dbg!(&err);
103 if let Some(code) = err.code() {
104 match code.code() {
105 "42702" => PostgresReadError::AmbigiousColumn(
106 err.as_db_error()
107 .unwrap()
108 .message()
109 .split('\"')
110 .nth(1)
111 .unwrap()
112 .to_string(),
113 ),
114 "42501" => PostgresReadError::PermissionDenied(
115 err.as_db_error().unwrap().table().unwrap().to_string(),
116 ),
117 _ => PostgresReadError::Unknown(err),
118 }
119 } else {
120 PostgresReadError::Unknown(err)
121 }
122 }
123}
124impl From<tokio_postgres::Error> for PostgresReadError {
125 fn from(value: tokio_postgres::Error) -> Self {
126 PostgresReadError::from_pg_err(value)
127 }
128}
129impl From<PostgresReadError> for RouteError {
130 fn from(value: PostgresReadError) -> Self {
131 dbg!(&value);
132 RouteError::bad_request("An error occurred and your request could not be fullfilled.")
133 }
134}
135
136#[derive(Debug)]
138pub enum PostgresWriteError {
139 NoWhereProvided,
140 InsertValueCountMismatch,
141 UniqueConstraintViolation(String, String),
143 NotNullConstraintViolation(String),
145 PermissionDenied(String),
147 NoRows,
148 Unknown(tokio_postgres::Error),
149}
150impl PostgresWriteError {
151 pub fn from_pg_err(err: tokio_postgres::Error) -> PostgresWriteError {
152 dbg!(&err);
153 if let Some(code) = err.code() {
154 match code.code() {
155 "42601" => PostgresWriteError::InsertValueCountMismatch,
156 "23505" => PostgresWriteError::UniqueConstraintViolation(
157 err.as_db_error().unwrap().constraint().unwrap().to_string(),
158 err.as_db_error().unwrap().detail().unwrap().to_string(),
159 ),
160 "23502" => PostgresWriteError::NotNullConstraintViolation(
161 err.as_db_error().unwrap().column().unwrap().to_string(),
162 ),
163 "42501" => PostgresWriteError::PermissionDenied(
164 err.as_db_error().unwrap().table().unwrap().to_string(),
165 ),
166 _ => PostgresWriteError::Unknown(err),
167 }
168 } else {
169 PostgresWriteError::Unknown(err)
170 }
171 }
172}
173impl From<tokio_postgres::Error> for PostgresWriteError {
174 fn from(value: tokio_postgres::Error) -> Self {
175 PostgresWriteError::from_pg_err(value)
176 }
177}
178impl From<PostgresWriteError> for RouteError {
179 fn from(value: PostgresWriteError) -> Self {
180 dbg!(&value);
181 RouteError::bad_request("An error occurred and your request could not be fullfilled.")
182 }
183}
184
185pub trait ColumnKeys {
187 fn name(&self) -> &'static str;
189}
190
191pub trait Columned: PostgresReadable + PostgresReadFields + FromPostgres + PostgresTable {
193 type ReadKeys: ColumnKeys;
194 type WriteKeys: ColumnKeys;
195}
196
197enum QueryComponent<T: ColumnKeys> {
198 Filter(QueryParam<T>),
199 And,
200 Or,
201 Limit(i32),
202 Offset(i32),
203}
204impl<T: ColumnKeys> QueryComponent<T> {
205 fn to_query(&self) -> String {
206 match self {
207 Self::Filter(param) => {
208 format!("{} {} ${}", param.key.name(), param.condition, param.arg)
209 }
210 Self::And => "AND".to_string(),
211 Self::Or => "OR".to_string(),
212 Self::Limit(limit) => format!("LIMIT {}", limit),
213 Self::Offset(offset) => format!("OFFSET {}", offset),
214 }
215 }
216 fn is_filter(&self) -> bool {
217 matches!(self, Self::Filter(_))
218 }
219}
220
221struct QueryParam<T: ColumnKeys> {
222 key: T,
223 condition: String,
224 arg: usize,
225}
226
227pub struct QueryBuilder<'a, T: Columned> {
229 set: Vec<String>,
230 filters: Vec<QueryComponent<T::ReadKeys>>,
231 args: Vec<&'a (dyn ToSql + Sync)>,
232 force: bool,
233}
234impl<T: Columned> Default for QueryBuilder<'_, T> {
235 fn default() -> Self {
236 Self::new()
237 }
238}
239
240impl<'a, T: Columned> QueryBuilder<'a, T> {
241 pub fn new() -> QueryBuilder<'a, T> {
243 QueryBuilder {
244 set: Vec::new(),
245 filters: Vec::new(),
246 args: Vec::new(),
247 force: false,
248 }
249 }
250 pub fn set(mut self, key: T::WriteKeys, val: &'a (dyn ToSql + Sync)) -> Self {
253 self.set
254 .push(format!("{} = ${}", key.name(), self.args.len() + 1));
255 self.args.push(val);
256 self
257 }
258 pub fn where_eq(mut self, key: T::ReadKeys, val: &'a (dyn ToSql + Sync)) -> Self {
260 self.filters.push(QueryComponent::Filter(QueryParam {
261 key,
262 condition: "=".to_string(),
263 arg: self.args.len() + 1,
264 }));
265 self.args.push(val);
266 self
267 }
268 pub fn where_ne(mut self, key: T::ReadKeys, val: &'a (dyn ToSql + Sync)) -> Self {
270 self.filters.push(QueryComponent::Filter(QueryParam {
271 key,
272 condition: "<>".to_string(),
273 arg: self.args.len() + 1,
274 }));
275 self.args.push(val);
276 self
277 }
278 pub fn and(mut self) -> Self {
280 self.filters.push(QueryComponent::And);
281 self
282 }
283 pub fn or(mut self) -> Self {
285 self.filters.push(QueryComponent::Or);
286 self
287 }
288 pub fn limit(mut self, val: i32) -> Self {
290 self.filters.push(QueryComponent::Limit(val));
291 self
292 }
293 pub fn offset(mut self, val: i32) -> Self {
295 self.filters.push(QueryComponent::Offset(val));
296 self
297 }
298 pub fn force(mut self) -> Self {
302 self.force = true;
303 self
304 }
305
306 fn build_trail(&self) -> String {
307 if !self.filters.is_empty() {
308 format!(
309 "{} {}",
310 if self.filters.first().map(|x| x.is_filter()).unwrap_or(false) {
311 "WHERE "
312 } else {
313 ""
314 },
315 self.filters
316 .iter()
317 .map(|x| x.to_query())
318 .collect::<Vec<_>>()
319 .join(" ")
320 )
321 } else {
322 String::new()
323 }
324 }
325
326 pub async fn get(self, db: &DatabaseConnection) -> Result<Option<T>, PostgresReadError> {
329 Ok(db
330 .query(
331 &format!(
332 "SELECT {} FROM {} {} {}",
333 T::read_fields().as_syntax(T::table_name()),
334 T::table_name(),
335 T::joins()
336 .iter()
337 .map(|j| j.to_read(T::table_name()))
338 .collect::<Vec<String>>()
339 .join(" "),
340 self.build_trail()
341 ),
342 &self.args,
343 )
344 .await?
345 .iter()
346 .map(|x| T::from_postgres(x))
347 .next())
348 }
349 pub async fn select_all(self, db: &DatabaseConnection) -> Result<Vec<T>, PostgresReadError> {
352 Ok(db
353 .query(
354 &format!(
355 "SELECT {} FROM {} {} {}",
356 T::read_fields().as_syntax(T::table_name()),
357 T::table_name(),
358 T::joins()
359 .iter()
360 .map(|j| j.to_read(T::table_name()))
361 .collect::<Vec<String>>()
362 .join(" "),
363 self.build_trail(),
364 ),
365 &self.args,
366 )
367 .await?
368 .iter()
369 .map(|x| T::from_postgres(x))
370 .collect())
371 }
372 pub async fn update_many(self, db: &DatabaseConnection) -> Result<Vec<T>, PostgresWriteError> {
374 let temp_table = format!("write_{}", T::table_name());
375 if self.filters.is_empty() && !self.force {
376 return Err(PostgresWriteError::NoWhereProvided);
377 }
378 Ok(db
379 .query(
380 &format!(
381 "WITH {} AS (UPDATE {} SET {} {} RETURNING *) SELECT {} FROM {} {}",
382 temp_table,
383 T::table_name(),
384 self.set.join(", "),
385 self.build_trail(),
386 T::read_fields().as_syntax(&temp_table),
387 temp_table,
388 T::joins().as_syntax(&temp_table),
389 ),
390 self.args.as_slice(),
391 )
392 .await?
393 .iter()
394 .map(|x| T::from_postgres(x))
395 .collect())
396 }
397 pub async fn update_one(
399 self,
400 db: &DatabaseConnection,
401 ) -> Result<Option<T>, PostgresWriteError> {
402 let temp_table = format!("write_{}", T::table_name());
403 if self.filters.is_empty() && !self.force {
404 return Err(PostgresWriteError::NoWhereProvided);
405 }
406 Ok(db
407 .query(
408 &format!(
409 "WITH {} AS (UPDATE {} SET {} {} RETURNING *) SELECT {} FROM {} {}",
410 temp_table,
411 T::table_name(),
412 self.set.join(", "),
413 self.build_trail(),
414 T::read_fields().as_syntax(&temp_table),
415 temp_table,
416 T::joins().as_syntax(&temp_table),
417 ),
418 self.args.as_slice(),
419 )
420 .await?
421 .iter()
422 .map(|x| T::from_postgres(x))
423 .next())
424 }
425 pub async fn delete(&self, db: &DatabaseConnection) -> Result<(), PostgresWriteError> {
427 _ = db
428 .query(
429 &format!("DELETE FROM {} {}", T::table_name(), self.build_trail()),
430 &self.args,
431 )
432 .await?;
433 Ok(())
434 }
435}
436
437pub struct DatabaseConnection {
442 cn: Object,
443 debug: bool,
444}
445impl DatabaseConnection {
446 pub async fn query<T: AsRef<str>>(
448 &self,
449 query: T,
450 args: &[&(dyn ToSql + Sync)],
451 ) -> Result<Vec<Row>, tokio_postgres::Error> {
452 if self.debug {
453 println!("[DEBUG: QUERY] {}", query.as_ref());
454 println!("[DEBUG: ARGS] Args: {:?}", args);
455 }
456 self.cn.query(query.as_ref(), args).await
457 }
458 pub async fn insert<T: FromPostgres + PostgresTable + PostgresReadFields>(
460 &self,
461 write: PostgresWrite,
462 ) -> Result<T, PostgresWriteError> {
463 let (insert_q, insert_a) = write.into_insert(T::table_name());
464 if self.debug {
465 println!(
466 "[DEBUG: QUERY] (insert) {} RETURNING {}",
467 insert_q,
468 T::read_fields().as_syntax(T::table_name())
469 );
470 println!("[DEBUG: ARGS] (insert) Args: {:?}", insert_a);
471 }
472 Ok(self
473 .cn
474 .query(
475 &format!(
476 "{} RETURNING {}",
477 insert_q,
478 T::read_fields().as_syntax(T::table_name())
479 ),
480 insert_a.as_slice(),
481 )
482 .await?
483 .iter()
484 .map(|x| T::from_postgres(x))
485 .next()
486 .unwrap())
487 }
488
489 pub async fn insert_vec<T: FromPostgres + PostgresTable + PostgresReadable>(
491 &self,
492 write: PostgresWrite,
493 ) -> Result<Vec<T>, PostgresWriteError> {
494 let (insert_q, insert_a) = write.into_bulk_insert(T::table_name());
495 if insert_a.is_empty() {
496 return Err(PostgresWriteError::NoRows);
497 }
498 let temp_table = format!("write_{}", T::table_name());
499 let join_str = if !T::joins().is_empty() {
500 T::joins().as_syntax(&temp_table)
501 } else {
502 "".to_string()
503 };
504 if self.debug {
505 println!(
506 "[DEBUG: QUERY] (insert_vec) WITH {} AS ({} RETURNING *) SELECT {} FROM {} {}",
507 temp_table,
508 insert_q,
509 T::read_fields().as_syntax(&temp_table),
510 temp_table,
511 join_str
512 );
513 println!("[DEBUG: ARGS] (insert_vec) Args: {:?}", insert_a);
514 }
515 Ok(self
516 .cn
517 .query(
518 &format!(
519 "WITH {} AS ({} RETURNING *) SELECT {} FROM {} {}",
520 temp_table,
521 insert_q,
522 T::read_fields().as_syntax(&temp_table),
523 temp_table,
524 join_str
525 ),
526 insert_a.as_slice(),
527 )
528 .await?
529 .iter()
530 .map(|x| T::from_postgres(x))
531 .collect())
532 }
533}
534
535pub enum DatabaseError {
537 Unknown,
538 ForeignKey(String),
539 NoResults,
540}