Skip to main content

sqlxo/
select.rs

1use smallvec::SmallVec;
2use sqlx::{
3	Encode,
4	Postgres,
5	QueryBuilder,
6	Type,
7};
8use sqlxo_traits::{
9	QueryModel,
10	SqlWrite,
11};
12use std::{
13	marker::PhantomData,
14	sync::Arc,
15};
16
17use crate::blocks::SqlWriter;
18
19/// Marker trait for model columns that can participate in `take!`.
20pub trait Column: Copy {
21	type Model: QueryModel;
22	type Type;
23	const NAME: &'static str;
24	const TABLE: &'static str;
25}
26
27#[derive(Clone, Copy, Debug, Eq, PartialEq)]
28pub struct SelectionColumn {
29	pub table:  &'static str,
30	pub column: &'static str,
31}
32
33impl SelectionColumn {
34	pub const fn new(table: &'static str, column: &'static str) -> Self {
35		Self { table, column }
36	}
37}
38
39#[derive(Clone, Copy, Debug, Eq, PartialEq)]
40pub enum AggregateFunction {
41	Count,
42	CountDistinct,
43	Sum,
44	Avg,
45	Min,
46	Max,
47}
48
49impl AggregateFunction {
50	pub const fn sql_name(&self) -> &'static str {
51		match self {
52			Self::Count | Self::CountDistinct => "COUNT",
53			Self::Sum => "SUM",
54			Self::Avg => "AVG",
55			Self::Min => "MIN",
56			Self::Max => "MAX",
57		}
58	}
59}
60
61#[derive(Clone, Copy, Debug, Eq, PartialEq)]
62pub struct AggregateSelection {
63	pub function: AggregateFunction,
64	pub column:   Option<SelectionColumn>,
65}
66
67impl AggregateSelection {
68	pub const fn new(
69		function: AggregateFunction,
70		column: Option<SelectionColumn>,
71	) -> Self {
72		Self { function, column }
73	}
74
75	pub const fn with_column(
76		function: AggregateFunction,
77		column: SelectionColumn,
78	) -> Self {
79		Self {
80			function,
81			column: Some(column),
82		}
83	}
84}
85
86#[derive(Clone, Copy, Debug, Eq, PartialEq)]
87pub enum SelectionEntry {
88	Column(SelectionColumn),
89	Aggregate(AggregateSelection),
90}
91
92#[derive(Debug, Clone)]
93pub struct SelectionList<Output, Store = SelectionColumn> {
94	pub(crate) entries: SmallVec<[Store; 4]>,
95	_marker:            PhantomData<Output>,
96}
97
98impl<Output, Store> SelectionList<Output, Store> {
99	pub fn new(entries: SmallVec<[Store; 4]>) -> Self {
100		Self {
101			entries,
102			_marker: PhantomData,
103		}
104	}
105
106	pub fn entries(&self) -> &[Store] {
107		&self.entries
108	}
109
110	pub fn clone_entries(&self) -> SmallVec<[Store; 4]>
111	where
112		Store: Clone,
113	{
114		self.entries.clone()
115	}
116
117	pub fn len(&self) -> usize {
118		self.entries.len()
119	}
120
121	pub fn is_empty(&self) -> bool {
122		self.entries.is_empty()
123	}
124}
125
126impl<Output> SelectionList<Output, SelectionColumn> {
127	pub fn columns(&self) -> &[SelectionColumn] {
128		&self.entries
129	}
130
131	pub fn clone_columns(&self) -> SmallVec<[SelectionColumn; 4]> {
132		self.entries.clone()
133	}
134
135	pub fn push_returning(
136		&self,
137		qb: &mut QueryBuilder<'static, Postgres>,
138		table: &str,
139	) {
140		qb.push(" RETURNING ");
141		for (idx, col) in self.entries.iter().enumerate() {
142			assert_eq!(
143				col.table, table,
144				"`RETURNING` may only use columns from `{}` but got `{}`",
145				table, col.table,
146			);
147
148			if idx > 0 {
149				qb.push(", ");
150			}
151			qb.push(format!(r#""{}"."{}""#, table, col.column));
152		}
153	}
154}
155
156impl<Output> SelectionList<Output, SelectionEntry> {
157	pub fn expect_columns(self) -> SelectionList<Output, SelectionColumn> {
158		let mut cols = SmallVec::<[SelectionColumn; 4]>::new();
159		for entry in self.entries {
160			match entry {
161				SelectionEntry::Column(col) => cols.push(col),
162				SelectionEntry::Aggregate(_) => {
163					panic!("aggregates are not supported in this context")
164				}
165			}
166		}
167		SelectionList::new(cols)
168	}
169}
170
171pub fn push_returning<Output>(
172	qb: &mut QueryBuilder<'static, Postgres>,
173	table: &str,
174	selection: Option<&SelectionList<Output, SelectionColumn>>,
175) {
176	if let Some(sel) = selection {
177		sel.push_returning(qb, table);
178	} else {
179		qb.push(" RETURNING *");
180	}
181}
182
183#[derive(Debug, Clone)]
184pub struct GroupByList {
185	columns: SmallVec<[SelectionColumn; 4]>,
186}
187
188impl GroupByList {
189	pub fn new(columns: SmallVec<[SelectionColumn; 4]>) -> Self {
190		Self { columns }
191	}
192
193	pub fn columns(&self) -> &[SelectionColumn] {
194		&self.columns
195	}
196
197	pub fn into_columns(self) -> SmallVec<[SelectionColumn; 4]> {
198		self.columns
199	}
200}
201
202#[derive(Clone)]
203pub struct HavingValue {
204	binder: Arc<dyn Fn(&mut SqlWriter) + Send + Sync>,
205}
206
207impl HavingValue {
208	pub fn new<T>(value: T) -> Self
209	where
210		T: Clone + Send + Sync + 'static,
211		T: Encode<'static, Postgres>,
212		T: Type<Postgres>,
213	{
214		let value = Arc::new(value);
215		Self {
216			binder: Arc::new(move |writer: &mut SqlWriter| {
217				writer.bind((value.as_ref()).clone());
218			}),
219		}
220	}
221
222	pub fn bind(&self, writer: &mut SqlWriter) {
223		(self.binder)(writer);
224	}
225}
226
227#[derive(Clone, Copy, Debug)]
228pub enum ComparisonOp {
229	Eq,
230	Ne,
231	Gt,
232	Ge,
233	Lt,
234	Le,
235}
236
237impl ComparisonOp {
238	pub const fn as_str(&self) -> &'static str {
239		match self {
240			Self::Eq => "=",
241			Self::Ne => "!=",
242			Self::Gt => ">",
243			Self::Ge => ">=",
244			Self::Lt => "<",
245			Self::Le => "<=",
246		}
247	}
248}
249
250#[derive(Clone)]
251pub struct HavingPredicate {
252	pub selection:  AggregateSelection,
253	pub comparator: ComparisonOp,
254	value:          HavingValue,
255}
256
257impl HavingPredicate {
258	pub fn new(
259		selection: AggregateSelection,
260		comparator: ComparisonOp,
261		value: HavingValue,
262	) -> Self {
263		Self {
264			selection,
265			comparator,
266			value,
267		}
268	}
269
270	pub fn bind_value(&self, writer: &mut SqlWriter) {
271		self.value.bind(writer);
272	}
273}
274
275#[derive(Clone)]
276pub struct HavingList {
277	predicates: Vec<HavingPredicate>,
278}
279
280impl HavingList {
281	pub fn new(predicates: Vec<HavingPredicate>) -> Self {
282		Self { predicates }
283	}
284
285	pub fn predicates(&self) -> &[HavingPredicate] {
286		&self.predicates
287	}
288
289	pub fn into_predicates(self) -> Vec<HavingPredicate> {
290		self.predicates
291	}
292}
293
294impl From<HavingPredicate> for HavingList {
295	fn from(predicate: HavingPredicate) -> Self {
296		Self::new(vec![predicate])
297	}
298}
299
300pub trait SelectionExpr {
301	type Output;
302	fn record(self, entries: &mut SmallVec<[SelectionEntry; 4]>);
303}
304
305impl<T> SelectionExpr for T
306where
307	T: Column,
308{
309	type Output = T::Type;
310
311	fn record(self, entries: &mut SmallVec<[SelectionEntry; 4]>) {
312		let column = SelectionColumn::new(T::TABLE, T::NAME);
313		entries.push(SelectionEntry::Column(column));
314	}
315}
316
317#[derive(Clone, Copy)]
318pub struct CountAllExpr;
319
320impl CountAllExpr {
321	pub const fn new() -> Self {
322		Self
323	}
324}
325
326impl Default for CountAllExpr {
327	fn default() -> Self {
328		Self::new()
329	}
330}
331
332#[derive(Clone, Copy)]
333pub struct CountExpr<C: Column>(PhantomData<C>);
334
335impl<C: Column> CountExpr<C> {
336	pub const fn new() -> Self {
337		Self(PhantomData)
338	}
339}
340
341impl<C: Column> Default for CountExpr<C> {
342	fn default() -> Self {
343		Self::new()
344	}
345}
346
347#[derive(Clone, Copy)]
348pub struct CountDistinctExpr<C: Column>(PhantomData<C>);
349
350impl<C: Column> CountDistinctExpr<C> {
351	pub const fn new() -> Self {
352		Self(PhantomData)
353	}
354}
355
356impl<C: Column> Default for CountDistinctExpr<C> {
357	fn default() -> Self {
358		Self::new()
359	}
360}
361
362#[derive(Clone, Copy)]
363pub struct SumExpr<C: Column>(PhantomData<C>);
364
365impl<C: Column> SumExpr<C> {
366	pub const fn new() -> Self {
367		Self(PhantomData)
368	}
369}
370
371impl<C: Column> Default for SumExpr<C> {
372	fn default() -> Self {
373		Self::new()
374	}
375}
376
377#[derive(Clone, Copy)]
378pub struct AvgExpr<C: Column>(PhantomData<C>);
379
380impl<C: Column> AvgExpr<C> {
381	pub const fn new() -> Self {
382		Self(PhantomData)
383	}
384}
385
386impl<C: Column> Default for AvgExpr<C> {
387	fn default() -> Self {
388		Self::new()
389	}
390}
391
392#[derive(Clone, Copy)]
393pub struct MinExpr<C: Column>(PhantomData<C>);
394
395impl<C: Column> MinExpr<C> {
396	pub const fn new() -> Self {
397		Self(PhantomData)
398	}
399}
400
401impl<C: Column> Default for MinExpr<C> {
402	fn default() -> Self {
403		Self::new()
404	}
405}
406
407#[derive(Clone, Copy)]
408pub struct MaxExpr<C: Column>(PhantomData<C>);
409
410impl<C: Column> MaxExpr<C> {
411	pub const fn new() -> Self {
412		Self(PhantomData)
413	}
414}
415
416impl<C: Column> Default for MaxExpr<C> {
417	fn default() -> Self {
418		Self::new()
419	}
420}
421
422impl SelectionExpr for CountAllExpr {
423	type Output = i64;
424
425	fn record(self, entries: &mut SmallVec<[SelectionEntry; 4]>) {
426		entries.push(SelectionEntry::Aggregate(self.selection()));
427	}
428}
429
430impl AggregateSelectionExpr for CountAllExpr {
431	fn selection(&self) -> AggregateSelection {
432		AggregateSelection::new(AggregateFunction::Count, None)
433	}
434}
435
436impl<C> SelectionExpr for CountExpr<C>
437where
438	C: Column,
439{
440	type Output = i64;
441
442	fn record(self, entries: &mut SmallVec<[SelectionEntry; 4]>) {
443		entries.push(SelectionEntry::Aggregate(self.selection()));
444	}
445}
446
447impl<C> AggregateSelectionExpr for CountExpr<C>
448where
449	C: Column,
450{
451	fn selection(&self) -> AggregateSelection {
452		let column = SelectionColumn::new(C::TABLE, C::NAME);
453		AggregateSelection::with_column(AggregateFunction::Count, column)
454	}
455}
456
457impl<C> SelectionExpr for CountDistinctExpr<C>
458where
459	C: Column,
460{
461	type Output = i64;
462
463	fn record(self, entries: &mut SmallVec<[SelectionEntry; 4]>) {
464		entries.push(SelectionEntry::Aggregate(self.selection()));
465	}
466}
467
468impl<C> AggregateSelectionExpr for CountDistinctExpr<C>
469where
470	C: Column,
471{
472	fn selection(&self) -> AggregateSelection {
473		let column = SelectionColumn::new(C::TABLE, C::NAME);
474		AggregateSelection::with_column(
475			AggregateFunction::CountDistinct,
476			column,
477		)
478	}
479}
480
481impl<C> SelectionExpr for SumExpr<C>
482where
483	C: Column,
484{
485	type Output = Option<C::Type>;
486
487	fn record(self, entries: &mut SmallVec<[SelectionEntry; 4]>) {
488		entries.push(SelectionEntry::Aggregate(self.selection()));
489	}
490}
491
492impl<C> AggregateSelectionExpr for SumExpr<C>
493where
494	C: Column,
495{
496	fn selection(&self) -> AggregateSelection {
497		let column = SelectionColumn::new(C::TABLE, C::NAME);
498		AggregateSelection::with_column(AggregateFunction::Sum, column)
499	}
500}
501
502impl<C> SelectionExpr for AvgExpr<C>
503where
504	C: Column,
505{
506	type Output = Option<C::Type>;
507
508	fn record(self, entries: &mut SmallVec<[SelectionEntry; 4]>) {
509		entries.push(SelectionEntry::Aggregate(self.selection()));
510	}
511}
512
513impl<C> AggregateSelectionExpr for AvgExpr<C>
514where
515	C: Column,
516{
517	fn selection(&self) -> AggregateSelection {
518		let column = SelectionColumn::new(C::TABLE, C::NAME);
519		AggregateSelection::with_column(AggregateFunction::Avg, column)
520	}
521}
522
523impl<C> SelectionExpr for MinExpr<C>
524where
525	C: Column,
526{
527	type Output = Option<C::Type>;
528
529	fn record(self, entries: &mut SmallVec<[SelectionEntry; 4]>) {
530		entries.push(SelectionEntry::Aggregate(self.selection()));
531	}
532}
533
534impl<C> AggregateSelectionExpr for MinExpr<C>
535where
536	C: Column,
537{
538	fn selection(&self) -> AggregateSelection {
539		let column = SelectionColumn::new(C::TABLE, C::NAME);
540		AggregateSelection::with_column(AggregateFunction::Min, column)
541	}
542}
543
544impl<C> SelectionExpr for MaxExpr<C>
545where
546	C: Column,
547{
548	type Output = Option<C::Type>;
549
550	fn record(self, entries: &mut SmallVec<[SelectionEntry; 4]>) {
551		entries.push(SelectionEntry::Aggregate(self.selection()));
552	}
553}
554
555impl<C> AggregateSelectionExpr for MaxExpr<C>
556where
557	C: Column,
558{
559	fn selection(&self) -> AggregateSelection {
560		let column = SelectionColumn::new(C::TABLE, C::NAME);
561		AggregateSelection::with_column(AggregateFunction::Max, column)
562	}
563}
564
565pub trait AggregateSelectionExpr: Copy {
566	fn selection(&self) -> AggregateSelection;
567}
568
569pub trait AggregatePredicateBuilder: AggregateSelectionExpr + Copy {
570	fn compare<T>(self, op: ComparisonOp, value: T) -> HavingPredicate
571	where
572		T: Clone + Send + Sync + 'static,
573		T: Encode<'static, Postgres>,
574		T: Type<Postgres>,
575	{
576		HavingPredicate::new(self.selection(), op, HavingValue::new(value))
577	}
578
579	fn eq<T>(self, value: T) -> HavingPredicate
580	where
581		T: Clone + Send + Sync + 'static,
582		T: Encode<'static, Postgres>,
583		T: Type<Postgres>,
584	{
585		self.compare(ComparisonOp::Eq, value)
586	}
587
588	fn ne<T>(self, value: T) -> HavingPredicate
589	where
590		T: Clone + Send + Sync + 'static,
591		T: Encode<'static, Postgres>,
592		T: Type<Postgres>,
593	{
594		self.compare(ComparisonOp::Ne, value)
595	}
596
597	fn gt<T>(self, value: T) -> HavingPredicate
598	where
599		T: Clone + Send + Sync + 'static,
600		T: Encode<'static, Postgres>,
601		T: Type<Postgres>,
602	{
603		self.compare(ComparisonOp::Gt, value)
604	}
605
606	fn ge<T>(self, value: T) -> HavingPredicate
607	where
608		T: Clone + Send + Sync + 'static,
609		T: Encode<'static, Postgres>,
610		T: Type<Postgres>,
611	{
612		self.compare(ComparisonOp::Ge, value)
613	}
614
615	fn lt<T>(self, value: T) -> HavingPredicate
616	where
617		T: Clone + Send + Sync + 'static,
618		T: Encode<'static, Postgres>,
619		T: Type<Postgres>,
620	{
621		self.compare(ComparisonOp::Lt, value)
622	}
623
624	fn le<T>(self, value: T) -> HavingPredicate
625	where
626		T: Clone + Send + Sync + 'static,
627		T: Encode<'static, Postgres>,
628		T: Type<Postgres>,
629	{
630		self.compare(ComparisonOp::Le, value)
631	}
632}
633
634impl<T> AggregatePredicateBuilder for T where T: AggregateSelectionExpr + Copy {}
635
636macro_rules! aggregate_predicate_methods_body {
637	() => {
638		pub fn eq<T>(self, value: T) -> HavingPredicate
639		where
640			T: Clone + Send + Sync + 'static,
641			T: Encode<'static, Postgres>,
642			T: Type<Postgres>,
643		{
644			AggregatePredicateBuilder::eq(self, value)
645		}
646
647		pub fn ne<T>(self, value: T) -> HavingPredicate
648		where
649			T: Clone + Send + Sync + 'static,
650			T: Encode<'static, Postgres>,
651			T: Type<Postgres>,
652		{
653			AggregatePredicateBuilder::ne(self, value)
654		}
655
656		pub fn gt<T>(self, value: T) -> HavingPredicate
657		where
658			T: Clone + Send + Sync + 'static,
659			T: Encode<'static, Postgres>,
660			T: Type<Postgres>,
661		{
662			AggregatePredicateBuilder::gt(self, value)
663		}
664
665		pub fn ge<T>(self, value: T) -> HavingPredicate
666		where
667			T: Clone + Send + Sync + 'static,
668			T: Encode<'static, Postgres>,
669			T: Type<Postgres>,
670		{
671			AggregatePredicateBuilder::ge(self, value)
672		}
673
674		pub fn lt<T>(self, value: T) -> HavingPredicate
675		where
676			T: Clone + Send + Sync + 'static,
677			T: Encode<'static, Postgres>,
678			T: Type<Postgres>,
679		{
680			AggregatePredicateBuilder::lt(self, value)
681		}
682
683		pub fn le<T>(self, value: T) -> HavingPredicate
684		where
685			T: Clone + Send + Sync + 'static,
686			T: Encode<'static, Postgres>,
687			T: Type<Postgres>,
688		{
689			AggregatePredicateBuilder::le(self, value)
690		}
691	};
692}
693
694macro_rules! impl_aggregate_predicate_methods {
695	($ty:ty) => {
696		impl $ty {
697			aggregate_predicate_methods_body!();
698		}
699	};
700	($ty:ident < $gen:ident > where $($bounds:tt)+) => {
701		impl<$gen> $ty<$gen>
702		where
703			$($bounds)+
704		{
705			aggregate_predicate_methods_body!();
706		}
707	};
708}
709
710impl_aggregate_predicate_methods!(CountAllExpr);
711impl_aggregate_predicate_methods!(CountExpr<C> where C: Column);
712impl_aggregate_predicate_methods!(CountDistinctExpr<C> where C: Column);
713impl_aggregate_predicate_methods!(SumExpr<C> where C: Column);
714impl_aggregate_predicate_methods!(AvgExpr<C> where C: Column);
715impl_aggregate_predicate_methods!(MinExpr<C> where C: Column);
716impl_aggregate_predicate_methods!(MaxExpr<C> where C: Column);
717
718#[derive(Clone, Copy)]
719pub struct SelectionOutput<T>(pub PhantomData<T>);
720
721impl<T> SelectionOutput<T> {
722	pub fn into_selection_list<Store>(
723		self,
724		entries: SmallVec<[Store; 4]>,
725	) -> SelectionList<T, Store> {
726		SelectionList::new(entries)
727	}
728}
729
730pub fn record_selection_expr<E>(
731	expr: E,
732	entries: &mut SmallVec<[SelectionEntry; 4]>,
733) -> SelectionOutput<E::Output>
734where
735	E: SelectionExpr,
736{
737	expr.record(entries);
738	SelectionOutput(PhantomData)
739}
740
741pub trait SelectionOutputTuple {
742	type Output;
743
744	fn flatten(self) -> SelectionOutput<Self::Output>;
745}
746
747impl<A> SelectionOutputTuple for (SelectionOutput<A>,) {
748	type Output = (A,);
749
750	fn flatten(self) -> SelectionOutput<(A,)> {
751		let _ = self;
752		SelectionOutput(PhantomData)
753	}
754}
755
756macro_rules! impl_selection_output_tuple {
757	($($name:ident),+) => {
758		impl<$($name),+> SelectionOutputTuple for ($(SelectionOutput<$name>,)+) {
759			type Output = ($($name,)+);
760
761			fn flatten(self) -> SelectionOutput<($($name,)+)> {
762				let _ = self;
763				SelectionOutput(PhantomData)
764			}
765		}
766	};
767}
768
769impl_selection_output_tuple!(A, B);
770impl_selection_output_tuple!(A, B, C);
771impl_selection_output_tuple!(A, B, C, D);
772impl_selection_output_tuple!(A, B, C, D, E);
773impl_selection_output_tuple!(A, B, C, D, E, F);
774impl_selection_output_tuple!(A, B, C, D, E, F, G);
775impl_selection_output_tuple!(A, B, C, D, E, F, G, H);
776impl_selection_output_tuple!(A, B, C, D, E, F, G, H, I);
777impl_selection_output_tuple!(A, B, C, D, E, F, G, H, I, J);
778impl_selection_output_tuple!(A, B, C, D, E, F, G, H, I, J, K);
779impl_selection_output_tuple!(A, B, C, D, E, F, G, H, I, J, K, L);
780
781#[macro_export]
782macro_rules! take {
783	($first:expr $(, $rest:expr)* $(,)?) => {{
784		use $crate::select::{
785			record_selection_expr as __sqlxo_record_selection_expr,
786			SelectionEntry as __SqlxoSelectionEntry,
787			SelectionList as __SqlxoSelectionList,
788			SelectionOutputTuple as __SqlxoSelectionOutputTuple,
789		};
790
791		let mut __entries: $crate::smallvec::SmallVec<
792			[__SqlxoSelectionEntry; 4]
793		> = $crate::smallvec::SmallVec::new();
794
795		let __outputs = (
796			__sqlxo_record_selection_expr($first, &mut __entries),
797			$(
798				__sqlxo_record_selection_expr($rest, &mut __entries),
799			)*
800		);
801
802		let __output_marker = __SqlxoSelectionOutputTuple::flatten(__outputs);
803		__output_marker.into_selection_list(__entries)
804	}};
805}
806
807#[macro_export]
808macro_rules! group_by {
809	($first:ty $(, $rest:ty)* $(,)?) => {{
810		use $crate::select::{
811			Column as __SqlxoColumn,
812			GroupByList as __SqlxoGroupByList,
813			SelectionColumn as __SqlxoSelectionColumn,
814		};
815
816		let mut __cols: $crate::smallvec::SmallVec<[__SqlxoSelectionColumn; 4]> =
817			$crate::smallvec::SmallVec::new();
818		__cols.push(__SqlxoSelectionColumn::new(
819			<$first as __SqlxoColumn>::TABLE,
820			<$first as __SqlxoColumn>::NAME,
821		));
822		$(
823			__cols.push(__SqlxoSelectionColumn::new(
824				<$rest as __SqlxoColumn>::TABLE,
825				<$rest as __SqlxoColumn>::NAME,
826			));
827		)*
828
829		__SqlxoGroupByList::new(__cols)
830	}};
831}
832
833#[macro_export]
834macro_rules! having {
835	($first:expr $(, $rest:expr)* $(,)?) => {{
836		let mut __preds = Vec::new();
837		__preds.push($first);
838		$(
839			__preds.push($rest);
840		)*
841		$crate::select::HavingList::new(__preds)
842	}};
843}