rust_query/value/
aggregate.rs

1use std::{
2    marker::PhantomData,
3    ops::{Deref, DerefMut},
4    rc::Rc,
5};
6
7use ref_cast::RefCast;
8use sea_query::{Func, SelectStatement, SimpleExpr};
9
10use crate::{
11    Expr, Table,
12    alias::{Field, MyAlias},
13    ast::MySelect,
14    rows::Rows,
15    value::{EqTyp, IntoExpr, MyTyp, NumTyp, Typed, ValueBuilder},
16};
17
18/// This is the argument type used for [aggregate].
19pub struct Aggregate<'outer, 'inner, S> {
20    // pub(crate) outer_ast: &'inner MySelect,
21    pub(crate) conds: Vec<(Field, Rc<dyn Fn(ValueBuilder) -> SimpleExpr>)>,
22    pub(crate) query: Rows<'inner, S>,
23    // pub(crate) table: MyAlias,
24    pub(crate) phantom2: PhantomData<fn(&'outer ()) -> &'outer ()>,
25}
26
27impl<'outer, 'inner, S> Deref for Aggregate<'outer, 'inner, S> {
28    type Target = Rows<'inner, S>;
29
30    fn deref(&self) -> &Self::Target {
31        &self.query
32    }
33}
34
35impl<'outer, 'inner, S> DerefMut for Aggregate<'outer, 'inner, S> {
36    fn deref_mut(&mut self) -> &mut Self::Target {
37        &mut self.query
38    }
39}
40
41impl<'outer, 'inner, S: 'static> Aggregate<'outer, 'inner, S> {
42    fn select<T>(&self, expr: impl Into<SimpleExpr>) -> Aggr<S, Option<T>> {
43        let alias = self
44            .ast
45            .select
46            .get_or_init(expr.into(), || self.ast.scope.new_field());
47        Aggr {
48            _p2: PhantomData,
49            select: self.query.ast.build_select(true),
50            field: *alias,
51            conds: self.conds.clone(),
52        }
53    }
54
55    /// Filter the rows of this sub-query based on a value from the outer query.
56    pub fn filter_on<T: EqTyp + 'static>(
57        &mut self,
58        val: impl IntoExpr<'inner, S, Typ = T>,
59        on: impl IntoExpr<'outer, S, Typ = T>,
60    ) {
61        let on = on.into_expr().inner;
62        let val = val.into_expr().inner;
63        let alias = self.ast.scope.new_alias();
64        self.conds
65            .push((Field::U64(alias), Rc::new(move |b| on.build_expr(b))));
66        self.ast
67            .filter_on
68            .push(Box::new((val.build_expr(self.ast.builder()), alias)))
69    }
70
71    /// Return the average value in a column, this is [None] if there are zero rows.
72    pub fn avg(&self, val: impl IntoExpr<'inner, S, Typ = f64>) -> Expr<'outer, S, Option<f64>> {
73        let val = val.into_expr().inner;
74        let expr = Func::avg(val.build_expr(self.ast.builder()));
75        Expr::new(self.select(expr))
76    }
77
78    /// Return the maximum value in a column, this is [None] if there are zero rows.
79    pub fn max<T>(&self, val: impl IntoExpr<'inner, S, Typ = T>) -> Expr<'outer, S, Option<T>>
80    where
81        T: NumTyp,
82    {
83        let val = val.into_expr().inner;
84        let expr = Func::max(val.build_expr(self.ast.builder()));
85        Expr::new(self.select(expr))
86    }
87
88    /// Return the sum of a column.
89    pub fn sum<T>(&self, val: impl IntoExpr<'inner, S, Typ = T>) -> Expr<'outer, S, T>
90    where
91        T: NumTyp,
92    {
93        let val = val.into_expr().inner;
94        let expr = Func::sum(val.build_expr(self.ast.builder()));
95        let val = self.select::<T>(expr);
96        Expr::adhoc(move |b| {
97            sea_query::Expr::expr(val.build_expr(b))
98                .if_null(SimpleExpr::Constant(T::ZERO.into_sea_value()))
99        })
100    }
101
102    /// Return the number of distinct values in a column.
103    pub fn count_distinct<T: 'static>(
104        &self,
105        val: impl IntoExpr<'inner, S, Typ = T>,
106    ) -> Expr<'outer, S, i64>
107    where
108        T: EqTyp,
109    {
110        let val = val.into_expr().inner;
111        let expr = Func::count_distinct(val.build_expr(self.ast.builder()));
112        let val = self.select::<i64>(expr);
113        Expr::adhoc(move |b| {
114            sea_query::Expr::expr(val.build_expr(b))
115                .if_null(SimpleExpr::Constant(0i64.into_sea_value()))
116        })
117    }
118
119    /// Return whether there are any rows.
120    pub fn exists(&self) -> Expr<'outer, S, bool> {
121        let expr = SimpleExpr::Constant(1.into_sea_value());
122        let val = self.select::<i64>(expr);
123        Expr::adhoc(move |b| sea_query::Expr::expr(val.build_expr(b)).is_not_null())
124    }
125}
126
127pub struct Aggr<S, T> {
128    pub(crate) _p2: PhantomData<(S, T)>,
129    pub(crate) select: SelectStatement,
130    pub(crate) conds: Vec<(Field, Rc<dyn Fn(ValueBuilder) -> SimpleExpr>)>,
131    pub(crate) field: Field,
132}
133
134impl<S, T> Clone for Aggr<S, T> {
135    fn clone(&self) -> Self {
136        Self {
137            _p2: PhantomData,
138            select: self.select.clone(),
139            conds: self.conds.clone(),
140            field: self.field,
141        }
142    }
143}
144
145impl<S, T: MyTyp> Typed for Aggr<S, T> {
146    type Typ = T;
147    fn build_expr(&self, b: crate::value::ValueBuilder) -> SimpleExpr {
148        sea_query::Expr::col((self.build_table(b), self.field)).into()
149    }
150}
151
152impl<S, T> Aggr<S, T> {
153    fn build_table(&self, b: crate::value::ValueBuilder) -> MyAlias {
154        let conds = self.conds.iter().map(|(field, expr)| (*field, expr(b)));
155        b.get_aggr(self.select.clone(), conds.collect())
156    }
157}
158
159impl<S, T: Table> Deref for Aggr<S, T> {
160    type Target = T::Ext<Self>;
161
162    fn deref(&self) -> &Self::Target {
163        RefCast::ref_cast(self)
164    }
165}
166
167/// Perform an aggregate that returns a single result for each of the current rows.
168///
169/// You can filter the rows in the aggregate based on values from the outer query.
170/// That is the only way to get a different aggregate for each outer row.
171///
172/// ```
173/// # use rust_query::{Table, aggregate};
174/// # use rust_query::private::doctest::*;
175/// # let mut client = get_client();
176/// # let txn = get_txn(&mut client);
177/// let res = txn.query_one(aggregate(|rows| {
178///     let user = User::join(rows);
179///     rows.count_distinct(user)
180/// }));
181/// assert_eq!(res, 1, "there is one user in the database");
182/// ```
183pub fn aggregate<'outer, S, F, R>(f: F) -> R
184where
185    F: for<'inner> FnOnce(&mut Aggregate<'outer, 'inner, S>) -> R,
186{
187    let ast = MySelect::default();
188    let inner = Rows {
189        phantom: PhantomData,
190        ast,
191        _p: PhantomData,
192    };
193    let mut group = Aggregate {
194        conds: Vec::new(),
195        query: inner,
196        phantom2: PhantomData,
197    };
198    f(&mut group)
199}