Skip to main content

rust_rel8/
lib.rs

1#![feature(unboxed_closures)]
2#![allow(internal_features)]
3#![cfg_attr(any(docsrs, docsrs_dep), feature(rustdoc_internals))]
4// #![deny(missing_docs)]
5
6//! Welcome to Rust Rel8
7//!
8//! This is a port of Haskell's excellent
9//! [Rel8](https://rel8.readthedocs.io/en/latest/cookbook.html) library which
10//! provides a type safe and expressive interface for constructing SQL queries.
11//!
12//! Unlike other ORMs and query builders, this library does not provide a
13//! builder pattern on top of the AST for a SQL query, but instead allows you
14//! to write queries as if the tables themselves were just arrays in rust.
15
16pub mod is_nullable;
17
18use std::{
19    borrow::Cow,
20    marker::PhantomData,
21    sync::{Arc, atomic::AtomicU32},
22};
23
24use bytemuck::TransparentWrapper as _;
25use sea_query::{ExprTrait, OverStatement};
26
27#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Ord, Eq, Hash)]
28struct Binder(u32);
29
30static BINDER_COUNT: AtomicU32 = AtomicU32::new(0);
31impl Binder {
32    fn new() -> Self {
33        Self(BINDER_COUNT.fetch_add(1, std::sync::atomic::Ordering::SeqCst))
34    }
35
36    #[allow(unused)]
37    fn reset() {
38        BINDER_COUNT.store(0, std::sync::atomic::Ordering::SeqCst);
39    }
40}
41
42#[derive(Clone, Eq, PartialEq)]
43struct TableName {
44    binder: Binder,
45    name: String,
46}
47
48impl std::fmt::Debug for TableName {
49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50        write!(f, "{}", self.name)
51    }
52}
53
54impl<'a> From<TableName> for Cow<'a, str> {
55    fn from(val: TableName) -> Self {
56        format!("t{}", val.binder.0).into()
57    }
58}
59
60impl TableName {
61    fn new(binder: Binder) -> Self {
62        Self {
63            binder,
64            name: format!("t{}", binder.0),
65        }
66    }
67}
68
69impl sea_query::Iden for TableName {
70    fn unquoted(&self) -> &str {
71        &self.name
72    }
73}
74
75#[derive(Clone, PartialEq, Eq, Hash)]
76struct ColumnName {
77    name: String,
78    rendered: String,
79}
80
81impl std::fmt::Display for ColumnName {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        write!(f, "{}", self.rendered)
84    }
85}
86
87impl std::fmt::Debug for ColumnName {
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        write!(f, "{}", self.rendered)
90    }
91}
92
93impl sea_query::Iden for ColumnName {
94    fn unquoted(&self) -> &str {
95        &self.rendered
96    }
97}
98
99impl ColumnName {
100    fn new(binder: Binder, name: String) -> Self {
101        Self {
102            rendered: format!("{}{}", name, binder.0),
103            name,
104        }
105    }
106}
107
108/// A trait abstracting the mode of a user defined table, which allows us to
109/// talk about the same table in two different modes.
110///
111/// ```rs
112/// impl<'scope, T: Column<'scope>> TableHKT<'scope> for MyTable<'scope, T> {
113///     type InMode<Mode: Column<'scope>> = MyTable<'scope, Mode>;
114///
115///     type Mode = T;
116/// }
117/// ```
118pub trait TableHKT {
119    /// The current mode of this table.
120    type Mode: TableMode;
121
122    /// Replace the mode with another.
123    type InMode<Mode: TableMode>;
124}
125
126#[cfg(feature = "sqlx")]
127/// Proxy module allowing us to have a supertrait conditional on feature flags.
128mod value_sqlx {
129    /// Proxy trait for [sqlx::Decode] that we can switch off if not enabled
130    pub trait Value: for<'r> sqlx::Decode<'r, sqlx::Postgres> + sqlx::Type<sqlx::Postgres> {}
131    impl<T> Value for T where T: for<'r> sqlx::Decode<'r, sqlx::Postgres> + sqlx::Type<sqlx::Postgres> {}
132}
133
134#[cfg(not(feature = "sqlx"))]
135/// Proxy module allowing us to have a supertrait conditional on feature flags.
136mod value_sqlx {
137    /// Proxy trait for [sqlx::Decode] that we can switch off if not enabled
138    pub trait Value {}
139    impl<T> Value for T where T: ?Sized {}
140}
141
142pub use value_sqlx::Value as SqlxValueIfEnabled;
143
144/// This trait represents values we know to encode and decode from their database type.
145///
146/// Depending on features, this will have supertraits of the encode/decode
147/// traits of the backends.
148pub trait Value: Into<sea_query::Value> + SqlxValueIfEnabled + IsNullable {}
149
150impl<T> Value for T where T: Into<sea_query::Value> + value_sqlx::Value + IsNullable {}
151
152/// This trait allows us to write a mapping function between two column modes.
153///
154/// We use a trait as we need this to work for all types of the type parameter `V`.
155pub trait ModeMapper<'scope, SrcMode: TableMode, DestMode: TableMode> {
156    /// Map from `SrcMode` to `DestMode`, consuming the value.
157    fn map_mode<V>(&mut self, src: SrcMode::T<'scope, V>) -> DestMode::T<'scope, V>
158    where
159        V: Value;
160}
161
162/// This trait allows us to write a mapping function between two column modes.
163///
164/// We use a trait as we need this to work for all types of the type parameter `V`.
165pub trait ModeMapperRef<'scope, SrcMode: TableMode, DestMode: TableMode> {
166    /// Map from `SrcMode` to `DestMode`, taking the value as a reference
167    fn map_mode_ref<V>(&mut self, src: &SrcMode::T<'scope, V>) -> DestMode::T<'scope, V>
168    where
169        V: Value;
170}
171
172/// This trait allows us to write a mapping function between two column modes.
173///
174/// We use a trait as we need this to work for all types of the type parameter `V`.
175pub trait ModeMapperMut<'scope, SrcMode: TableMode, DestMode: TableMode> {
176    /// Map from `SrcMode` to `DestMode`, taking the value as a mutable reference
177    fn map_mode_mut<V>(&mut self, src: &mut SrcMode::T<'scope, V>) -> DestMode::T<'scope, V>
178    where
179        V: Value;
180}
181
182/// This trait allows us to change the mode of a table by mapping all the fields
183/// with [ModeMapper], [ModeMapperRef], or [ModeMapperMut].
184pub trait MapTable<'scope>: TableHKT {
185    /// Map each field of the table
186    ///
187    /// The order and number of fields visited must always remain the same,
188    /// across: [Table::visit], [Table::visit_mut], and all methods of [MapTable].
189    fn map_modes<Mapper, DestMode>(self, mapper: &mut Mapper) -> Self::InMode<DestMode>
190    where
191        Mapper: ModeMapper<'scope, Self::Mode, DestMode>,
192        DestMode: TableMode;
193
194    /// Map each field of the table, with a reference
195    ///
196    /// The order and number of fields visited must always remain the same,
197    /// across: [Table::visit], [Table::visit_mut], and all methods of [MapTable].
198    fn map_modes_ref<Mapper, DestMode>(&self, mapper: &mut Mapper) -> Self::InMode<DestMode>
199    where
200        Mapper: ModeMapperRef<'scope, Self::Mode, DestMode>,
201        DestMode: TableMode;
202
203    /// Map each field of the table, with a mutable reference
204    ///
205    /// The order and number of fields visited must always remain the same,
206    /// across: [Table::visit], [Table::visit_mut], and all methods of [MapTable].
207    fn map_modes_mut<Mapper, DestMode>(&mut self, mapper: &mut Mapper) -> Self::InMode<DestMode>
208    where
209        Mapper: ModeMapperMut<'scope, Self::Mode, DestMode>,
210        DestMode: TableMode;
211}
212
213/// A mapper which reads the column names of a table in [NameMode] and adds them to a select,
214/// then yields an expression referencing the column.
215struct NameToExprMapper {
216    binder: Binder,
217    query: sea_query::SelectStatement,
218}
219
220impl<'scope> ModeMapperRef<'scope, NameMode, ExprMode> for NameToExprMapper {
221    fn map_mode_ref<V>(
222        &mut self,
223        src: &<NameMode as TableMode>::T<'scope, V>,
224    ) -> <ExprMode as TableMode>::T<'scope, V> {
225        let col_name = ColumnName::new(self.binder, src.to_string());
226        self.query
227            .expr_as(sea_query::Expr::column(*src), col_name.clone());
228
229        Expr::new(ExprInner::Column(TableName::new(self.binder), col_name))
230    }
231}
232
233struct LitMapper {}
234
235impl<'scope> ModeMapper<'scope, ValueMode, ExprMode> for LitMapper {
236    fn map_mode<V>(
237        &mut self,
238        src: <ValueMode as TableMode>::T<'scope, V>,
239    ) -> <ExprMode as TableMode>::T<'scope, V>
240    where
241        V: Value,
242    {
243        Expr::new(ExprInner::Raw(sea_query::Expr::value(src.into())))
244    }
245}
246
247struct ExprCollectorMapper {
248    idx: usize,
249    table_binder: Binder,
250    columns: Vec<ColumnName>,
251    values: Vec<sea_query::Value>,
252}
253
254impl<'scope> ModeMapper<'scope, ValueMode, ExprMode> for ExprCollectorMapper {
255    fn map_mode<V>(
256        &mut self,
257        src: <ValueMode as TableMode>::T<'scope, V>,
258    ) -> <ExprMode as TableMode>::T<'scope, V>
259    where
260        V: Value,
261    {
262        let idx = self.idx;
263        self.idx += 1;
264
265        self.values.push(src.into());
266
267        let column_name = ColumnName::new(self.table_binder, format!("values_{idx}_"));
268
269        self.columns.push(column_name.clone());
270
271        Expr::new(ExprInner::Column(
272            TableName::new(self.table_binder),
273            column_name,
274        ))
275    }
276}
277
278struct ExprCollectorRemainingMapper {
279    values: Vec<sea_query::Value>,
280}
281
282impl<'scope> ModeMapper<'scope, ValueMode, EmptyMode> for ExprCollectorRemainingMapper {
283    fn map_mode<V>(
284        &mut self,
285        src: <ValueMode as TableMode>::T<'scope, V>,
286    ) -> <EmptyMode as TableMode>::T<'scope, V>
287    where
288        V: Value,
289    {
290        self.values.push(src.into());
291    }
292}
293
294struct NameCollectorMapper {
295    names: Vec<&'static str>,
296}
297
298impl<'scope> ModeMapperRef<'scope, NameMode, EmptyMode> for NameCollectorMapper {
299    fn map_mode_ref<V>(
300        &mut self,
301        src: &<NameMode as TableMode>::T<'scope, V>,
302    ) -> <EmptyMode as TableMode>::T<'scope, V> {
303        self.names.push(*src);
304    }
305}
306
307/// A mapper which visits each expression of a table in [ExprMode].
308struct VisitTableMapper<'a, F> {
309    f: &'a mut F,
310    mode: VisitTableMode,
311}
312
313impl<'a, 'scope, F> ModeMapperRef<'scope, ExprMode, EmptyMode> for VisitTableMapper<'a, F>
314where
315    F: FnMut(&ErasedExpr),
316{
317    fn map_mode_ref<V>(
318        &mut self,
319        src: &<ExprMode as TableMode>::T<'scope, V>,
320    ) -> <EmptyMode as TableMode>::T<'scope, V>
321    where
322        V: Value,
323    {
324        match self.mode {
325            VisitTableMode::All => (self.f)(src.as_erased()),
326            VisitTableMode::NonNull => {
327                if !<V as IsNullable>::IS_NULLABLE {
328                    (self.f)(src.as_erased());
329                }
330            }
331        }
332    }
333}
334
335/// A mapper which mutably visits each expression of a table in [ExprMode].
336struct VisitTableMapperMut<'a, F> {
337    f: &'a mut F,
338    mode: VisitTableMode,
339}
340
341impl<'a, 'scope, F> ModeMapperMut<'scope, ExprMode, EmptyMode> for VisitTableMapperMut<'a, F>
342where
343    F: FnMut(&mut ErasedExpr),
344{
345    fn map_mode_mut<V>(
346        &mut self,
347        src: &mut <ExprMode as TableMode>::T<'scope, V>,
348    ) -> <EmptyMode as TableMode>::T<'scope, V>
349    where
350        V: Value,
351    {
352        match self.mode {
353            VisitTableMode::All => (self.f)(src.as_erased_mut()),
354            VisitTableMode::NonNull => {
355                if !<V as IsNullable>::IS_NULLABLE {
356                    (self.f)(src.as_erased_mut());
357                }
358            }
359        }
360    }
361}
362
363#[cfg(feature = "sqlx")]
364/// A mapper which decodes a table from a result row, turning it from [ExprMode] to [ValueMode].
365struct LoadingMapper<'a, IT> {
366    it: &'a mut IT,
367}
368
369#[cfg(feature = "sqlx")]
370impl<'a, 'b, 'scope, IT> ModeMapperRef<'scope, ExprMode, ValueMode> for LoadingMapper<'a, IT>
371where
372    IT: Iterator<Item = sqlx::postgres::PgValueRef<'b>>,
373{
374    fn map_mode_ref<V>(
375        &mut self,
376        _src: &<ExprMode as TableMode>::T<'scope, V>,
377    ) -> <ValueMode as TableMode>::T<'scope, V>
378    where
379        V: Value,
380    {
381        <_ as sqlx::Decode<sqlx::Postgres>>::decode(self.it.next().unwrap()).unwrap()
382    }
383}
384
385#[cfg(feature = "sqlx")]
386/// A mapper which decodes a table from a result row, turning it from [ExprMode] to [ValueMode].
387struct LoadingManyMapper<'a, IT> {
388    it: &'a mut IT,
389}
390
391#[cfg(feature = "sqlx")]
392impl<'a, 'b, 'scope, IT> ModeMapperRef<'scope, ExprMode, ValueManyMode>
393    for LoadingManyMapper<'a, IT>
394where
395    IT: Iterator<Item = sqlx::postgres::PgValueRef<'b>>,
396{
397    fn map_mode_ref<V>(
398        &mut self,
399        _src: &<ExprMode as TableMode>::T<'scope, V>,
400    ) -> <ValueManyMode as TableMode>::T<'scope, V>
401    where
402        V: Value,
403    {
404        <Vec<V> as sqlx::Decode<sqlx::Postgres>>::decode(self.it.next().unwrap()).unwrap()
405    }
406}
407
408#[cfg(feature = "sqlx")]
409/// A mapper which skips loading a table from a result row.
410struct SkippingMapper<'a, IT> {
411    it: &'a mut IT,
412}
413
414#[cfg(feature = "sqlx")]
415impl<'a, 'b, 'scope, IT> ModeMapperRef<'scope, ExprMode, EmptyMode> for SkippingMapper<'a, IT>
416where
417    IT: Iterator<Item = sqlx::postgres::PgValueRef<'b>>,
418{
419    fn map_mode_ref<V>(
420        &mut self,
421        _src: &<ExprMode as TableMode>::T<'scope, V>,
422    ) -> <EmptyMode as TableMode>::T<'scope, V> {
423        self.it.next().unwrap();
424    }
425}
426
427#[cfg(feature = "sqlx")]
428struct ManyRemainingMapper {
429    is_empty: bool,
430}
431
432#[cfg(feature = "sqlx")]
433impl<'scope> ModeMapperRef<'scope, ValueManyMode, EmptyMode> for ManyRemainingMapper {
434    fn map_mode_ref<V>(
435        &mut self,
436        src: &<ValueManyMode as TableMode>::T<'scope, V>,
437    ) -> <EmptyMode as TableMode>::T<'scope, V>
438    where
439        V: Value,
440    {
441        self.is_empty &= src.is_empty();
442    }
443}
444
445#[cfg(feature = "sqlx")]
446struct SpreadingManyMapper {}
447
448#[cfg(feature = "sqlx")]
449impl<'a, 'scope> ModeMapperMut<'scope, ValueManyMode, ValueMode> for SpreadingManyMapper {
450    fn map_mode_mut<V>(
451        &mut self,
452        src: &mut <ValueManyMode as TableMode>::T<'scope, V>,
453    ) -> <ValueMode as TableMode>::T<'scope, V>
454    where
455        V: Value,
456    {
457        src.pop().unwrap()
458    }
459}
460
461/// Trait for tables containing values which we can turn into literals
462pub trait LitTable: TableHKT<Mode = ValueMode> {
463    /// Lift a [ValueMode] table to an [ExprMode] table
464    fn lit(self) -> Self::InMode<ExprMode>;
465}
466
467impl<'scope, T: TableHKT<Mode = ValueMode> + MapTable<'scope>> LitTable for T {
468    fn lit(self) -> Self::InMode<ExprMode> {
469        self.map_modes(&mut LitMapper {})
470    }
471}
472
473/// A table's name and column names.
474pub struct TableSchema<Table> {
475    /// The name of the table.
476    pub name: &'static str,
477
478    /// The table columns, this should be some table in [NameMode].
479    pub columns: Table,
480}
481
482/// The modes a table can be in
483pub mod table_modes {
484    /// Name mode, where all columns are [`&'static str`], representing the column names.
485    pub enum NameMode {}
486
487    #[derive(Debug, PartialEq)]
488    /// Value mode, representing a table row that has been loaded from the query.
489    ///
490    /// This enum implements [Debug] and [PartialEq] so that your types can
491    /// derive them without a baseless failure from a type parameter not
492    /// implementing trait despite not appearing in the data type.
493    pub enum ValueMode {}
494
495    pub enum ValueManyMode {}
496
497    // /// Value mode, but the value might be null.
498    // pub enum ValueNullifiedMode {}
499
500    /// Expr mode, the columns are [crate::Expr]s.
501    pub enum ExprMode {}
502
503    // /// Expr mode, but the value might be null, the columns are [crate::Expr]s.
504    // pub enum ExprNullifiedMode {}
505
506    /// Empty mode, all fields are `()`. This is used when a mapper doesn't
507    /// want to produce a value.
508    pub enum EmptyMode {}
509}
510
511pub use table_modes::*;
512
513/// Table modes, this trait is used to switch the types of a rust structs fields.
514///
515/// You should use it in table struct like so:
516///
517/// ```rust
518/// use rust_rel8::*;
519///
520/// struct MyTable<'scope, Mode: TableMode> {
521///   id: Mode::T<'scope, i32>,
522///   name: Mode::T<'scope, String>,
523///   age: Mode::T<'scope, i32>,
524/// }
525/// ```
526pub trait TableMode {
527    /// A Gat, the resultant type may or may not incorporate `V`.
528    type T<'scope, V>;
529}
530
531impl TableMode for NameMode {
532    /// a string representing the column name.
533    type T<'scope, V> = &'static str;
534}
535
536impl TableMode for ValueMode {
537    type T<'scope, V> = V;
538}
539
540impl TableMode for ValueManyMode {
541    type T<'scope, V> = Vec<V>;
542}
543
544impl TableMode for ExprMode {
545    type T<'scope, V> = Expr<'scope, V>;
546}
547
548impl TableMode for EmptyMode {
549    type T<'scope, V> = ();
550}
551
552#[derive(bytemuck::TransparentWrapper)]
553#[repr(transparent)]
554/// A wrapper which implements [Table] for any type implementing [MapTable] in [ExprMode].
555pub struct TableUsingMapper<T>(pub T);
556
557#[allow(missing_docs)]
558impl<T> TableUsingMapper<T> {
559    pub fn wrap(t: T) -> Self {
560        <Self as bytemuck::TransparentWrapper<T>>::wrap(t)
561    }
562
563    pub fn wrap_ref(t: &T) -> &Self {
564        <Self as bytemuck::TransparentWrapper<T>>::wrap_ref(t)
565    }
566
567    pub fn wrap_mut(t: &mut T) -> &mut Self {
568        <Self as bytemuck::TransparentWrapper<T>>::wrap_mut(t)
569    }
570}
571
572impl<'scope, T> Table<'scope> for TableUsingMapper<T>
573where
574    T: Table<'scope> + MapTable<'scope> + TableHKT<Mode = ExprMode>,
575{
576    type Result = T::InMode<ValueMode>;
577
578    fn visit(&self, f: &mut impl FnMut(&ErasedExpr), mode: VisitTableMode) {
579        let mut mapper = VisitTableMapper { f, mode };
580        self.0.map_modes_ref(&mut mapper);
581    }
582
583    fn visit_mut(&mut self, f: &mut impl FnMut(&mut ErasedExpr), mode: VisitTableMode) {
584        let mut mapper = VisitTableMapperMut { f, mode };
585        self.0.map_modes_mut(&mut mapper);
586    }
587}
588
589#[cfg(feature = "sqlx")]
590impl<T> TableLoaderSqlx for TableUsingMapper<T>
591where
592    T: Table<'static> + MapTable<'static> + TableHKT<Mode = ExprMode>,
593{
594    fn load<'a>(
595        &self,
596        values: &mut impl Iterator<Item = sqlx::postgres::PgValueRef<'a>>,
597    ) -> Self::Result {
598        let mut mapper = LoadingMapper { it: values };
599        self.0.map_modes_ref(&mut mapper)
600    }
601
602    fn skip<'a>(&self, values: &mut impl Iterator<Item = sqlx::postgres::PgValueRef<'a>>) {
603        let mut mapper = SkippingMapper { it: values };
604        self.0.map_modes_ref(&mut mapper);
605    }
606}
607
608#[cfg(feature = "sqlx")]
609impl<T> TableLoaderManySqlx for TableUsingMapper<T>
610where
611    T: Table<'static> + MapTable<'static> + TableHKT<Mode = ExprMode>,
612    T::InMode<ValueManyMode>: MapTable<'static> + TableHKT<Mode = ValueManyMode>,
613    // sorry
614    <<T as TableHKT>::InMode<table_modes::ValueManyMode> as TableHKT>::InMode<
615        table_modes::ValueMode,
616    >: type_equalities::IsEqual<Self::Result>,
617{
618    fn load_many<'a>(
619        &self,
620        values: &mut impl Iterator<Item = sqlx::postgres::PgValueRef<'a>>,
621    ) -> Vec<Self::Result> {
622        let mut mapper = LoadingManyMapper { it: values };
623        let mut collected = self.0.map_modes_ref(&mut mapper);
624
625        let mut results: Vec<Self::Result> = vec![];
626
627        while {
628            let mut mapper = ManyRemainingMapper { is_empty: true };
629            collected.map_modes_ref(&mut mapper);
630            !mapper.is_empty
631        } {
632            let mut mapper = SpreadingManyMapper {};
633            results.push(type_equalities::coerce(
634                collected.map_modes_mut(&mut mapper),
635                type_equalities::trivial_eq(),
636            ));
637        }
638
639        results
640    }
641}
642
643/// How to visit the columns in a table
644#[derive(Debug, Copy, Clone)]
645pub enum VisitTableMode {
646    /// Visit all the columns in the table
647    All,
648    /// Visit only the non-nullable columns in the table
649    NonNull,
650}
651
652/// A trait that represents a database result row.
653///
654/// If you implement [Table] on your type, you must also implement [ForLifetimeTable].
655pub trait Table<'scope> {
656    /// The value a row of this table has when loaded from the database.
657    type Result;
658
659    /// Visit each expr in the table.
660    ///
661    /// The order and number of expressions visited must always remain the same,
662    /// across: [Table::visit], [Table::visit_mut], and all methods of [MapTable].
663    fn visit(&self, f: &mut impl FnMut(&ErasedExpr), mode: VisitTableMode);
664
665    /// Visit each expr in the table, with a mutable reference.
666    ///
667    /// The order and number of expressions visited must always remain the same,
668    /// across: [Table::visit], [Table::visit_mut], and all methods of [MapTable].
669    fn visit_mut(&mut self, f: &mut impl FnMut(&mut ErasedExpr), mode: VisitTableMode);
670}
671
672#[cfg(feature = "sqlx")]
673/// Trait allowing loading of [Table]s when using sqlx.
674pub trait TableLoaderSqlx: Table<'static> {
675    /// Load the table given an iterator over a row's values
676    fn load<'a>(
677        &self,
678        values: &mut impl Iterator<Item = sqlx::postgres::PgValueRef<'a>>,
679    ) -> Self::Result;
680
681    /// discard N columns from the iterator ofer a row
682    /// This is used when this value was discarded by a [MaybeTable].
683    fn skip<'a>(&self, values: &mut impl Iterator<Item = sqlx::postgres::PgValueRef<'a>>);
684}
685
686#[cfg(feature = "sqlx")]
687/// Trait allowing loading of many [Table]s when using sqlx.
688pub trait TableLoaderManySqlx: Table<'static> + TableLoaderSqlx {
689    /// Load the table given an iterator over a vec of row's values
690    fn load_many<'a>(
691        &self,
692        values: &mut impl Iterator<Item = sqlx::postgres::PgValueRef<'a>>,
693    ) -> Vec<Self::Result>;
694}
695
696/// attach each of the tables columns to the select, and renames the table to
697/// use the new names.
698///
699/// This embeds expressions, so after this the table contains only Column exprs.
700fn subst_table<'scope, T: Table<'scope>>(
701    table: &mut T,
702    table_name: TableName,
703    dest_select: &mut sea_query::SelectStatement,
704) {
705    table.visit_mut(
706        &mut |ErasedExpr(inner)| {
707            let new_column_name = match inner {
708                ExprInner::Raw(..) => ColumnName::new(Binder::new(), "lit".to_owned()),
709                ExprInner::Column(_table_name, column_name) => {
710                    ColumnName::new(Binder::new(), column_name.name.clone())
711                }
712                ExprInner::BinOp(..) => ColumnName::new(Binder::new(), "expr".to_owned()),
713                ExprInner::NOp(..) => ColumnName::new(Binder::new(), "expr".to_owned()),
714            };
715            let r = inner.render();
716            dest_select.expr_as(r, new_column_name.clone());
717            *inner = ExprInner::Column(table_name.clone(), new_column_name);
718        },
719        VisitTableMode::All,
720    )
721}
722
723// not sure if we should do this by having `q` wrap the query and `subst_table` everything, or this
724fn insert_table_name<'scope, T: Table<'scope>>(table: &mut T, new_table_name: TableName) {
725    table.visit_mut(
726        &mut |ErasedExpr(inner)| match inner {
727            ExprInner::Raw(_) => {}
728            ExprInner::Column(table_name, _column_name) => {
729                *table_name = new_table_name.clone();
730            }
731            ExprInner::BinOp(_, expr_inner, expr_inner1) => {
732                expr_inner.visit_mut(&mut |table_name, _| *table_name = new_table_name.clone());
733                expr_inner1.visit_mut(&mut |table_name, _| *table_name = new_table_name.clone());
734            }
735            ExprInner::NOp(_, expr_inners) => {
736                for expr_inner in expr_inners {
737                    expr_inner.visit_mut(&mut |table_name, _| *table_name = new_table_name.clone());
738                }
739            }
740        },
741        VisitTableMode::All,
742    )
743}
744
745fn collect_exprs<'scope, T: Table<'scope>>(table: &T) -> Vec<ExprInner> {
746    let mut exprs = vec![];
747
748    table.visit(
749        &mut |ErasedExpr(e)| {
750            exprs.push(e.clone());
751        },
752        VisitTableMode::All,
753    );
754
755    exprs
756}
757
758/// Private type used to restrict calls of [ForLifetimeTable::with_lt] to the
759/// library.
760#[non_exhaustive]
761pub struct WithLtMarker {}
762
763impl WithLtMarker {
764    fn new() -> Self {
765        Self {}
766    }
767}
768
769/// A helper trait that allows us to talk about a [Table] with different
770/// lifetimes. Conceptually it is a type level function of `lt -> T where T:
771/// Table<'lt>`.
772///
773/// If you implement [Table] on your type, you must also implement [ForLifetimeTable].
774pub trait ForLifetimeTable {
775    /// Substitute the lifetime of this table with `'lt`.
776    type WithLt<'lt>: ForLifetimeTable + Table<'lt> + Sized;
777
778    /// Coerce the lifetime of this table. This is used internally by the library.
779    fn with_lt<'lt>(self, marker: &mut WithLtMarker) -> Self::WithLt<'lt>;
780}
781
782impl<'scope, T> Table<'scope> for Expr<'scope, T>
783where
784    T: Value,
785{
786    type Result = T;
787
788    fn visit(&self, f: &mut impl FnMut(&ErasedExpr), mode: VisitTableMode) {
789        match mode {
790            VisitTableMode::All => f(self.as_erased()),
791            VisitTableMode::NonNull => {
792                if !<T as IsNullable>::IS_NULLABLE {
793                    f(self.as_erased());
794                }
795            }
796        }
797    }
798
799    fn visit_mut(&mut self, f: &mut impl FnMut(&mut ErasedExpr), mode: VisitTableMode) {
800        match mode {
801            VisitTableMode::All => f(self.as_erased_mut()),
802            VisitTableMode::NonNull => {
803                if !<T as IsNullable>::IS_NULLABLE {
804                    f(self.as_erased_mut());
805                }
806            }
807        }
808    }
809}
810
811impl<'scope, T: Value> ForLifetimeTable for Expr<'scope, T> {
812    type WithLt<'lt> = Expr<'lt, T>;
813
814    fn with_lt<'lt>(self, _marker: &mut WithLtMarker) -> Self::WithLt<'lt> {
815        Expr::new(self.expr)
816    }
817}
818
819impl<'scope, T> ShortenLifetime for Expr<'scope, T> {
820    type Shortened<'small>
821        = Expr<'small, T>
822    where
823        Self: 'small;
824
825    fn shorten_lifetime<'small, 'large: 'small>(self) -> Self::Shortened<'small>
826    where
827        Self: 'large,
828    {
829        Expr::new(self.expr)
830    }
831}
832
833#[cfg(feature = "sqlx")]
834impl<T: Value> TableLoaderSqlx for Expr<'static, T> {
835    fn load<'a>(
836        &self,
837        values: &mut impl Iterator<Item = sqlx::postgres::PgValueRef<'a>>,
838    ) -> Self::Result {
839        T::decode(values.next().unwrap()).unwrap()
840    }
841
842    fn skip<'a>(&self, values: &mut impl Iterator<Item = sqlx::postgres::PgValueRef<'a>>) {
843        let _ = values.next().unwrap();
844    }
845}
846
847#[cfg(feature = "sqlx")]
848impl<T: Value> TableLoaderManySqlx for Expr<'static, T> {
849    fn load_many<'a>(
850        &self,
851        values: &mut impl Iterator<Item = sqlx::postgres::PgValueRef<'a>>,
852    ) -> Vec<Self::Result> {
853        <Vec<T> as sqlx::Decode<'_, sqlx::Postgres>>::decode(values.next().unwrap())
854            .unwrap_or_default()
855    }
856}
857
858macro_rules! izip_priv {
859    ( @closure $p:pat => $tup:expr ) => {
860        |$p| $tup
861    };
862
863    ( @closure $p:pat => ( $($tup:tt)* ) , $_iter:expr $( , $tail:expr )* ) => {
864        izip_priv!(@closure ($p, b) => ( $($tup)*, b ) $( , $tail )*)
865    };
866
867    ($first:expr $(,)*) => {
868        $first.into_iter().map(|a| (a,))
869    };
870
871    ( $first:expr $( , $rest:expr )* $(,)* ) => {
872        {
873            let iter = $first.into_iter();
874            $(
875                let iter = iter.zip($rest);
876            )*
877            iter.map(izip_priv!(@closure a => (a) $( , $rest )*))
878        }
879    };
880}
881
882macro_rules! impl_tuples {
883    ($(#[$meta:meta])* $(($idx:tt, $ty:ident)),*) => {
884
885        $(#[$meta])*
886        impl<'scope, $($ty,)*> Table<'scope> for ($($ty,)*)
887            where $($ty: Table<'scope>,)*
888        {
889            type Result = ($($ty::Result,)*);
890
891            fn visit(&self, f: &mut impl FnMut(&ErasedExpr), mode: VisitTableMode) {
892                $(self.$idx.visit(f, mode);)*
893            }
894
895            fn visit_mut(&mut self, f: &mut impl FnMut(&mut ErasedExpr), mode: VisitTableMode) {
896                $(self.$idx.visit_mut(f, mode);)*
897            }
898        }
899
900        $(#[$meta])*
901        impl<$($ty,)*> ShortenLifetime for ($($ty,)*)
902            where $($ty: ShortenLifetime,)*
903        {
904            type Shortened<'small>
905                = ($($ty::Shortened<'small>,)*)
906            where
907                Self: 'small;
908
909            fn shorten_lifetime<'small, 'large: 'small>(self) -> Self::Shortened<'small>
910            where
911                Self: 'large,
912            {
913                ($(self.$idx.shorten_lifetime(),)*)
914            }
915        }
916
917        $(#[$meta])*
918        impl<$($ty,)*> ForLifetimeTable for ($($ty,)*)
919            where $($ty: ForLifetimeTable,)*
920        {
921            type WithLt<'lt>
922                = ($($ty::WithLt<'lt>,)*);
923
924            fn with_lt<'lt>(self, marker: &mut WithLtMarker) -> Self::WithLt<'lt> {
925                ($(self.$idx.with_lt(marker),)*)
926            }
927        }
928
929        #[cfg(feature = "sqlx")]
930        $(#[$meta])*
931        impl<$($ty,)*> TableLoaderSqlx for ($($ty,)*)
932            where $($ty: TableLoaderSqlx,)*
933        {
934            fn load<'a>(
935                &self,
936                values: &mut impl Iterator<Item = sqlx::postgres::PgValueRef<'a>>,
937            ) -> Self::Result {
938                (
939                    $(self.$idx.load(values),)*
940                )
941            }
942
943            fn skip<'a>(&self, values: &mut impl Iterator<Item = sqlx::postgres::PgValueRef<'a>>) {
944                $(self.$idx.skip(values);)*
945            }
946        }
947
948        #[cfg(feature = "sqlx")]
949        $(#[$meta])*
950        impl<$($ty,)*> TableLoaderManySqlx for ($($ty,)*)
951        where $($ty: TableLoaderManySqlx,)*
952        {
953            fn load_many<'a>(
954                &self,
955                values: &mut impl Iterator<Item = sqlx::postgres::PgValueRef<'a>>,
956        ) -> Vec<Self::Result> {
957            izip_priv!($(self.$idx.load_many(values),)*).collect()
958        }
959    }
960}
961}
962
963variadics_please::all_tuples_enumerated!(
964    #[doc(fake_variadic)]
965    impl_tuples,
966    1,
967    15,
968    T
969);
970
971/// A table which contains a tag indicating whether a query returned rows.
972/// Use [Query::optional] to construct this.
973pub struct MaybeTable<'scope, T> {
974    tag: Expr<'scope, Option<bool>>,
975    inner: T,
976}
977
978impl<'scope, T: Table<'scope>> MaybeTable<'scope, T> {
979    /// Project out an expression from the maybe table
980    pub fn project<U, F>(&self, f: F) -> Expr<'scope, Option<U>>
981    where
982        F: FnOnce(&T) -> Expr<U>,
983    {
984        let tag = self.tag.clone();
985        let e = f(&self.inner);
986
987        let cb = Arc::new(|cond: sea_query::SimpleExpr, expr: sea_query::SimpleExpr| {
988            sea_query::CaseStatement::new()
989                .case(cond.is_not_null(), expr)
990                .finally(sea_query::Expr::null())
991                .into()
992        });
993
994        Expr::new(ExprInner::BinOp(cb, Box::new(tag.expr), Box::new(e.expr)))
995    }
996
997    /// Test if the table is None, if it is, use the `fallback`, otherwise call
998    /// `f`.
999    pub fn maybe<U: Table<'scope>, F>(self, fallback: U, f: F) -> U
1000    where
1001        F: FnOnce(T) -> U,
1002    {
1003        let tag = self.tag.clone();
1004        let mut e = f(self.inner);
1005
1006        let mut defaults = collect_exprs(&fallback).into_iter();
1007
1008        e.visit_mut(
1009            &mut |ErasedExpr(inner)| {
1010                let tag = tag.expr.clone();
1011                let default = defaults.next().unwrap();
1012                let cb = Arc::new(|exprs: Vec<sea_query::Expr>| {
1013                    let [tag, expr, default] = exprs.try_into().unwrap();
1014                    sea_query::CaseStatement::new()
1015                        .case(tag.is_not_null(), expr)
1016                        .finally(default)
1017                        .into()
1018                });
1019
1020                *inner = ExprInner::NOp(cb, vec![tag, inner.clone(), default]);
1021            },
1022            VisitTableMode::All,
1023        );
1024
1025        e
1026    }
1027}
1028
1029impl<'scope, T: Table<'scope>> Table<'scope> for MaybeTable<'scope, T> {
1030    type Result = Option<T::Result>;
1031
1032    fn visit(&self, f: &mut impl FnMut(&ErasedExpr), mode: VisitTableMode) {
1033        self.tag.visit(f, mode);
1034        self.inner.visit(f, mode);
1035    }
1036
1037    fn visit_mut(&mut self, f: &mut impl FnMut(&mut ErasedExpr), mode: VisitTableMode) {
1038        self.tag.visit_mut(f, mode);
1039        self.inner.visit_mut(f, mode);
1040    }
1041}
1042
1043impl<'scope, T: ForLifetimeTable + Table<'scope>> ForLifetimeTable for MaybeTable<'scope, T>
1044where
1045    for<'lt> T::WithLt<'lt>: Table<'lt>,
1046{
1047    type WithLt<'lt> = MaybeTable<'lt, T::WithLt<'lt>>;
1048
1049    fn with_lt<'lt>(self, marker: &mut WithLtMarker) -> Self::WithLt<'lt> {
1050        MaybeTable {
1051            tag: self.tag.with_lt(marker),
1052            inner: self.inner.with_lt(marker),
1053        }
1054    }
1055}
1056
1057impl<'scope, T: ShortenLifetime> ShortenLifetime for MaybeTable<'scope, T> {
1058    type Shortened<'small>
1059        = MaybeTable<'small, T::Shortened<'small>>
1060    where
1061        Self: 'small;
1062
1063    fn shorten_lifetime<'small, 'large: 'small>(self) -> Self::Shortened<'small>
1064    where
1065        Self: 'large,
1066    {
1067        MaybeTable {
1068            tag: self.tag.shorten_lifetime(),
1069            inner: self.inner.shorten_lifetime(),
1070        }
1071    }
1072}
1073
1074#[cfg(feature = "sqlx")]
1075impl<T: TableLoaderSqlx> TableLoaderSqlx for MaybeTable<'static, T> {
1076    fn load<'a>(
1077        &self,
1078        values: &mut impl Iterator<Item = sqlx::postgres::PgValueRef<'a>>,
1079    ) -> Self::Result {
1080        let tag =
1081            <Option<bool> as sqlx::Decode<sqlx::Postgres>>::decode(values.next().unwrap()).unwrap();
1082
1083        if tag == Some(true) {
1084            Some(self.inner.load(values))
1085        } else {
1086            self.inner.skip(values);
1087            None
1088        }
1089    }
1090
1091    fn skip<'a>(&self, values: &mut impl Iterator<Item = sqlx::postgres::PgValueRef<'a>>) {
1092        // the tag
1093        values.next().unwrap();
1094        self.inner.skip(values);
1095    }
1096}
1097
1098/// A table wrapping another. This table's rows are assumed null if any of the
1099/// non-nullable columns in `T` are null.
1100///
1101/// Use [Query::nullable] to construct this.
1102pub struct NullTable<'scope, T> {
1103    // nulltable still has a tag, but it is derived from the non-null columns of
1104    // inner
1105    tag: Expr<'scope, Option<bool>>,
1106    inner: T,
1107}
1108
1109impl<'scope, T: Table<'scope>> NullTable<'scope, T> {
1110    /// Project out an expression from the null table
1111    pub fn project<U, F>(&self, f: F) -> Expr<'scope, Option<U>>
1112    where
1113        F: FnOnce(&T) -> Expr<U>,
1114    {
1115        let e = f(&self.inner);
1116
1117        Expr::new(e.expr)
1118    }
1119
1120    /// Test if the table is None, if it is, use the `fallback`, otherwise call
1121    /// `f`.
1122    pub fn maybe<U: Table<'scope>, F>(self, fallback: U, f: F) -> U
1123    where
1124        F: FnOnce(T) -> U,
1125    {
1126        let tag = self.tag.clone();
1127        let mut e = f(self.inner);
1128
1129        let mut defaults = collect_exprs(&fallback).into_iter();
1130
1131        e.visit_mut(
1132            &mut |ErasedExpr(inner)| {
1133                let tag = tag.expr.clone();
1134                let default = defaults.next().unwrap();
1135                let cb = Arc::new(|exprs: Vec<sea_query::Expr>| {
1136                    let [tag, expr, default] = exprs.try_into().unwrap();
1137                    sea_query::CaseStatement::new()
1138                        .case(tag.is_not_null(), expr)
1139                        .finally(default)
1140                        .into()
1141                });
1142
1143                *inner = ExprInner::NOp(cb, vec![tag, inner.clone(), default]);
1144            },
1145            VisitTableMode::All,
1146        );
1147
1148        e
1149    }
1150}
1151
1152impl<'scope, T: Table<'scope>> Table<'scope> for NullTable<'scope, T> {
1153    type Result = Option<T::Result>;
1154
1155    fn visit(&self, f: &mut impl FnMut(&ErasedExpr), mode: VisitTableMode) {
1156        self.tag.visit(f, mode);
1157        self.inner.visit(f, mode);
1158    }
1159
1160    fn visit_mut(&mut self, f: &mut impl FnMut(&mut ErasedExpr), mode: VisitTableMode) {
1161        self.tag.visit_mut(f, mode);
1162        self.inner.visit_mut(f, mode);
1163    }
1164}
1165
1166impl<'scope, T: ForLifetimeTable + Table<'scope>> ForLifetimeTable for NullTable<'scope, T>
1167where
1168    for<'lt> T::WithLt<'lt>: Table<'lt>,
1169{
1170    type WithLt<'lt> = NullTable<'lt, T::WithLt<'lt>>;
1171
1172    fn with_lt<'lt>(self, marker: &mut WithLtMarker) -> Self::WithLt<'lt> {
1173        NullTable {
1174            tag: self.tag.with_lt(marker),
1175            inner: self.inner.with_lt(marker),
1176        }
1177    }
1178}
1179
1180impl<'scope, T: ShortenLifetime> ShortenLifetime for NullTable<'scope, T> {
1181    type Shortened<'small>
1182        = NullTable<'small, T::Shortened<'small>>
1183    where
1184        Self: 'small;
1185
1186    fn shorten_lifetime<'small, 'large: 'small>(self) -> Self::Shortened<'small>
1187    where
1188        Self: 'large,
1189    {
1190        NullTable {
1191            tag: self.tag.shorten_lifetime(),
1192            inner: self.inner.shorten_lifetime(),
1193        }
1194    }
1195}
1196
1197#[cfg(feature = "sqlx")]
1198impl<T: TableLoaderSqlx> TableLoaderSqlx for NullTable<'static, T> {
1199    fn load<'a>(
1200        &self,
1201        values: &mut impl Iterator<Item = sqlx::postgres::PgValueRef<'a>>,
1202    ) -> Self::Result {
1203        let tag =
1204            <Option<bool> as sqlx::Decode<sqlx::Postgres>>::decode(values.next().unwrap()).unwrap();
1205
1206        if tag == Some(true) {
1207            Some(self.inner.load(values))
1208        } else {
1209            self.inner.skip(values);
1210            None
1211        }
1212    }
1213
1214    fn skip<'a>(&self, values: &mut impl Iterator<Item = sqlx::postgres::PgValueRef<'a>>) {
1215        // the tag
1216        values.next().unwrap();
1217        self.inner.skip(values);
1218    }
1219}
1220
1221/// A table representing an array aggregation.
1222///
1223/// Use [Query::many] or [Query::aggregate] to construct this.
1224pub struct ListTable<T> {
1225    inner: T,
1226}
1227
1228impl<'scope, T: Table<'scope>> Table<'scope> for ListTable<T> {
1229    type Result = Vec<T::Result>;
1230
1231    fn visit(&self, f: &mut impl FnMut(&ErasedExpr), mode: VisitTableMode) {
1232        self.inner.visit(f, mode);
1233    }
1234
1235    fn visit_mut(&mut self, f: &mut impl FnMut(&mut ErasedExpr), mode: VisitTableMode) {
1236        self.inner.visit_mut(f, mode);
1237    }
1238}
1239
1240impl<'scope, T: ForLifetimeTable + Table<'scope>> ForLifetimeTable for ListTable<T>
1241where
1242    for<'lt> T::WithLt<'lt>: Table<'lt>,
1243{
1244    type WithLt<'lt> = ListTable<T::WithLt<'lt>>;
1245
1246    fn with_lt<'lt>(self, marker: &mut WithLtMarker) -> Self::WithLt<'lt> {
1247        ListTable {
1248            inner: self.inner.with_lt(marker),
1249        }
1250    }
1251}
1252
1253impl<T: ShortenLifetime> ShortenLifetime for ListTable<T> {
1254    type Shortened<'small>
1255        = ListTable<T::Shortened<'small>>
1256    where
1257        Self: 'small;
1258
1259    fn shorten_lifetime<'small, 'large: 'small>(self) -> Self::Shortened<'small>
1260    where
1261        Self: 'large,
1262    {
1263        ListTable {
1264            inner: self.inner.shorten_lifetime(),
1265        }
1266    }
1267}
1268
1269#[cfg(feature = "sqlx")]
1270impl<T: TableLoaderManySqlx> TableLoaderSqlx for ListTable<T> {
1271    fn load<'a>(
1272        &self,
1273        values: &mut impl Iterator<Item = sqlx::postgres::PgValueRef<'a>>,
1274    ) -> Self::Result {
1275        self.inner.load_many(values)
1276    }
1277
1278    fn skip<'a>(&self, values: &mut impl Iterator<Item = sqlx::postgres::PgValueRef<'a>>) {
1279        self.inner.skip(values);
1280    }
1281}
1282
1283/// A value representing a sql select statement which produces rows of type `T`.
1284#[derive(Clone)]
1285pub struct Query<T> {
1286    // Unique ID used to make the table name and columns unique
1287    binder: Binder,
1288    expr: sea_query::SelectStatement,
1289    inner: T,
1290    siblings_need_random: bool,
1291}
1292
1293impl<'scope, T> Query<T>
1294where
1295    T: Table<'scope> + ForLifetimeTable,
1296{
1297    fn new(binder: Binder, expr: sea_query::SelectStatement, inner: T) -> Self {
1298        Self {
1299            binder,
1300            expr,
1301            inner,
1302            siblings_need_random: false,
1303        }
1304    }
1305
1306    fn into_volatile(mut self) -> Self {
1307        self.siblings_need_random = true;
1308        self
1309    }
1310
1311    fn with_volatile(mut self, volatility: bool) -> Self {
1312        self.siblings_need_random = volatility;
1313        self
1314    }
1315
1316    fn erased(self) -> (ErasedQuery, T) {
1317        let r = ErasedQuery {
1318            expr: self.expr,
1319            siblings_need_random: self.siblings_need_random,
1320        };
1321
1322        (r, self.inner)
1323    }
1324
1325    /// Enter a context in which [A] can be used to build aggregations.
1326    ///
1327    /// The callback `f` will receive an [A], the methods of which can be used
1328    /// to build a table in the `'outer` scope.
1329    pub fn aggregate<U>(
1330        self,
1331        f: impl for<'inner, 'outer> FnOnce(
1332            &mut A<'inner, 'outer>,
1333            T::WithLt<'inner>,
1334        ) -> U::WithLt<'outer>,
1335    ) -> Query<U::WithLt<'scope>>
1336    where
1337        U: ForLifetimeTable,
1338    {
1339        let mut a = A {
1340            group_by: vec![],
1341            _phantom: PhantomData,
1342        };
1343
1344        let mut r = f(&mut a, self.inner.with_lt(&mut WithLtMarker::new()));
1345
1346        let binder = Binder::new();
1347        let mut expr = sea_query::SelectStatement::new();
1348
1349        expr.from_subquery(self.expr, TableName::new(self.binder));
1350
1351        subst_table(&mut r, TableName::new(binder), &mut expr);
1352
1353        expr.add_group_by(a.group_by.iter().map(ExprInner::render));
1354
1355        Query::new(binder, expr, r)
1356    }
1357
1358    /// Construct a simple list aggregation for this query.
1359    ///
1360    /// That is, all rows of this query will be aggregated into an array.
1361    pub fn many(mut self) -> Query<ListTable<T>> {
1362        let binder = Binder::new();
1363        let mut expr = sea_query::SelectStatement::new();
1364
1365        expr.from_subquery(self.expr, TableName::new(self.binder));
1366
1367        self.inner.visit_mut(
1368            &mut |ErasedExpr(inner)| {
1369                *inner = ExprInner::NOp(
1370                    Arc::new(|inners| {
1371                        let [inner] = inners.try_into().unwrap();
1372
1373                        sea_query::PgFunc::array_agg(inner).into()
1374                    }),
1375                    vec![inner.clone()],
1376                )
1377            },
1378            VisitTableMode::All,
1379        );
1380
1381        subst_table(&mut self.inner, TableName::new(binder), &mut expr);
1382
1383        Query::new(binder, expr, ListTable { inner: self.inner })
1384            .with_volatile(self.siblings_need_random)
1385    }
1386
1387    /// Enter a context in which [W] can be used to insert expressions using window functions.
1388    ///
1389    /// The callback `f` will receive an [F], the methods of which can be used
1390    /// to build a table in the `'outer` scope.
1391    pub fn window<U>(
1392        self,
1393        f: impl for<'inner, 'outer> FnOnce(
1394            &mut W<'inner, 'outer>,
1395            T::WithLt<'inner>,
1396        ) -> U::WithLt<'outer>,
1397    ) -> Query<U::WithLt<'scope>>
1398    where
1399        U: ForLifetimeTable + Table<'scope>,
1400    {
1401        let inner_binder = Binder::new();
1402        let middle_binder = Binder::new();
1403        let outer_binder = Binder::new();
1404
1405        let mut w = W {
1406            inner_query: sea_query::SelectStatement::new(),
1407            inner_table: TableName::new(inner_binder),
1408            middle_query: sea_query::SelectStatement::new(),
1409            middle_table: TableName::new(middle_binder),
1410            _phantom: PhantomData,
1411        };
1412
1413        let mut r = f(&mut w, self.inner.with_lt(&mut WithLtMarker::new()));
1414
1415        w.inner_query
1416            .from_subquery(self.expr, TableName::new(self.binder));
1417        w.middle_query.from_subquery(w.inner_query, w.inner_table);
1418
1419        let mut outer_query = sea_query::SelectStatement::new();
1420        outer_query.from_subquery(w.middle_query, w.middle_table.clone());
1421
1422        subst_table(&mut r, w.middle_table, &mut outer_query);
1423
1424        Query::new(outer_binder, outer_query, r)
1425    }
1426
1427    /// Lift a table into a query which ensures side effects happen and are not shared.
1428    pub fn evaluate(mut table: T) -> Self {
1429        let binder = Binder::new();
1430        let mut expr = sea_query::SelectStatement::new();
1431        subst_table(&mut table, TableName::new(binder), &mut expr);
1432        Self::new(binder, expr, table).into_volatile()
1433    }
1434
1435    /// Transform this query into one which produces rows of either [`Some<T>`] or [None].
1436    ///
1437    /// That is, this turns the query into a left join.
1438    pub fn optional(mut self) -> Query<MaybeTable<'scope, T>> {
1439        let binder = Binder::new();
1440
1441        let mut filler = sea_query::Query::select()
1442            .from_values(
1443                vec![sea_query::ValueTuple::One(true.into())],
1444                TableName::new(binder),
1445            )
1446            .to_owned();
1447
1448        let tag = ColumnName::new(binder, "tag".to_owned());
1449
1450        let mut expr = self.expr;
1451        expr.expr_as(sea_query::Value::Bool(Some(true)), tag.clone());
1452
1453        let table_name = TableName::new(self.binder);
1454
1455        filler.join_subquery(
1456            sea_query::JoinType::LeftJoin,
1457            expr,
1458            table_name.clone(),
1459            sea_query::Condition::all(),
1460        );
1461
1462        // important: the tag must be the first column
1463        filler.expr_as(
1464            sea_query::Expr::column((table_name.clone(), tag.clone())),
1465            tag.clone(),
1466        );
1467
1468        // the rest can come later
1469        subst_table(&mut self.inner, table_name, &mut filler);
1470
1471        let maybe_table = MaybeTable {
1472            inner: self.inner,
1473            tag: Expr {
1474                expr: ExprInner::Column(TableName::new(binder), tag),
1475                _phantom: PhantomData,
1476            },
1477        };
1478
1479        Query {
1480            binder,
1481            expr: filler,
1482            inner: maybe_table,
1483            siblings_need_random: self.siblings_need_random,
1484        }
1485    }
1486
1487    /// Make this table nullable. Unlike [MaybeTable], this isn't a left join,
1488    /// but instead considers the table null if any of the non-nullable columns
1489    /// are null.
1490    ///
1491    /// If the table only contains `Option<T>` then this table cannot distinguish.
1492    pub fn nullable(mut self) -> Query<NullTable<'scope, T>> {
1493        let binder = Binder::new();
1494
1495        let tag = ColumnName::new(binder, "tag".to_owned());
1496
1497        let mut non_null_exprs = vec![];
1498        self.inner.visit(
1499            &mut |ErasedExpr(e)| {
1500                non_null_exprs.push(e.render().is_not_null());
1501            },
1502            VisitTableMode::NonNull,
1503        );
1504
1505        let non_null_expr = non_null_exprs
1506            .into_iter()
1507            .fold(sea_query::Condition::all(), |a, b| a.add(b));
1508
1509        let mut expr = self.expr;
1510        expr.expr_as(non_null_expr, tag.clone());
1511
1512        let table_name = TableName::new(self.binder);
1513
1514        let mut query = sea_query::Query::select();
1515        query.from_subquery(expr, table_name.clone());
1516
1517        // important: the tag must be the first column
1518        query.expr_as(
1519            sea_query::Expr::column((table_name.clone(), tag.clone())),
1520            tag.clone(),
1521        );
1522
1523        // the rest can come later
1524        subst_table(&mut self.inner, table_name, &mut query);
1525
1526        let null_table = NullTable {
1527            inner: self.inner,
1528            tag: Expr::new(ExprInner::Column(TableName::new(binder), tag)),
1529        };
1530
1531        Query {
1532            binder,
1533            expr: query,
1534            inner: null_table,
1535            siblings_need_random: self.siblings_need_random,
1536        }
1537    }
1538
1539    /// Add an order by clause to the query, the given function should return a
1540    /// table, the query will be ordered by each column of the table.
1541    pub fn order_by<U, F>(self, f: F) -> Query<T>
1542    where
1543        U: Table<'scope>,
1544        F: FnOnce(&T) -> (U, sea_query::Order),
1545    {
1546        let binder = Binder::new();
1547
1548        let mut outer = sea_query::Query::select();
1549        outer.from_subquery(self.expr, TableName::new(self.binder));
1550
1551        let (order_expr, order) = f(&self.inner);
1552
1553        order_expr.visit(
1554            &mut |ErasedExpr(e)| {
1555                outer.order_by_expr(e.render(), order.clone());
1556            },
1557            VisitTableMode::All,
1558        );
1559
1560        let mut e = self.inner;
1561        subst_table(&mut e, TableName::new(binder), &mut outer);
1562
1563        Query {
1564            binder,
1565            expr: outer,
1566            inner: e,
1567            siblings_need_random: self.siblings_need_random,
1568        }
1569    }
1570}
1571
1572impl<'scope, T> Query<T>
1573where
1574    T: MapTable<'scope> + TableHKT<Mode = NameMode>,
1575    T::InMode<ExprMode>: ForLifetimeTable + Table<'scope>,
1576{
1577    /// Given a [TableSchema], build a query that selects all columns of every row.
1578    pub fn each(schema: &TableSchema<T>) -> Query<T::InMode<ExprMode>> {
1579        let binder = Binder::new();
1580        let mut query = sea_query::Query::select();
1581        query.from(schema.name);
1582
1583        let mut mapper = NameToExprMapper { binder, query };
1584
1585        let expr = schema.columns.map_modes_ref(&mut mapper);
1586
1587        Query::new(binder, mapper.query, expr)
1588    }
1589}
1590
1591impl<'scope, T: TableHKT<Mode = ValueMode>> Query<T> {
1592    /// Construct a query yielding the given rows
1593    pub fn values(vals: impl IntoIterator<Item = T>) -> Query<T::InMode<ExprMode>>
1594    where
1595        T: MapTable<'scope>,
1596        T::InMode<ExprMode>: ForLifetimeTable + Table<'scope>,
1597    {
1598        let binder = Binder::new();
1599        let mut iter = vals.into_iter();
1600
1601        let Some(first) = iter.next() else {
1602            panic!("Don't do that");
1603        };
1604
1605        let mut mapper = ExprCollectorMapper {
1606            idx: 0,
1607            table_binder: binder,
1608            columns: Vec::new(),
1609            values: Vec::new(),
1610        };
1611
1612        let result_expr = first.map_modes(&mut mapper);
1613
1614        let mut all_values = vec![mapper.values];
1615
1616        for v in iter {
1617            let mut mapper = ExprCollectorRemainingMapper { values: Vec::new() };
1618            v.map_modes(&mut mapper);
1619            all_values.push(mapper.values);
1620        }
1621
1622        let mut select = sea_query::Query::select();
1623        for (idx, col) in mapper.columns.into_iter().enumerate() {
1624            select.expr_as(sea_query::Expr::column(format!("column{}", idx + 1)), col);
1625        }
1626
1627        select.from_values(
1628            all_values
1629                .into_iter()
1630                .map(|v| sea_query::ValueTuple::Many(v)),
1631            TableName::new(binder),
1632        );
1633
1634        Query::new(binder, select, result_expr)
1635    }
1636}
1637
1638#[cfg(feature = "sqlx")]
1639impl<T: TableLoaderSqlx> Query<T> {
1640    /// Load all rows of the query
1641    pub async fn all(&self, pool: &mut sqlx::PgConnection) -> sqlx::Result<Vec<T::Result>> {
1642        use sea_query::PostgresQueryBuilder;
1643        use sea_query_sqlx::SqlxBinder as _;
1644        use sqlx::Row as _;
1645
1646        let (sql, values) = self.expr.build_sqlx(PostgresQueryBuilder);
1647
1648        let all = sqlx::query_with(&sql, values).fetch_all(pool).await?;
1649
1650        Ok(all
1651            .into_iter()
1652            .map(|row| {
1653                let len = row.len();
1654                let mut it = (0..len).map(|x| row.try_get_raw(x).unwrap());
1655                self.inner.load(&mut it)
1656            })
1657            .collect::<Vec<_>>())
1658    }
1659}
1660
1661#[derive(Debug)]
1662struct ErasedQuery {
1663    expr: sea_query::SelectStatement,
1664    siblings_need_random: bool,
1665}
1666
1667/// A publicly exposed opaque type that is used by [Table::visit] and
1668/// [Table::visit_mut]. Its purpose is to allow you to store [Expr]s in your
1669/// types which implement the [Table] trait.
1670#[derive(bytemuck::TransparentWrapper)]
1671#[repr(transparent)]
1672pub struct ErasedExpr(ExprInner);
1673
1674#[derive(Clone)]
1675enum ExprInner {
1676    Raw(sea_query::Expr),
1677    Column(TableName, ColumnName),
1678    BinOp(
1679        Arc<dyn Fn(sea_query::SimpleExpr, sea_query::SimpleExpr) -> sea_query::SimpleExpr>,
1680        Box<ExprInner>,
1681        Box<ExprInner>,
1682    ),
1683    NOp(
1684        Arc<dyn Fn(Vec<sea_query::SimpleExpr>) -> sea_query::SimpleExpr>,
1685        Vec<ExprInner>,
1686    ),
1687}
1688
1689impl ExprInner {
1690    fn visit_mut(&mut self, f: &mut impl FnMut(&mut TableName, &mut ColumnName)) {
1691        match self {
1692            ExprInner::Column(table_name, column_name) => f(table_name, column_name),
1693            ExprInner::BinOp(_, expr_inner, expr_inner1) => {
1694                expr_inner.visit_mut(f);
1695                expr_inner1.visit_mut(f);
1696            }
1697            ExprInner::NOp(_, inners) => {
1698                for inner in inners {
1699                    inner.visit_mut(f);
1700                }
1701            }
1702            ExprInner::Raw(_) => {}
1703        }
1704    }
1705
1706    fn render(&self) -> sea_query::SimpleExpr {
1707        match self {
1708            ExprInner::Raw(value) => value.clone(),
1709            ExprInner::Column(table_name, column_name) => {
1710                sea_query::Expr::column((table_name.clone(), column_name.clone()))
1711            }
1712            ExprInner::BinOp(cb, expr_inner, expr_inner1) => {
1713                cb(expr_inner.render(), expr_inner1.render())
1714            }
1715            ExprInner::NOp(cb, inners) => cb(inners.into_iter().map(|n| n.render()).collect()),
1716        }
1717    }
1718}
1719
1720/// A type representing an expression in the query, can be passed around on the
1721/// rust side to wire things up
1722#[derive(Clone)]
1723pub struct Expr<'scope, T> {
1724    expr: ExprInner,
1725    _phantom: PhantomData<(&'scope (), T)>,
1726}
1727
1728impl<'scope, T> Expr<'scope, T> {
1729    fn new(expr: ExprInner) -> Self {
1730        Self {
1731            expr,
1732            _phantom: PhantomData,
1733        }
1734    }
1735
1736    /// Construct a literal value from any value from any value that can be encoded.
1737    pub fn lit(value: T) -> Self
1738    where
1739        T: Into<sea_query::Value>,
1740    {
1741        Self::new(ExprInner::Raw(sea_query::Expr::value(value.into())))
1742    }
1743
1744    fn binop<U>(
1745        self,
1746        other: Self,
1747        binop: Arc<dyn Fn(sea_query::SimpleExpr, sea_query::SimpleExpr) -> sea_query::SimpleExpr>,
1748    ) -> Expr<'scope, U> {
1749        Expr::new(ExprInner::BinOp(
1750            binop,
1751            Box::new(self.expr),
1752            Box::new(other.expr),
1753        ))
1754    }
1755
1756    /// SQL equality
1757    pub fn equals(self, other: Self) -> Expr<'scope, bool> {
1758        self.binop(
1759            other,
1760            Arc::new(|a, b| a.binary(sea_query::BinOper::Equal, b)),
1761        )
1762    }
1763
1764    fn as_erased(&self) -> &ErasedExpr {
1765        ErasedExpr::wrap_ref(&self.expr)
1766    }
1767
1768    fn as_erased_mut(&mut self) -> &mut ErasedExpr {
1769        ErasedExpr::wrap_mut(&mut self.expr)
1770    }
1771}
1772
1773// TODO: num trait
1774impl<'scope> Expr<'scope, i32> {
1775    /// SQL numeric addition
1776    pub fn add(self, other: Self) -> Self {
1777        self.binop(other, Arc::new(|a, b| a.binary(sea_query::BinOper::Add, b)))
1778    }
1779
1780    /// generate `nextval('name')`, this must be used within [`Query::evaluate`]
1781    /// for it to behave properly.
1782    ///
1783    /// # Example
1784    ///
1785    /// ```rust
1786    /// use rust_rel8::{helper_tables::One, *};
1787    ///
1788    /// query::<(Expr<i32>, Expr<i32>)>(|q| {
1789    ///   let id = q.q(Query::evaluate(Expr::nextval("table_id_seq")));
1790    ///   let v = q.q(Query::values([1, 2, 3].map(|a| One { a })));
1791    ///   (id, v.a)
1792    /// });
1793    /// ```
1794    pub fn nextval(name: &str) -> Self {
1795        Self::new(ExprInner::Raw(
1796            sea_query::Func::cust("nextval").arg(name.to_owned()).into(),
1797        ))
1798    }
1799}
1800
1801/// An opaque value you can use to compose together queries.
1802///
1803/// To get a value of this type, use [query].
1804pub struct Q<'scope> {
1805    queries: Vec<(TableName, ErasedQuery)>,
1806    filters: Vec<ExprInner>,
1807    binder: Binder,
1808    _phantom: PhantomData<&'scope ()>,
1809}
1810
1811impl<'scope> Q<'scope> {
1812    /// Bind a query and give you a value representing each row it produces.
1813    ///
1814    /// The `'scope` lifetime prevents this value leaking out of its context,
1815    /// which would result in invalid queries.
1816    pub fn q<T: ForLifetimeTable + Table<'scope>>(&mut self, query: Query<T>) -> T {
1817        let binder = Binder::new();
1818        let name = TableName::new(binder);
1819        let (erased, mut inner) = query.erased();
1820        self.queries.push((name.clone(), erased));
1821        insert_table_name(&mut inner, name);
1822        inner
1823    }
1824
1825    /// Introduce a where clause for this query.
1826    ///
1827    /// If you introduce a clause that looks like `a.id = b.a_id` then you
1828    /// effectively create an inner join.
1829    pub fn where_<'a>(&mut self, expr: Expr<'a, bool>)
1830    where
1831        'scope: 'a,
1832    {
1833        self.filters.push(expr.expr);
1834    }
1835}
1836
1837/// Open a context allowing you to manipulate a query.
1838///
1839/// Inside you can use `q.q(...)` on as many [`Query<T>`] values as you wish, the
1840/// result of each call can be thought of as each value the query yields.
1841///
1842/// You can think of this as cross joining each query together, to create inner
1843/// joins or left joins, simply use [Q::where_] and [`Query<T>::optional`].
1844///
1845/// Unfortunately, rustc isn't able to infer the return type of this function as
1846/// there seems to be no good way to express that `T::WithLt<'a>` is the same type
1847/// as `T::WithLt<'b>` modulo lifetimes.
1848pub fn query<'outer, T: ForLifetimeTable>(
1849    f: impl for<'scope> FnOnce(&mut Q<'scope>) -> T::WithLt<'scope>,
1850) -> Query<T::WithLt<'outer>> {
1851    let mut q = Q {
1852        binder: Binder::new(),
1853        filters: Vec::new(),
1854        queries: Vec::new(),
1855        _phantom: PhantomData,
1856    };
1857
1858    let mut e = f(&mut q);
1859
1860    // if one of the selects needs the parents to have dummy columns to prevent
1861    // postgres evaluating it only once, we add to each query a `random() as dummy`.
1862    // Then, in those selects needing them,
1863    let needs_random = q.queries.iter().any(|(_, q)| q.siblings_need_random);
1864
1865    let mut random_binders: Vec<sea_query::Expr> = Vec::new();
1866
1867    let mut insert_dummy = |mut stmt: sea_query::SelectStatement, table: &TableName| {
1868        if needs_random {
1869            stmt.expr_as(sea_query::Func::random(), "dummy");
1870
1871            for binder in &random_binders {
1872                stmt.and_where(binder.clone());
1873            }
1874
1875            random_binders
1876                .push(sea_query::Expr::column((table.clone(), "dummy".to_string())).is_not_null())
1877        }
1878        stmt
1879    };
1880
1881    let mut iter = q.queries.into_iter();
1882    let mut table = sea_query::Query::select();
1883
1884    if let Some((first_table_name, first)) = iter.next() {
1885        let expr = insert_dummy(first.expr, &first_table_name);
1886        table.from_subquery(expr, first_table_name);
1887    };
1888
1889    for (table_name, q) in iter {
1890        let expr = insert_dummy(q.expr, &table_name);
1891        table.join_lateral(
1892            // normally a cross join, but sea_query doesn't support omitting the `ON` for cross joins (:
1893            // and CROSS JOIN is INNER JOIN ON TRUE
1894            sea_query::JoinType::InnerJoin,
1895            expr,
1896            table_name,
1897            sea_query::Condition::all(),
1898        );
1899    }
1900
1901    for filter in q.filters {
1902        table.and_where(filter.render());
1903    }
1904
1905    subst_table(&mut e, TableName::new(q.binder), &mut table);
1906
1907    // if we needed to add random calls, so do parents
1908    let mut q = Query::new(q.binder, table.to_owned(), e);
1909    q.siblings_need_random = needs_random;
1910    q
1911}
1912
1913/// Construct an insert statement, the result of the query `rows` will be inserted into `into`.
1914pub struct Insert<T: TableHKT<Mode = NameMode>> {
1915    /// The table to insert the rows into.
1916    pub into: TableSchema<T>,
1917    /// Query producing the rows to insert.
1918    pub rows: Query<T::InMode<ExprMode>>,
1919}
1920
1921#[cfg(feature = "sqlx")]
1922impl<T: TableHKT<Mode = NameMode>> Insert<T>
1923where
1924    T: MapTable<'static>,
1925{
1926    /// Run the insert statement
1927    pub async fn run(
1928        &self,
1929        pool: &mut sqlx::PgConnection,
1930    ) -> sqlx::Result<sqlx::postgres::PgQueryResult> {
1931        use sea_query::PostgresQueryBuilder;
1932        use sea_query_sqlx::SqlxBinder as _;
1933
1934        let mut insert = sea_query::Query::insert()
1935            .into_table(self.into.name)
1936            .to_owned();
1937
1938        let mut mapper = NameCollectorMapper { names: Vec::new() };
1939        self.into.columns.map_modes_ref(&mut mapper);
1940
1941        insert.columns(mapper.names);
1942        insert.select_from(self.rows.expr.clone()).unwrap();
1943
1944        let (sql, values) = insert.build_sqlx(PostgresQueryBuilder);
1945
1946        let all = sqlx::query_with(&sql, values).execute(pool).await?;
1947
1948        Ok(all)
1949    }
1950}
1951
1952/// A struct allowing manipulation of an aggregation
1953///
1954/// # Lifetimes
1955///
1956/// The `'inner` lifetime represents the lifetime of the query you are aggregating on.
1957/// The `'outer` lifetime represents the lifetime of the aggregation query being build.
1958///
1959/// The two are separate to ensure that any expression in the output table is
1960/// either used in an aggregation function, or part of a group by clause.
1961pub struct A<'inner, 'outer> {
1962    group_by: Vec<ExprInner>,
1963    _phantom: PhantomData<(&'inner (), &'outer ())>,
1964}
1965
1966impl<'inner, 'outer> A<'inner, 'outer> {
1967    /// Group the aggregation by the given table.
1968    pub fn group_by<T: Table<'inner> + ForLifetimeTable>(&mut self, expr: T) -> T::WithLt<'outer> {
1969        expr.visit(
1970            &mut |ErasedExpr(e)| {
1971                self.group_by.push(e.clone());
1972            },
1973            VisitTableMode::All,
1974        );
1975
1976        expr.with_lt(&mut WithLtMarker::new())
1977    }
1978
1979    /// Aggregate rows of the given table into an array
1980    pub fn array_agg<T: Table<'inner> + ForLifetimeTable>(
1981        &mut self,
1982        mut expr: T,
1983    ) -> ListTable<T::WithLt<'outer>> {
1984        expr.visit_mut(
1985            &mut |ErasedExpr(inner)| {
1986                *inner = ExprInner::NOp(
1987                    Arc::new(|inners| {
1988                        let [inner] = inners.try_into().unwrap();
1989
1990                        sea_query::PgFunc::array_agg(inner).into()
1991                    }),
1992                    vec![inner.clone()],
1993                )
1994            },
1995            VisitTableMode::All,
1996        );
1997
1998        ListTable {
1999            inner: expr.with_lt(&mut WithLtMarker::new()),
2000        }
2001    }
2002
2003    fn simple_agg_fn<T: Table<'inner> + ForLifetimeTable>(
2004        &mut self,
2005        mut expr: T,
2006        f: impl Fn(sea_query::Expr) -> sea_query::Expr + Clone + 'static,
2007    ) -> T::WithLt<'outer> {
2008        expr.visit_mut(
2009            &mut |ErasedExpr(inner)| {
2010                let f = f.clone();
2011                *inner = ExprInner::NOp(
2012                    Arc::new(move |inners| {
2013                        let [inner] = inners.try_into().unwrap();
2014
2015                        f(inner)
2016                    }),
2017                    vec![inner.clone()],
2018                )
2019            },
2020            VisitTableMode::All,
2021        );
2022
2023        expr.with_lt(&mut WithLtMarker::new())
2024    }
2025
2026    /// Average the values of this table in the group.
2027    pub fn avg<T: Table<'inner> + ForLifetimeTable>(&mut self, expr: T) -> T::WithLt<'outer> {
2028        self.simple_agg_fn(expr, |e| e.avg())
2029    }
2030
2031    /// Bit and the values of this table in the group.
2032    pub fn bit_and<T: Table<'inner> + ForLifetimeTable>(&mut self, expr: T) -> T::WithLt<'outer> {
2033        self.simple_agg_fn(expr, |e| sea_query::Func::bit_and(e).into())
2034    }
2035
2036    /// Bit or the values of this table in the group.
2037    pub fn bit_or<T: Table<'inner> + ForLifetimeTable>(&mut self, expr: T) -> T::WithLt<'outer> {
2038        self.simple_agg_fn(expr, |e| sea_query::Func::bit_or(e).into())
2039    }
2040
2041    /// Bit xor the values of this table in the group.
2042    pub fn bit_xor<T: Table<'inner> + ForLifetimeTable>(&mut self, expr: T) -> T::WithLt<'outer> {
2043        self.simple_agg_fn(expr, |e| sea_query::Func::cust("BIT_XOR").arg(e).into())
2044    }
2045
2046    /// Boolean and the values of this table in the group.
2047    pub fn bool_and<T: Table<'inner> + ForLifetimeTable>(&mut self, expr: T) -> T::WithLt<'outer> {
2048        self.simple_agg_fn(expr, |e| sea_query::Func::cust("BOOL_AND").arg(e).into())
2049    }
2050
2051    /// Boolean or the values of this table in the group.
2052    pub fn bool_or<T: Table<'inner> + ForLifetimeTable>(&mut self, expr: T) -> T::WithLt<'outer> {
2053        self.simple_agg_fn(expr, |e| sea_query::Func::cust("BOOL_OR").arg(e).into())
2054    }
2055
2056    /// Count the number of rows in the group, including nulls.
2057    pub fn count_star(&mut self) -> Expr<'outer, i32> {
2058        Expr::new(ExprInner::Raw(sea_query::Expr::cust("*").count()))
2059    }
2060
2061    /// Count the number of rows in the group, excluding where the table is null.
2062    pub fn count<T>(&mut self, expr: Expr<'inner, T>) -> Expr<'outer, i32> {
2063        Expr::new(ExprInner::NOp(
2064            Arc::new(|inners| {
2065                let [inner] = inners.try_into().unwrap();
2066
2067                inner.count()
2068            }),
2069            vec![expr.expr],
2070        ))
2071    }
2072
2073    /// Count the distinct number of rows in the group, excluding where the table is null.
2074    pub fn count_distinct<T>(&mut self, expr: Expr<'inner, T>) -> Expr<'outer, i32> {
2075        Expr::new(ExprInner::NOp(
2076            Arc::new(|inners| {
2077                let [inner] = inners.try_into().unwrap();
2078
2079                inner.count_distinct()
2080            }),
2081            vec![expr.expr],
2082        ))
2083    }
2084
2085    /// The maximum value of each column of this table in its group.
2086    pub fn max<T: Table<'inner> + ForLifetimeTable>(&mut self, expr: T) -> T::WithLt<'outer> {
2087        self.simple_agg_fn(expr, |e| e.max())
2088    }
2089
2090    /// The minimum value of each column of this table in its group.
2091    pub fn min<T: Table<'inner> + ForLifetimeTable>(&mut self, expr: T) -> T::WithLt<'outer> {
2092        self.simple_agg_fn(expr, |e| e.min())
2093    }
2094
2095    /// The summation of the values of each column of this table in its group.
2096    pub fn sum<T: Table<'inner> + ForLifetimeTable>(&mut self, expr: T) -> T::WithLt<'outer> {
2097        self.simple_agg_fn(expr, |e| e.sum())
2098    }
2099}
2100
2101/// A struct representing an `OVER` used by a window function.
2102#[derive(Clone)]
2103pub struct Over<'scope> {
2104    partitions: Vec<ExprInner>,
2105    orderings: Vec<(ExprInner, sea_query::Order)>,
2106    _phantom: PhantomData<&'scope ()>,
2107}
2108
2109impl<'scope> Over<'scope> {
2110    /// Create a new [Over] for use with [W] in [Query::window]
2111    pub fn new() -> Self {
2112        Self {
2113            partitions: Vec::new(),
2114            orderings: Vec::new(),
2115            _phantom: PhantomData,
2116        }
2117    }
2118
2119    /// Add an `ORDER BY` clause to this `OVER`
2120    pub fn order_by<T: Table<'scope>>(mut self, expr: T, order: sea_query::Order) -> Self {
2121        expr.visit(
2122            &mut |ErasedExpr(e)| self.orderings.push((e.clone(), order.clone())),
2123            VisitTableMode::All,
2124        );
2125        self
2126    }
2127
2128    /// Add a `PARTITION BY` clause to this `OVER`
2129    pub fn partition_by<T: Table<'scope>>(mut self, expr: T) -> Self {
2130        expr.visit(
2131            &mut |ErasedExpr(e)| self.partitions.push(e.clone()),
2132            VisitTableMode::All,
2133        );
2134        self
2135    }
2136}
2137
2138/// A struct allowing insertion of window functions
2139///
2140/// # Lifetimes
2141///
2142/// The `'inner` lifetime represents the lifetime of the query you are starting with.
2143/// The `'outer` lifetime represents the lifetime of the the new query containing the window function expressions.
2144///
2145/// Ideally this would not be used and you would just use something like
2146/// `Expr::row_number().over(...)`, but the query builder being used requires
2147/// that `PARTITION BY` takes only column names, therefore to support any
2148/// arbitrary expression we need to place everything into a subquery in which we
2149/// just alias all expressions to single tables.
2150///
2151/// If you have a table in the `'inner` scope, you can freely use [W::id] to bring it to `'outer`.
2152pub struct W<'inner, 'outer> {
2153    // inner select, starts empty. when consuming a partition, or by use of
2154    // `id`, we add the contained expressions to this table.
2155    inner_query: sea_query::SelectStatement,
2156    inner_table: TableName,
2157    middle_query: sea_query::SelectStatement,
2158    middle_table: TableName,
2159    _phantom: PhantomData<(&'inner (), &'outer ())>,
2160}
2161
2162impl<'inner, 'outer> W<'inner, 'outer> {
2163    fn generic_window(&mut self, expr: sea_query::Expr, over: Over<'inner>) -> ExprInner {
2164        let mut window_statement = sea_query::WindowStatement::new();
2165
2166        for expr in over.partitions {
2167            let name = ColumnName::new(Binder::new(), "part".to_owned());
2168            self.inner_query.expr_as(expr.render(), name.clone());
2169            window_statement.partition_by((self.inner_table.clone(), name));
2170        }
2171
2172        for (expr, order) in over.orderings {
2173            let name = ColumnName::new(Binder::new(), "ord".to_owned());
2174            self.inner_query.expr_as(expr.render(), name.clone());
2175            window_statement.order_by((self.inner_table.clone(), name), order);
2176        }
2177
2178        let column = ColumnName::new(Binder::new(), "win".to_owned());
2179
2180        self.middle_query
2181            .expr_window_as(expr, window_statement, column.clone());
2182
2183        ExprInner::Column(self.middle_table.clone(), column)
2184    }
2185
2186    pub fn row_number(&mut self, over: Over<'inner>) -> Expr<'outer, i64> {
2187        Expr::new(self.generic_window(sea_query::Func::cust("row_number").into(), over))
2188    }
2189
2190    pub fn rank(&mut self, over: Over<'inner>) -> Expr<'outer, i64> {
2191        Expr::new(self.generic_window(sea_query::Func::cust("rank").into(), over))
2192    }
2193
2194    pub fn dense_rank(&mut self, over: Over<'inner>) -> Expr<'outer, i64> {
2195        Expr::new(self.generic_window(sea_query::Func::cust("dense_rank").into(), over))
2196    }
2197
2198    pub fn percent_rank(&mut self, over: Over<'inner>) -> Expr<'outer, f64> {
2199        Expr::new(self.generic_window(sea_query::Func::cust("percent_rank").into(), over))
2200    }
2201
2202    pub fn cume_dist(&mut self, over: Over<'inner>) -> Expr<'outer, f64> {
2203        Expr::new(self.generic_window(sea_query::Func::cust("cume_dist").into(), over))
2204    }
2205
2206    pub fn ntile(
2207        &mut self,
2208        num_buckets: Expr<'outer, i64>,
2209        over: Over<'inner>,
2210    ) -> Expr<'outer, f64> {
2211        Expr::new(
2212            self.generic_window(
2213                sea_query::Func::cust("ntlile")
2214                    .arg(num_buckets.expr.render())
2215                    .into(),
2216                over,
2217            ),
2218        )
2219    }
2220
2221    fn lag_lead_generic<T: Table<'inner> + ForLifetimeTable>(
2222        &mut self,
2223        fn_: &'static str,
2224        mut value: T,
2225        offset: Option<Expr<'outer, i64>>,
2226        default: Option<T>,
2227        over: Over<'inner>,
2228    ) -> T::WithLt<'outer> {
2229        let mut default_exprs = default.as_ref().map(collect_exprs).map(|x| x.into_iter());
2230
2231        let mut window_statement = sea_query::WindowStatement::new();
2232
2233        for expr in over.partitions {
2234            let name = ColumnName::new(Binder::new(), "part".to_owned());
2235            self.inner_query.expr_as(expr.render(), name.clone());
2236            window_statement.partition_by((self.inner_table.clone(), name));
2237        }
2238
2239        for (expr, order) in over.orderings {
2240            let name = ColumnName::new(Binder::new(), "ord".to_owned());
2241            self.inner_query.expr_as(expr.render(), name.clone());
2242            window_statement.order_by((self.inner_table.clone(), name), order);
2243        }
2244
2245        subst_table(&mut value, self.inner_table.clone(), &mut self.inner_query);
2246
2247        value.visit_mut(
2248            &mut |ErasedExpr(inner)| {
2249                let offset = offset.clone();
2250                let default = default_exprs.as_mut().map(|x| x.next().unwrap());
2251
2252                let r = sea_query::Func::cust(fn_).arg(inner.clone().render());
2253                let (r, set_default) = if let Some(offset) = offset {
2254                    (r.arg(offset.expr.render()), true)
2255                } else {
2256                    (r, false)
2257                };
2258                let r = if let Some(default) = default {
2259                    let r = if !set_default {
2260                        r.arg(sea_query::Expr::value(1))
2261                    } else {
2262                        r
2263                    };
2264                    r.arg(default.render())
2265                } else {
2266                    r
2267                };
2268
2269                let column = ColumnName::new(Binder::new(), "win".to_owned());
2270
2271                self.middle_query
2272                    .expr_window_as(r, window_statement.clone(), column.clone());
2273
2274                *inner = ExprInner::Column(self.middle_table.clone(), column);
2275            },
2276            VisitTableMode::All,
2277        );
2278
2279        value.with_lt(&mut WithLtMarker::new())
2280    }
2281
2282    pub fn lag<T: Table<'inner> + ForLifetimeTable>(
2283        &mut self,
2284        value: T,
2285        offset: Option<Expr<'outer, i64>>,
2286        default: Option<T>,
2287        over: Over<'inner>,
2288    ) -> T::WithLt<'outer> {
2289        self.lag_lead_generic("lag", value, offset, default, over)
2290    }
2291
2292    pub fn lead<T: Table<'inner> + ForLifetimeTable>(
2293        &mut self,
2294        value: T,
2295        offset: Option<Expr<'outer, i64>>,
2296        default: Option<T>,
2297        over: Over<'inner>,
2298    ) -> T::WithLt<'outer> {
2299        self.lag_lead_generic("lead", value, offset, default, over)
2300    }
2301
2302    pub fn first_value<T: Table<'inner> + ForLifetimeTable>(
2303        &mut self,
2304        value: T,
2305        over: Over<'inner>,
2306    ) -> T::WithLt<'outer> {
2307        self.lag_lead_generic("first_value", value, None, None, over)
2308    }
2309
2310    pub fn last_value<T: Table<'inner> + ForLifetimeTable>(
2311        &mut self,
2312        value: T,
2313        over: Over<'inner>,
2314    ) -> T::WithLt<'outer> {
2315        self.lag_lead_generic("first_value", value, None, None, over)
2316    }
2317
2318    // needs some thought
2319    // pub fn nth_value<T: Table<'outer>>(
2320    //     &mut self,
2321    //     value: T,
2322    //     offset: Option<Expr<'outer, i64>>,
2323    //     over: Over<'inner>,
2324    // ) -> NullTable<'outer, T> {
2325    //     self.lag_lead_generic("lag", value, offset, None, over)
2326    // }
2327
2328    /// Bring a table in the `'inner` scope to the `'outer` scope.
2329    ///
2330    /// Ideally this wouldn't be needed, but the sql builder used internally
2331    /// doesn't support arbitrary expressions in window function partitions, so
2332    /// we have to move everything into a subquery. This function is then used
2333    /// to reselect the input table from this subquery.
2334    pub fn id<T: Table<'inner> + ForLifetimeTable>(&mut self, mut table: T) -> T::WithLt<'outer> {
2335        subst_table(&mut table, self.inner_table.clone(), &mut self.inner_query);
2336        subst_table(
2337            &mut table,
2338            self.middle_table.clone(),
2339            &mut self.middle_query,
2340        );
2341        table.with_lt(&mut WithLtMarker::new())
2342    }
2343}
2344
2345/// A set of helper tables that are equivalent to tuples.
2346///
2347/// We need these as an alternative to just tuples in order to be parameterised
2348/// by the [TableMode].
2349pub mod helper_tables {
2350    use super::*;
2351    use rust_rel8_derive::TableStruct;
2352
2353    #[derive(TableStruct)]
2354    #[table(crate = "crate")]
2355    #[perfect_derive::perfect_derive(Debug, PartialEq, Clone)]
2356    /// A helper table with one field.
2357    pub struct One<'scope, Mode: TableMode, #[table(proxy)] A: Value> {
2358        pub a: Mode::T<'scope, A>,
2359    }
2360
2361    #[derive(TableStruct)]
2362    #[table(crate = "crate")]
2363    #[perfect_derive::perfect_derive(Debug, PartialEq, Clone)]
2364    /// A helper table with two fields.
2365    pub struct Two<'scope, Mode: TableMode, #[table(proxy)] A: Value, #[table(proxy)] B: Value> {
2366        pub a: Mode::T<'scope, A>,
2367        pub b: Mode::T<'scope, B>,
2368    }
2369
2370    #[derive(TableStruct)]
2371    #[table(crate = "crate")]
2372    #[perfect_derive::perfect_derive(Debug, PartialEq, Clone)]
2373    /// A helper table with three fields.
2374    pub struct Three<
2375        'scope,
2376        Mode: TableMode,
2377        #[table(proxy)] A: Value,
2378        #[table(proxy)] B: Value,
2379        #[table(proxy)] C: Value,
2380    > {
2381        pub a: Mode::T<'scope, A>,
2382        pub b: Mode::T<'scope, B>,
2383        pub c: Mode::T<'scope, C>,
2384    }
2385}
2386
2387pub mod helper_utilities {
2388    use std::{collections::HashMap, hash::Hash};
2389
2390    /// A helper trait which implements `shorten_lifetime` for some common wrapper types.
2391    pub trait ShortenLifetime {
2392        type Shortened<'small>
2393        where
2394            Self: 'small;
2395
2396        /// Shorten a lifetime, normally rust does this automatically, but if
2397        /// the lifetime is invariant due to being used in a Gat or trait, we
2398        /// need to do it manually.
2399        ///
2400        /// If rust complains about a lifetime being invariant, you should call
2401        /// this method at the use site where the lifetime error is generated.
2402        fn shorten_lifetime<'small, 'large: 'small>(self) -> Self::Shortened<'small>
2403        where
2404            Self: 'large;
2405    }
2406
2407    impl<T: ShortenLifetime, const N: usize> ShortenLifetime for [T; N] {
2408        type Shortened<'small>
2409            = [T::Shortened<'small>; N]
2410        where
2411            Self: 'small;
2412
2413        fn shorten_lifetime<'small, 'large: 'small>(self) -> Self::Shortened<'small>
2414        where
2415            Self: 'large,
2416        {
2417            self.map(ShortenLifetime::shorten_lifetime)
2418        }
2419    }
2420
2421    impl<T: ShortenLifetime> ShortenLifetime for Vec<T> {
2422        type Shortened<'small>
2423            = Vec<T::Shortened<'small>>
2424        where
2425            Self: 'small;
2426
2427        fn shorten_lifetime<'small, 'large: 'small>(self) -> Self::Shortened<'small>
2428        where
2429            Self: 'large,
2430        {
2431            self.into_iter()
2432                .map(ShortenLifetime::shorten_lifetime)
2433                .collect::<Vec<_>>()
2434        }
2435    }
2436
2437    impl<K: Hash + Eq, T: ShortenLifetime> ShortenLifetime for HashMap<K, T> {
2438        type Shortened<'small>
2439            = HashMap<K, T::Shortened<'small>>
2440        where
2441            Self: 'small;
2442
2443        fn shorten_lifetime<'small, 'large: 'small>(self) -> Self::Shortened<'small>
2444        where
2445            Self: 'large,
2446        {
2447            self.into_iter()
2448                .map(|(k, v)| (k, v.shorten_lifetime()))
2449                .collect::<HashMap<_, _>>()
2450        }
2451    }
2452}
2453
2454pub use helper_utilities::ShortenLifetime;
2455
2456#[cfg(feature = "derive")]
2457pub use rust_rel8_derive::TableStruct;
2458
2459use self::is_nullable::IsNullable;