rust_query/
aggregate.rs

1use std::{
2    marker::PhantomData,
3    ops::{Deref, DerefMut},
4    rc::Rc,
5};
6
7use ref_cast::RefCast;
8use sea_query::{Expr, Func, SelectStatement, SimpleExpr};
9
10use crate::{
11    alias::{Field, MyAlias},
12    ast::MySelect,
13    rows::Rows,
14    value::{
15        operations::{Const, IsNotNull, UnwrapOr},
16        EqTyp, IntoColumn, MyTyp, NumTyp, Typed, ValueBuilder,
17    },
18    Column, Table,
19};
20
21/// This is the argument type used for aggregates.
22///
23/// While it is possible to join many tables in an aggregate, there can be only one result.
24/// (The result can be a tuple or struct with multiple values though).
25pub struct Aggregate<'outer, 'inner, S> {
26    // pub(crate) outer_ast: &'inner MySelect,
27    pub(crate) conds: Vec<(Field, Rc<dyn 'outer + Fn(ValueBuilder) -> SimpleExpr>)>,
28    pub(crate) query: Rows<'inner, S>,
29    // pub(crate) table: MyAlias,
30    pub(crate) phantom2: PhantomData<fn(&'outer ()) -> &'outer ()>,
31}
32
33impl<'outer, 'inner, S> Deref for Aggregate<'outer, 'inner, S> {
34    type Target = Rows<'inner, S>;
35
36    fn deref(&self) -> &Self::Target {
37        &self.query
38    }
39}
40
41impl<'outer, 'inner, S> DerefMut for Aggregate<'outer, 'inner, S> {
42    fn deref_mut(&mut self) -> &mut Self::Target {
43        &mut self.query
44    }
45}
46
47impl<'outer: 'inner, 'inner, S: 'outer> Aggregate<'outer, 'inner, S> {
48    fn select<T>(&'inner self, expr: impl Into<SimpleExpr>) -> Aggr<'outer, S, Option<T>> {
49        let alias = self
50            .ast
51            .select
52            .get_or_init(expr.into(), || self.ast.scope.new_field());
53        Aggr {
54            _p: PhantomData,
55            _p2: PhantomData,
56            select: self.query.ast.build_select(true),
57            field: *alias,
58            conds: self.conds.clone(),
59        }
60    }
61
62    /// Filter the rows of this sub-query based on a value from the outer query.
63    pub fn filter_on<T>(
64        &mut self,
65        val: impl IntoColumn<'inner, S, Typ = T>,
66        on: impl IntoColumn<'outer, S, Typ = T>,
67    ) {
68        let on = on.into_owned();
69        let alias = self.ast.scope.new_alias();
70        self.conds
71            .push((Field::U64(alias), Rc::new(move |b| on.build_expr(b))));
72        self.ast
73            .filter_on
74            .push(Box::new((val.build_expr(self.ast.builder()), alias)))
75    }
76
77    /// Return the average value in a column, this is [None] if there are zero rows.
78    pub fn avg(
79        &'inner self,
80        val: impl IntoColumn<'inner, S, Typ = f64>,
81    ) -> Column<'outer, S, Option<f64>> {
82        let expr = Func::avg(val.build_expr(self.ast.builder()));
83        self.select(expr).into_column()
84    }
85
86    /// Return the maximum value in a column, this is [None] if there are zero rows.
87    pub fn max<T>(
88        &'inner self,
89        val: impl IntoColumn<'inner, S, Typ = T>,
90    ) -> Column<'outer, S, Option<T>>
91    where
92        T: NumTyp,
93    {
94        let expr = Func::max(val.build_expr(self.ast.builder()));
95        self.select(expr).into_column()
96    }
97
98    /// Return the sum of a column.
99    pub fn sum<T>(&'inner self, val: impl IntoColumn<'inner, S, Typ = T>) -> Column<'outer, S, T>
100    where
101        T: NumTyp,
102    {
103        let expr = Func::sum(val.build_expr(self.ast.builder()));
104        UnwrapOr(self.select::<T>(expr), Const(T::ZERO)).into_column()
105    }
106
107    /// Return the number of distinct values in a column.
108    pub fn count_distinct<T>(
109        &'inner self,
110        val: impl IntoColumn<'inner, S, Typ = T>,
111    ) -> Column<'outer, S, i64>
112    where
113        T: EqTyp,
114    {
115        let expr = Func::count_distinct(val.build_expr(self.ast.builder()));
116        UnwrapOr(self.select::<i64>(expr), Const(0)).into_column()
117    }
118
119    /// Return whether there are any rows.
120    pub fn exists(&'inner self) -> Column<'outer, S, bool> {
121        let expr = SimpleExpr::Constant(1.into_sea_value());
122        IsNotNull(self.select::<i64>(expr)).into_column()
123    }
124}
125
126pub struct Aggr<'t, S, T> {
127    pub(crate) _p: PhantomData<fn(&'t S) -> &'t S>,
128    pub(crate) _p2: PhantomData<T>,
129    pub(crate) select: SelectStatement,
130    pub(crate) conds: Vec<(Field, Rc<dyn 't + Fn(ValueBuilder) -> SimpleExpr>)>,
131    pub(crate) field: Field,
132}
133
134impl<S, T> Clone for Aggr<'_, S, T> {
135    fn clone(&self) -> Self {
136        Self {
137            _p: PhantomData,
138            _p2: PhantomData,
139            select: self.select.clone(),
140            conds: self.conds.clone(),
141            field: self.field,
142        }
143    }
144}
145
146impl<'t, S, T: MyTyp> Typed for Aggr<'t, S, T> {
147    type Typ = T;
148    fn build_expr(&self, b: crate::value::ValueBuilder) -> SimpleExpr {
149        Expr::col((self.build_table(b), self.field)).into()
150    }
151}
152
153impl<'t, S, T> Aggr<'t, S, T> {
154    fn build_table(&self, b: crate::value::ValueBuilder) -> MyAlias {
155        let conds = self.conds.iter().map(|(field, expr)| (*field, expr(b)));
156        b.get_aggr(self.select.clone(), conds.collect())
157    }
158}
159
160impl<'t, S, T: MyTyp> IntoColumn<'t, S> for Aggr<'t, S, T> {
161    type Owned = Self;
162
163    fn into_owned(self) -> Self::Owned {
164        self
165    }
166}
167
168impl<S, T: Table> Deref for Aggr<'_, S, T> {
169    type Target = T::Ext<Self>;
170
171    fn deref(&self) -> &Self::Target {
172        RefCast::ref_cast(self)
173    }
174}
175
176/// Perform an aggregate that returns a single result for each of the current rows.
177///
178/// You can filter the rows in the aggregate based on values from the outer query.
179/// That is the only way to get a different aggregate for each outer row.
180pub fn aggregate<'outer, S, F, R>(f: F) -> R
181where
182    F: for<'a> FnOnce(&'a mut Aggregate<'outer, 'a, S>) -> R,
183{
184    let ast = MySelect::default();
185    let inner = Rows {
186        phantom: PhantomData,
187        ast,
188    };
189    let mut group = Aggregate {
190        conds: Vec::new(),
191        query: inner,
192        phantom2: PhantomData,
193    };
194    f(&mut group)
195}