1use smallvec::SmallVec;
2use sqlx::{
3 Encode,
4 Postgres,
5 QueryBuilder,
6 Type,
7};
8use sqlxo_traits::{
9 QueryModel,
10 SqlWrite,
11};
12use std::{
13 marker::PhantomData,
14 sync::Arc,
15};
16
17use crate::blocks::SqlWriter;
18
19pub trait Column: Copy {
21 type Model: QueryModel;
22 type Type;
23 const NAME: &'static str;
24 const TABLE: &'static str;
25}
26
27#[derive(Clone, Copy, Debug, Eq, PartialEq)]
28pub struct SelectionColumn {
29 pub table: &'static str,
30 pub column: &'static str,
31}
32
33impl SelectionColumn {
34 pub const fn new(table: &'static str, column: &'static str) -> Self {
35 Self { table, column }
36 }
37}
38
39#[derive(Clone, Copy, Debug, Eq, PartialEq)]
40pub enum AggregateFunction {
41 Count,
42 CountDistinct,
43 Sum,
44 Avg,
45 Min,
46 Max,
47}
48
49impl AggregateFunction {
50 pub const fn sql_name(&self) -> &'static str {
51 match self {
52 Self::Count | Self::CountDistinct => "COUNT",
53 Self::Sum => "SUM",
54 Self::Avg => "AVG",
55 Self::Min => "MIN",
56 Self::Max => "MAX",
57 }
58 }
59}
60
61#[derive(Clone, Copy, Debug, Eq, PartialEq)]
62pub struct AggregateSelection {
63 pub function: AggregateFunction,
64 pub column: Option<SelectionColumn>,
65}
66
67impl AggregateSelection {
68 pub const fn new(
69 function: AggregateFunction,
70 column: Option<SelectionColumn>,
71 ) -> Self {
72 Self { function, column }
73 }
74
75 pub const fn with_column(
76 function: AggregateFunction,
77 column: SelectionColumn,
78 ) -> Self {
79 Self {
80 function,
81 column: Some(column),
82 }
83 }
84}
85
86#[derive(Clone, Copy, Debug, Eq, PartialEq)]
87pub enum SelectionEntry {
88 Column(SelectionColumn),
89 Aggregate(AggregateSelection),
90}
91
92#[derive(Debug, Clone)]
93pub struct SelectionList<Output, Store = SelectionColumn> {
94 pub(crate) entries: SmallVec<[Store; 4]>,
95 _marker: PhantomData<Output>,
96}
97
98impl<Output, Store> SelectionList<Output, Store> {
99 pub fn new(entries: SmallVec<[Store; 4]>) -> Self {
100 Self {
101 entries,
102 _marker: PhantomData,
103 }
104 }
105
106 pub fn entries(&self) -> &[Store] {
107 &self.entries
108 }
109
110 pub fn clone_entries(&self) -> SmallVec<[Store; 4]>
111 where
112 Store: Clone,
113 {
114 self.entries.clone()
115 }
116
117 pub fn len(&self) -> usize {
118 self.entries.len()
119 }
120
121 pub fn is_empty(&self) -> bool {
122 self.entries.is_empty()
123 }
124}
125
126impl<Output> SelectionList<Output, SelectionColumn> {
127 pub fn columns(&self) -> &[SelectionColumn] {
128 &self.entries
129 }
130
131 pub fn clone_columns(&self) -> SmallVec<[SelectionColumn; 4]> {
132 self.entries.clone()
133 }
134
135 pub fn push_returning(
136 &self,
137 qb: &mut QueryBuilder<'static, Postgres>,
138 table: &str,
139 ) {
140 qb.push(" RETURNING ");
141 for (idx, col) in self.entries.iter().enumerate() {
142 assert_eq!(
143 col.table, table,
144 "`RETURNING` may only use columns from `{}` but got `{}`",
145 table, col.table,
146 );
147
148 if idx > 0 {
149 qb.push(", ");
150 }
151 qb.push(format!(r#""{}"."{}""#, table, col.column));
152 }
153 }
154}
155
156impl<Output> SelectionList<Output, SelectionEntry> {
157 pub fn expect_columns(self) -> SelectionList<Output, SelectionColumn> {
158 let mut cols = SmallVec::<[SelectionColumn; 4]>::new();
159 for entry in self.entries {
160 match entry {
161 SelectionEntry::Column(col) => cols.push(col),
162 SelectionEntry::Aggregate(_) => {
163 panic!("aggregates are not supported in this context")
164 }
165 }
166 }
167 SelectionList::new(cols)
168 }
169}
170
171pub fn push_returning<Output>(
172 qb: &mut QueryBuilder<'static, Postgres>,
173 table: &str,
174 selection: Option<&SelectionList<Output, SelectionColumn>>,
175) {
176 if let Some(sel) = selection {
177 sel.push_returning(qb, table);
178 } else {
179 qb.push(" RETURNING *");
180 }
181}
182
183#[derive(Debug, Clone)]
184pub struct GroupByList {
185 columns: SmallVec<[SelectionColumn; 4]>,
186}
187
188impl GroupByList {
189 pub fn new(columns: SmallVec<[SelectionColumn; 4]>) -> Self {
190 Self { columns }
191 }
192
193 pub fn columns(&self) -> &[SelectionColumn] {
194 &self.columns
195 }
196
197 pub fn into_columns(self) -> SmallVec<[SelectionColumn; 4]> {
198 self.columns
199 }
200}
201
202#[derive(Clone)]
203pub struct HavingValue {
204 binder: Arc<dyn Fn(&mut SqlWriter) + Send + Sync>,
205}
206
207impl HavingValue {
208 pub fn new<T>(value: T) -> Self
209 where
210 T: Clone + Send + Sync + 'static,
211 T: Encode<'static, Postgres>,
212 T: Type<Postgres>,
213 {
214 let value = Arc::new(value);
215 Self {
216 binder: Arc::new(move |writer: &mut SqlWriter| {
217 writer.bind((value.as_ref()).clone());
218 }),
219 }
220 }
221
222 pub fn bind(&self, writer: &mut SqlWriter) {
223 (self.binder)(writer);
224 }
225}
226
227#[derive(Clone, Copy, Debug)]
228pub enum ComparisonOp {
229 Eq,
230 Ne,
231 Gt,
232 Ge,
233 Lt,
234 Le,
235}
236
237impl ComparisonOp {
238 pub const fn as_str(&self) -> &'static str {
239 match self {
240 Self::Eq => "=",
241 Self::Ne => "!=",
242 Self::Gt => ">",
243 Self::Ge => ">=",
244 Self::Lt => "<",
245 Self::Le => "<=",
246 }
247 }
248}
249
250#[derive(Clone)]
251pub struct HavingPredicate {
252 pub selection: AggregateSelection,
253 pub comparator: ComparisonOp,
254 value: HavingValue,
255}
256
257impl HavingPredicate {
258 pub fn new(
259 selection: AggregateSelection,
260 comparator: ComparisonOp,
261 value: HavingValue,
262 ) -> Self {
263 Self {
264 selection,
265 comparator,
266 value,
267 }
268 }
269
270 pub fn bind_value(&self, writer: &mut SqlWriter) {
271 self.value.bind(writer);
272 }
273}
274
275#[derive(Clone)]
276pub struct HavingList {
277 predicates: Vec<HavingPredicate>,
278}
279
280impl HavingList {
281 pub fn new(predicates: Vec<HavingPredicate>) -> Self {
282 Self { predicates }
283 }
284
285 pub fn predicates(&self) -> &[HavingPredicate] {
286 &self.predicates
287 }
288
289 pub fn into_predicates(self) -> Vec<HavingPredicate> {
290 self.predicates
291 }
292}
293
294impl From<HavingPredicate> for HavingList {
295 fn from(predicate: HavingPredicate) -> Self {
296 Self::new(vec![predicate])
297 }
298}
299
300pub trait SelectionExpr {
301 type Output;
302 fn record(self, entries: &mut SmallVec<[SelectionEntry; 4]>);
303}
304
305impl<T> SelectionExpr for T
306where
307 T: Column,
308{
309 type Output = T::Type;
310
311 fn record(self, entries: &mut SmallVec<[SelectionEntry; 4]>) {
312 let column = SelectionColumn::new(T::TABLE, T::NAME);
313 entries.push(SelectionEntry::Column(column));
314 }
315}
316
317#[derive(Clone, Copy)]
318pub struct CountAllExpr;
319
320impl CountAllExpr {
321 pub const fn new() -> Self {
322 Self
323 }
324}
325
326impl Default for CountAllExpr {
327 fn default() -> Self {
328 Self::new()
329 }
330}
331
332#[derive(Clone, Copy)]
333pub struct CountExpr<C: Column>(PhantomData<C>);
334
335impl<C: Column> CountExpr<C> {
336 pub const fn new() -> Self {
337 Self(PhantomData)
338 }
339}
340
341impl<C: Column> Default for CountExpr<C> {
342 fn default() -> Self {
343 Self::new()
344 }
345}
346
347#[derive(Clone, Copy)]
348pub struct CountDistinctExpr<C: Column>(PhantomData<C>);
349
350impl<C: Column> CountDistinctExpr<C> {
351 pub const fn new() -> Self {
352 Self(PhantomData)
353 }
354}
355
356impl<C: Column> Default for CountDistinctExpr<C> {
357 fn default() -> Self {
358 Self::new()
359 }
360}
361
362#[derive(Clone, Copy)]
363pub struct SumExpr<C: Column>(PhantomData<C>);
364
365impl<C: Column> SumExpr<C> {
366 pub const fn new() -> Self {
367 Self(PhantomData)
368 }
369}
370
371impl<C: Column> Default for SumExpr<C> {
372 fn default() -> Self {
373 Self::new()
374 }
375}
376
377#[derive(Clone, Copy)]
378pub struct AvgExpr<C: Column>(PhantomData<C>);
379
380impl<C: Column> AvgExpr<C> {
381 pub const fn new() -> Self {
382 Self(PhantomData)
383 }
384}
385
386impl<C: Column> Default for AvgExpr<C> {
387 fn default() -> Self {
388 Self::new()
389 }
390}
391
392#[derive(Clone, Copy)]
393pub struct MinExpr<C: Column>(PhantomData<C>);
394
395impl<C: Column> MinExpr<C> {
396 pub const fn new() -> Self {
397 Self(PhantomData)
398 }
399}
400
401impl<C: Column> Default for MinExpr<C> {
402 fn default() -> Self {
403 Self::new()
404 }
405}
406
407#[derive(Clone, Copy)]
408pub struct MaxExpr<C: Column>(PhantomData<C>);
409
410impl<C: Column> MaxExpr<C> {
411 pub const fn new() -> Self {
412 Self(PhantomData)
413 }
414}
415
416impl<C: Column> Default for MaxExpr<C> {
417 fn default() -> Self {
418 Self::new()
419 }
420}
421
422impl SelectionExpr for CountAllExpr {
423 type Output = i64;
424
425 fn record(self, entries: &mut SmallVec<[SelectionEntry; 4]>) {
426 entries.push(SelectionEntry::Aggregate(self.selection()));
427 }
428}
429
430impl AggregateSelectionExpr for CountAllExpr {
431 fn selection(&self) -> AggregateSelection {
432 AggregateSelection::new(AggregateFunction::Count, None)
433 }
434}
435
436impl<C> SelectionExpr for CountExpr<C>
437where
438 C: Column,
439{
440 type Output = i64;
441
442 fn record(self, entries: &mut SmallVec<[SelectionEntry; 4]>) {
443 entries.push(SelectionEntry::Aggregate(self.selection()));
444 }
445}
446
447impl<C> AggregateSelectionExpr for CountExpr<C>
448where
449 C: Column,
450{
451 fn selection(&self) -> AggregateSelection {
452 let column = SelectionColumn::new(C::TABLE, C::NAME);
453 AggregateSelection::with_column(AggregateFunction::Count, column)
454 }
455}
456
457impl<C> SelectionExpr for CountDistinctExpr<C>
458where
459 C: Column,
460{
461 type Output = i64;
462
463 fn record(self, entries: &mut SmallVec<[SelectionEntry; 4]>) {
464 entries.push(SelectionEntry::Aggregate(self.selection()));
465 }
466}
467
468impl<C> AggregateSelectionExpr for CountDistinctExpr<C>
469where
470 C: Column,
471{
472 fn selection(&self) -> AggregateSelection {
473 let column = SelectionColumn::new(C::TABLE, C::NAME);
474 AggregateSelection::with_column(
475 AggregateFunction::CountDistinct,
476 column,
477 )
478 }
479}
480
481impl<C> SelectionExpr for SumExpr<C>
482where
483 C: Column,
484{
485 type Output = Option<C::Type>;
486
487 fn record(self, entries: &mut SmallVec<[SelectionEntry; 4]>) {
488 entries.push(SelectionEntry::Aggregate(self.selection()));
489 }
490}
491
492impl<C> AggregateSelectionExpr for SumExpr<C>
493where
494 C: Column,
495{
496 fn selection(&self) -> AggregateSelection {
497 let column = SelectionColumn::new(C::TABLE, C::NAME);
498 AggregateSelection::with_column(AggregateFunction::Sum, column)
499 }
500}
501
502impl<C> SelectionExpr for AvgExpr<C>
503where
504 C: Column,
505{
506 type Output = Option<C::Type>;
507
508 fn record(self, entries: &mut SmallVec<[SelectionEntry; 4]>) {
509 entries.push(SelectionEntry::Aggregate(self.selection()));
510 }
511}
512
513impl<C> AggregateSelectionExpr for AvgExpr<C>
514where
515 C: Column,
516{
517 fn selection(&self) -> AggregateSelection {
518 let column = SelectionColumn::new(C::TABLE, C::NAME);
519 AggregateSelection::with_column(AggregateFunction::Avg, column)
520 }
521}
522
523impl<C> SelectionExpr for MinExpr<C>
524where
525 C: Column,
526{
527 type Output = Option<C::Type>;
528
529 fn record(self, entries: &mut SmallVec<[SelectionEntry; 4]>) {
530 entries.push(SelectionEntry::Aggregate(self.selection()));
531 }
532}
533
534impl<C> AggregateSelectionExpr for MinExpr<C>
535where
536 C: Column,
537{
538 fn selection(&self) -> AggregateSelection {
539 let column = SelectionColumn::new(C::TABLE, C::NAME);
540 AggregateSelection::with_column(AggregateFunction::Min, column)
541 }
542}
543
544impl<C> SelectionExpr for MaxExpr<C>
545where
546 C: Column,
547{
548 type Output = Option<C::Type>;
549
550 fn record(self, entries: &mut SmallVec<[SelectionEntry; 4]>) {
551 entries.push(SelectionEntry::Aggregate(self.selection()));
552 }
553}
554
555impl<C> AggregateSelectionExpr for MaxExpr<C>
556where
557 C: Column,
558{
559 fn selection(&self) -> AggregateSelection {
560 let column = SelectionColumn::new(C::TABLE, C::NAME);
561 AggregateSelection::with_column(AggregateFunction::Max, column)
562 }
563}
564
565pub trait AggregateSelectionExpr: Copy {
566 fn selection(&self) -> AggregateSelection;
567}
568
569pub trait AggregatePredicateBuilder: AggregateSelectionExpr + Copy {
570 fn compare<T>(self, op: ComparisonOp, value: T) -> HavingPredicate
571 where
572 T: Clone + Send + Sync + 'static,
573 T: Encode<'static, Postgres>,
574 T: Type<Postgres>,
575 {
576 HavingPredicate::new(self.selection(), op, HavingValue::new(value))
577 }
578
579 fn eq<T>(self, value: T) -> HavingPredicate
580 where
581 T: Clone + Send + Sync + 'static,
582 T: Encode<'static, Postgres>,
583 T: Type<Postgres>,
584 {
585 self.compare(ComparisonOp::Eq, value)
586 }
587
588 fn ne<T>(self, value: T) -> HavingPredicate
589 where
590 T: Clone + Send + Sync + 'static,
591 T: Encode<'static, Postgres>,
592 T: Type<Postgres>,
593 {
594 self.compare(ComparisonOp::Ne, value)
595 }
596
597 fn gt<T>(self, value: T) -> HavingPredicate
598 where
599 T: Clone + Send + Sync + 'static,
600 T: Encode<'static, Postgres>,
601 T: Type<Postgres>,
602 {
603 self.compare(ComparisonOp::Gt, value)
604 }
605
606 fn ge<T>(self, value: T) -> HavingPredicate
607 where
608 T: Clone + Send + Sync + 'static,
609 T: Encode<'static, Postgres>,
610 T: Type<Postgres>,
611 {
612 self.compare(ComparisonOp::Ge, value)
613 }
614
615 fn lt<T>(self, value: T) -> HavingPredicate
616 where
617 T: Clone + Send + Sync + 'static,
618 T: Encode<'static, Postgres>,
619 T: Type<Postgres>,
620 {
621 self.compare(ComparisonOp::Lt, value)
622 }
623
624 fn le<T>(self, value: T) -> HavingPredicate
625 where
626 T: Clone + Send + Sync + 'static,
627 T: Encode<'static, Postgres>,
628 T: Type<Postgres>,
629 {
630 self.compare(ComparisonOp::Le, value)
631 }
632}
633
634impl<T> AggregatePredicateBuilder for T where T: AggregateSelectionExpr + Copy {}
635
636macro_rules! aggregate_predicate_methods_body {
637 () => {
638 pub fn eq<T>(self, value: T) -> HavingPredicate
639 where
640 T: Clone + Send + Sync + 'static,
641 T: Encode<'static, Postgres>,
642 T: Type<Postgres>,
643 {
644 AggregatePredicateBuilder::eq(self, value)
645 }
646
647 pub fn ne<T>(self, value: T) -> HavingPredicate
648 where
649 T: Clone + Send + Sync + 'static,
650 T: Encode<'static, Postgres>,
651 T: Type<Postgres>,
652 {
653 AggregatePredicateBuilder::ne(self, value)
654 }
655
656 pub fn gt<T>(self, value: T) -> HavingPredicate
657 where
658 T: Clone + Send + Sync + 'static,
659 T: Encode<'static, Postgres>,
660 T: Type<Postgres>,
661 {
662 AggregatePredicateBuilder::gt(self, value)
663 }
664
665 pub fn ge<T>(self, value: T) -> HavingPredicate
666 where
667 T: Clone + Send + Sync + 'static,
668 T: Encode<'static, Postgres>,
669 T: Type<Postgres>,
670 {
671 AggregatePredicateBuilder::ge(self, value)
672 }
673
674 pub fn lt<T>(self, value: T) -> HavingPredicate
675 where
676 T: Clone + Send + Sync + 'static,
677 T: Encode<'static, Postgres>,
678 T: Type<Postgres>,
679 {
680 AggregatePredicateBuilder::lt(self, value)
681 }
682
683 pub fn le<T>(self, value: T) -> HavingPredicate
684 where
685 T: Clone + Send + Sync + 'static,
686 T: Encode<'static, Postgres>,
687 T: Type<Postgres>,
688 {
689 AggregatePredicateBuilder::le(self, value)
690 }
691 };
692}
693
694macro_rules! impl_aggregate_predicate_methods {
695 ($ty:ty) => {
696 impl $ty {
697 aggregate_predicate_methods_body!();
698 }
699 };
700 ($ty:ident < $gen:ident > where $($bounds:tt)+) => {
701 impl<$gen> $ty<$gen>
702 where
703 $($bounds)+
704 {
705 aggregate_predicate_methods_body!();
706 }
707 };
708}
709
710impl_aggregate_predicate_methods!(CountAllExpr);
711impl_aggregate_predicate_methods!(CountExpr<C> where C: Column);
712impl_aggregate_predicate_methods!(CountDistinctExpr<C> where C: Column);
713impl_aggregate_predicate_methods!(SumExpr<C> where C: Column);
714impl_aggregate_predicate_methods!(AvgExpr<C> where C: Column);
715impl_aggregate_predicate_methods!(MinExpr<C> where C: Column);
716impl_aggregate_predicate_methods!(MaxExpr<C> where C: Column);
717
718#[derive(Clone, Copy)]
719pub struct SelectionOutput<T>(pub PhantomData<T>);
720
721impl<T> SelectionOutput<T> {
722 pub fn into_selection_list<Store>(
723 self,
724 entries: SmallVec<[Store; 4]>,
725 ) -> SelectionList<T, Store> {
726 SelectionList::new(entries)
727 }
728}
729
730pub fn record_selection_expr<E>(
731 expr: E,
732 entries: &mut SmallVec<[SelectionEntry; 4]>,
733) -> SelectionOutput<E::Output>
734where
735 E: SelectionExpr,
736{
737 expr.record(entries);
738 SelectionOutput(PhantomData)
739}
740
741pub trait SelectionOutputTuple {
742 type Output;
743
744 fn flatten(self) -> SelectionOutput<Self::Output>;
745}
746
747impl<A> SelectionOutputTuple for (SelectionOutput<A>,) {
748 type Output = (A,);
749
750 fn flatten(self) -> SelectionOutput<(A,)> {
751 let _ = self;
752 SelectionOutput(PhantomData)
753 }
754}
755
756macro_rules! impl_selection_output_tuple {
757 ($($name:ident),+) => {
758 impl<$($name),+> SelectionOutputTuple for ($(SelectionOutput<$name>,)+) {
759 type Output = ($($name,)+);
760
761 fn flatten(self) -> SelectionOutput<($($name,)+)> {
762 let _ = self;
763 SelectionOutput(PhantomData)
764 }
765 }
766 };
767}
768
769impl_selection_output_tuple!(A, B);
770impl_selection_output_tuple!(A, B, C);
771impl_selection_output_tuple!(A, B, C, D);
772impl_selection_output_tuple!(A, B, C, D, E);
773impl_selection_output_tuple!(A, B, C, D, E, F);
774impl_selection_output_tuple!(A, B, C, D, E, F, G);
775impl_selection_output_tuple!(A, B, C, D, E, F, G, H);
776impl_selection_output_tuple!(A, B, C, D, E, F, G, H, I);
777impl_selection_output_tuple!(A, B, C, D, E, F, G, H, I, J);
778impl_selection_output_tuple!(A, B, C, D, E, F, G, H, I, J, K);
779impl_selection_output_tuple!(A, B, C, D, E, F, G, H, I, J, K, L);
780
781#[macro_export]
782macro_rules! take {
783 ($first:expr $(, $rest:expr)* $(,)?) => {{
784 use $crate::select::{
785 record_selection_expr as __sqlxo_record_selection_expr,
786 SelectionEntry as __SqlxoSelectionEntry,
787 SelectionList as __SqlxoSelectionList,
788 SelectionOutputTuple as __SqlxoSelectionOutputTuple,
789 };
790
791 let mut __entries: $crate::smallvec::SmallVec<
792 [__SqlxoSelectionEntry; 4]
793 > = $crate::smallvec::SmallVec::new();
794
795 let __outputs = (
796 __sqlxo_record_selection_expr($first, &mut __entries),
797 $(
798 __sqlxo_record_selection_expr($rest, &mut __entries),
799 )*
800 );
801
802 let __output_marker = __SqlxoSelectionOutputTuple::flatten(__outputs);
803 __output_marker.into_selection_list(__entries)
804 }};
805}
806
807#[macro_export]
808macro_rules! group_by {
809 ($first:ty $(, $rest:ty)* $(,)?) => {{
810 use $crate::select::{
811 Column as __SqlxoColumn,
812 GroupByList as __SqlxoGroupByList,
813 SelectionColumn as __SqlxoSelectionColumn,
814 };
815
816 let mut __cols: $crate::smallvec::SmallVec<[__SqlxoSelectionColumn; 4]> =
817 $crate::smallvec::SmallVec::new();
818 __cols.push(__SqlxoSelectionColumn::new(
819 <$first as __SqlxoColumn>::TABLE,
820 <$first as __SqlxoColumn>::NAME,
821 ));
822 $(
823 __cols.push(__SqlxoSelectionColumn::new(
824 <$rest as __SqlxoColumn>::TABLE,
825 <$rest as __SqlxoColumn>::NAME,
826 ));
827 )*
828
829 __SqlxoGroupByList::new(__cols)
830 }};
831}
832
833#[macro_export]
834macro_rules! having {
835 ($first:expr $(, $rest:expr)* $(,)?) => {{
836 let mut __preds = Vec::new();
837 __preds.push($first);
838 $(
839 __preds.push($rest);
840 )*
841 $crate::select::HavingList::new(__preds)
842 }};
843}