1use crate::column::{reindex_params, FilterExpr, OrderExpr, SqlValue};
8use crate::error::{OrmError, OrmResult};
9use crate::pagination::{Page};
10use crate::scope::Scope;
11use sqlx::postgres::{PgArguments, PgRow};
12use sqlx::{PgPool, Postgres};
13use std::marker::PhantomData;
14
15#[derive(Debug, Clone)]
20enum JoinType {
21 Inner,
22 Left,
23 Right,
24}
25
26#[derive(Debug, Clone)]
27struct JoinClause {
28 join_type: JoinType,
29 table: String,
30 alias: Option<String>,
31 on: String,
33}
34
35#[derive(Debug, Clone)]
40pub struct UpdateSet {
41 pub col: String,
42 pub val: SqlValue,
43}
44
45#[derive(Debug)]
50pub struct QueryBuilder<T> {
51 pub(crate) table: String,
52 pub(crate) pk: String,
53
54 select_cols: Vec<String>,
56 distinct: bool,
57
58 filters: Vec<FilterExpr>,
60
61 with_deleted: bool,
63 only_deleted: bool,
64 soft_delete_col: Option<String>,
65
66 joins: Vec<JoinClause>,
68
69 group_by: Vec<String>,
71 having: Option<FilterExpr>,
72
73 order_by: Vec<String>,
75 order_random: bool,
76
77 limit: Option<i64>,
79 offset: Option<i64>,
80
81 op: QueryOp,
83
84 update_sets: Vec<UpdateSet>,
86
87 _marker: PhantomData<T>,
88}
89
90impl<T> Clone for QueryBuilder<T> {
93 fn clone(&self) -> Self {
94 Self {
95 table: self.table.clone(),
96 pk: self.pk.clone(),
97 select_cols: self.select_cols.clone(),
98 distinct: self.distinct,
99 filters: self.filters.clone(),
100 with_deleted: self.with_deleted,
101 only_deleted: self.only_deleted,
102 soft_delete_col: self.soft_delete_col.clone(),
103 joins: self.joins.clone(),
104 group_by: self.group_by.clone(),
105 having: self.having.clone(),
106 order_by: self.order_by.clone(),
107 order_random: self.order_random,
108 limit: self.limit,
109 offset: self.offset,
110 op: self.op.clone(),
111 update_sets: self.update_sets.clone(),
112 _marker: PhantomData,
113 }
114 }
115}
116
117#[derive(Debug, Clone, PartialEq)]
118enum QueryOp {
119 Select,
120 Update,
121 Delete,
122 Count,
123}
124
125impl<T> QueryBuilder<T>
126where
127 T: Send + Sync + Unpin + 'static,
128{
129 pub fn new(table: impl Into<String>, pk: impl Into<String>) -> Self {
130 Self {
131 table: table.into(),
132 pk: pk.into(),
133 select_cols: vec!["*".into()],
134 distinct: false,
135 filters: vec![],
136 with_deleted: false,
137 only_deleted: false,
138 soft_delete_col: None,
139 joins: vec![],
140 group_by: vec![],
141 having: None,
142 order_by: vec![],
143 order_random: false,
144 limit: None,
145 offset: None,
146 op: QueryOp::Select,
147 update_sets: vec![],
148 _marker: PhantomData,
149 }
150 }
151
152 pub fn with_soft_delete_col(mut self, col: impl Into<String>) -> Self {
155 self.soft_delete_col = Some(col.into());
156 self
157 }
158
159 pub fn filter<F>(mut self, f: F) -> Self
163 where
164 F: FnOnce(&T::Columns) -> FilterExpr,
165 T: HasColumns,
166 {
167 let cols = T::columns_proxy();
168 let expr = f(&cols);
169 self.filters.push(expr);
170 self
171 }
172
173 pub fn filter_raw(mut self, sql: impl Into<String>) -> Self {
175 self.filters.push(FilterExpr::raw(sql));
176 self
177 }
178
179 pub fn filter_if<F>(self, condition: bool, f: F) -> Self
181 where
182 F: FnOnce(&T::Columns) -> FilterExpr,
183 T: HasColumns,
184 {
185 if condition {
186 self.filter(f)
187 } else {
188 self
189 }
190 }
191
192 pub fn apply(self, scope: Scope<T>) -> Self {
196 (scope.apply_fn)(self)
197 }
198
199 pub fn with_deleted(mut self) -> Self {
203 self.with_deleted = true;
204 self
205 }
206
207 pub fn only_deleted(mut self) -> Self {
209 self.only_deleted = true;
210 self
211 }
212
213 pub fn select_cols(mut self, cols: impl IntoIterator<Item = impl Into<String>>) -> Self {
217 self.select_cols = cols.into_iter().map(|c| c.into()).collect();
218 self
219 }
220
221 pub fn select_distinct_col(mut self, col: impl Into<String>) -> Self {
222 self.distinct = true;
223 self.select_cols = vec![col.into()];
224 self
225 }
226
227 pub fn inner_join(mut self, table: impl Into<String>, on: impl Into<String>) -> Self {
231 self.joins.push(JoinClause {
232 join_type: JoinType::Inner,
233 table: table.into(),
234 alias: None,
235 on: on.into(),
236 });
237 self
238 }
239
240 pub fn left_join(mut self, table: impl Into<String>, on: impl Into<String>) -> Self {
242 self.joins.push(JoinClause {
243 join_type: JoinType::Left,
244 table: table.into(),
245 alias: None,
246 on: on.into(),
247 });
248 self
249 }
250
251 pub fn group_by_col(mut self, col: impl Into<String>) -> Self {
254 self.group_by.push(col.into());
255 self
256 }
257
258 pub fn having_raw(mut self, sql: impl Into<String>) -> Self {
259 self.having = Some(FilterExpr::raw(sql));
260 self
261 }
262
263 pub fn order_by<F>(mut self, f: F) -> Self
266 where
267 F: FnOnce(&T::Columns) -> OrderExpr,
268 T: HasColumns,
269 {
270 let cols = T::columns_proxy();
271 let expr = f(&cols);
272 self.order_by.push(expr.sql);
273 self
274 }
275
276 pub fn order_by_raw(mut self, sql: impl Into<String>) -> Self {
277 self.order_by.push(sql.into());
278 self
279 }
280
281 pub fn order_by_random(mut self) -> Self {
282 self.order_random = true;
283 self
284 }
285
286 pub fn limit(mut self, n: i64) -> Self {
289 self.limit = Some(n);
290 self
291 }
292
293 pub fn offset(mut self, n: i64) -> Self {
294 self.offset = Some(n);
295 self
296 }
297
298 pub fn paginate(mut self, page: i64, per_page: i64) -> Self {
302 let page = page.max(1);
303 self.limit = Some(per_page);
304 self.offset = Some((page - 1) * per_page);
305 self
306 }
307
308 pub fn build_select(&self) -> (String, Vec<SqlValue>) {
312 let mut bindings: Vec<SqlValue> = vec![];
313
314 let mut all_filters = self.filters.clone();
316 if let Some(ref col) = self.soft_delete_col {
317 if !self.with_deleted && !self.only_deleted {
318 all_filters.push(FilterExpr::raw(format!("\"{}\" IS NULL", col)));
319 } else if self.only_deleted {
320 all_filters.push(FilterExpr::raw(format!("\"{}\" IS NOT NULL", col)));
321 }
322 }
323
324 let distinct_kw = if self.distinct { "DISTINCT " } else { "" };
326 let cols = self.select_cols.join(", ");
327 let mut sql = format!("SELECT {}{} FROM \"{}\"", distinct_kw, cols, self.table);
328
329 for j in &self.joins {
331 let kw = match j.join_type {
332 JoinType::Inner => "INNER JOIN",
333 JoinType::Left => "LEFT JOIN",
334 JoinType::Right => "RIGHT JOIN",
335 };
336 let alias_part = j
337 .alias
338 .as_deref()
339 .map(|a| format!(" AS \"{}\"", a))
340 .unwrap_or_default();
341 sql.push_str(&format!(
342 " {} \"{}\"{} ON {}",
343 kw, j.table, alias_part, j.on
344 ));
345 }
346
347 if !all_filters.is_empty() {
349 let mut parts: Vec<String> = vec![];
350 for expr in &all_filters {
351 let offset = bindings.len();
352 let reindexed = reindex_params(&expr.sql, offset);
353 parts.push(reindexed);
354 bindings.extend(expr.bindings.clone());
355 }
356 sql.push_str(" WHERE ");
357 sql.push_str(&parts.join(" AND "));
358 }
359
360 if !self.group_by.is_empty() {
362 sql.push_str(" GROUP BY ");
363 sql.push_str(&self.group_by.join(", "));
364 }
365
366 if let Some(ref hav) = self.having {
368 let offset = bindings.len();
369 let reindexed = reindex_params(&hav.sql, offset);
370 sql.push_str(&format!(" HAVING {}", reindexed));
371 bindings.extend(hav.bindings.clone());
372 }
373
374 if self.order_random {
376 sql.push_str(" ORDER BY RANDOM()");
377 } else if !self.order_by.is_empty() {
378 sql.push_str(" ORDER BY ");
379 sql.push_str(&self.order_by.join(", "));
380 }
381
382 if let Some(l) = self.limit {
384 sql.push_str(&format!(" LIMIT {}", l));
385 }
386 if let Some(o) = self.offset {
387 sql.push_str(&format!(" OFFSET {}", o));
388 }
389
390 (sql, bindings)
391 }
392
393 fn build_count(self) -> (String, Vec<SqlValue>) {
395 let count_builder = QueryBuilder::<T> {
396 select_cols: vec!["COUNT(*)".into()],
397 order_by: vec![],
398 order_random: false,
399 limit: None,
400 offset: None,
401 ..self.clone()
402 };
403 count_builder.build_select()
404 }
405
406 pub async fn fetch_all(self, pool: &PgPool) -> OrmResult<Vec<T>>
410 where
411 T: for<'r> sqlx::FromRow<'r, PgRow>,
412 {
413 let (sql, bindings) = self.build_select();
414 let mut q = sqlx::query_as::<Postgres, T>(&sql);
415 for b in bindings {
416 q = bind_value(q, b);
417 }
418 q.fetch_all(pool).await.map_err(OrmError::from_sqlx)
419 }
420
421 pub async fn first(mut self, pool: &PgPool) -> OrmResult<Option<T>>
423 where
424 T: for<'r> sqlx::FromRow<'r, PgRow>,
425 {
426 self.limit = Some(1);
427 let (sql, bindings) = self.build_select();
428 let mut q = sqlx::query_as::<Postgres, T>(&sql);
429 for b in bindings {
430 q = bind_value(q, b);
431 }
432 q.fetch_optional(pool).await.map_err(OrmError::from_sqlx)
433 }
434
435 pub async fn first_or_fail(self, pool: &PgPool) -> OrmResult<T>
437 where
438 T: for<'r> sqlx::FromRow<'r, PgRow>,
439 {
440 self.first(pool).await?.ok_or(OrmError::NotFound)
441 }
442
443 pub async fn last(mut self, pool: &PgPool) -> OrmResult<Option<T>>
445 where
446 T: for<'r> sqlx::FromRow<'r, PgRow>,
447 {
448 if self.order_by.is_empty() && !self.order_random {
450 self.order_by.push(format!("\"{}\" DESC", self.pk));
451 }
452 self.limit = Some(1);
453 let (sql, bindings) = self.build_select();
454 let mut q = sqlx::query_as::<Postgres, T>(&sql);
455 for b in bindings {
456 q = bind_value(q, b);
457 }
458 q.fetch_optional(pool).await.map_err(OrmError::from_sqlx)
459 }
460
461 pub async fn count(self, pool: &PgPool) -> OrmResult<i64> {
463 let (sql, bindings) = self.build_count();
464 let mut q = sqlx::query_as::<Postgres, (i64,)>(&sql);
465 for b in bindings {
466 q = bind_i64_value(q, b);
467 }
468 let row = q.fetch_one(pool).await.map_err(OrmError::from_sqlx)?;
469 Ok(row.0)
470 }
471
472 pub async fn exists(self, pool: &PgPool) -> OrmResult<bool> {
474 Ok(self.count(pool).await? > 0)
475 }
476
477 pub async fn fetch_page(self, page: i64, per_page: i64, pool: &PgPool) -> OrmResult<Page<T>>
479 where
480 T: for<'r> sqlx::FromRow<'r, PgRow>,
481 {
482 let page = page.max(1);
483 let total = self.clone().count(pool).await?;
484 let items = self.paginate(page, per_page).fetch_all(pool).await?;
485 Ok(Page::new(items, total, page, per_page))
486 }
487
488 pub async fn update_all<F>(self, f: F, pool: &PgPool) -> OrmResult<u64>
492 where
493 F: FnOnce(&mut UpdateBuilder<T>),
494 T: HasColumns,
495 {
496 let mut ub = UpdateBuilder::new();
497 f(&mut ub);
498
499 let mut set_parts: Vec<String> = vec![];
501 let mut bindings: Vec<SqlValue> = vec![];
502 for us in ub.sets {
503 let idx = bindings.len() + 1;
504 set_parts.push(format!("\"{}\" = ${}", us.col, idx));
505 bindings.push(us.val);
506 }
507
508 let mut where_parts: Vec<String> = vec![];
510 for expr in &self.filters {
511 let offset = bindings.len();
512 let reindexed = reindex_params(&expr.sql, offset);
513 where_parts.push(reindexed);
514 bindings.extend(expr.bindings.clone());
515 }
516
517 let mut sql = format!("UPDATE \"{}\" SET {}", self.table, set_parts.join(", "));
518 if !where_parts.is_empty() {
519 sql.push_str(" WHERE ");
520 sql.push_str(&where_parts.join(" AND "));
521 }
522
523 let mut q = sqlx::query(&sql);
524 for b in bindings {
525 q = bind_query_value(q, b);
526 }
527 let result = q.execute(pool).await.map_err(OrmError::from_sqlx)?;
528 Ok(result.rows_affected())
529 }
530
531 pub async fn delete_all(self, pool: &PgPool) -> OrmResult<u64> {
533 let mut bindings: Vec<SqlValue> = vec![];
534 let mut where_parts: Vec<String> = vec![];
535 for expr in &self.filters {
536 let offset = bindings.len();
537 let reindexed = reindex_params(&expr.sql, offset);
538 where_parts.push(reindexed);
539 bindings.extend(expr.bindings.clone());
540 }
541
542 let mut sql = format!("DELETE FROM \"{}\"", self.table);
543 if !where_parts.is_empty() {
544 sql.push_str(" WHERE ");
545 sql.push_str(&where_parts.join(" AND "));
546 }
547
548 let mut q = sqlx::query(&sql);
549 for b in bindings {
550 q = bind_query_value(q, b);
551 }
552 let result = q.execute(pool).await.map_err(OrmError::from_sqlx)?;
553 Ok(result.rows_affected())
554 }
555}
556
557pub struct UpdateBuilder<T> {
562 pub sets: Vec<UpdateSet>,
563 _m: PhantomData<T>,
564}
565
566impl<T> UpdateBuilder<T> {
567 fn new() -> Self {
568 Self {
569 sets: vec![],
570 _m: PhantomData,
571 }
572 }
573}
574
575pub trait HasColumns {
577 type Columns;
578 fn columns_proxy() -> Self::Columns;
579}
580
581fn bind_value<'q, T>(
586 q: sqlx::query::QueryAs<'q, Postgres, T, PgArguments>,
587 val: SqlValue,
588) -> sqlx::query::QueryAs<'q, Postgres, T, PgArguments>
589where
590 T: Send + Unpin,
591{
592 match val {
593 SqlValue::Int(v) => q.bind(v),
594 SqlValue::Float(v) => q.bind(v),
595 SqlValue::Text(v) => q.bind(v),
596 SqlValue::Bool(v) => q.bind(v),
597 SqlValue::Null => q.bind(Option::<String>::None),
598 SqlValue::Json(v) => q.bind(sqlx::types::Json(v)),
599 SqlValue::Bytes(v) => q.bind(v),
600 }
601}
602
603fn bind_i64_value<'q>(
604 q: sqlx::query::QueryAs<'q, Postgres, (i64,), PgArguments>,
605 val: SqlValue,
606) -> sqlx::query::QueryAs<'q, Postgres, (i64,), PgArguments> {
607 match val {
608 SqlValue::Int(v) => q.bind(v),
609 SqlValue::Float(v) => q.bind(v),
610 SqlValue::Text(v) => q.bind(v),
611 SqlValue::Bool(v) => q.bind(v),
612 SqlValue::Null => q.bind(Option::<String>::None),
613 SqlValue::Json(v) => q.bind(sqlx::types::Json(v)),
614 SqlValue::Bytes(v) => q.bind(v),
615 }
616}
617
618fn bind_query_value<'q>(
619 q: sqlx::query::Query<'q, Postgres, PgArguments>,
620 val: SqlValue,
621) -> sqlx::query::Query<'q, Postgres, PgArguments> {
622 match val {
623 SqlValue::Int(v) => q.bind(v),
624 SqlValue::Float(v) => q.bind(v),
625 SqlValue::Text(v) => q.bind(v),
626 SqlValue::Bool(v) => q.bind(v),
627 SqlValue::Null => q.bind(Option::<String>::None),
628 SqlValue::Json(v) => q.bind(sqlx::types::Json(v)),
629 SqlValue::Bytes(v) => q.bind(v),
630 }
631}