rust_query/value/
aggregate.rs

1use std::{
2    marker::PhantomData,
3    ops::{Deref, DerefMut},
4    rc::Rc,
5};
6
7use sea_query::{Asterisk, ExprTrait, Func, SelectStatement};
8
9use crate::{
10    Expr,
11    alias::MyAlias,
12    rows::Rows,
13    value::{EqTyp, IntoExpr, MyTyp, NumTyp, Typed, ValueBuilder},
14};
15
16use super::DynTypedExpr;
17
18/// This is the argument type used for [aggregate].
19pub struct Aggregate<'outer, 'inner, S> {
20    pub(crate) query: Rows<'inner, S>,
21    _p: PhantomData<&'inner &'outer ()>,
22}
23
24impl<'inner, S> Deref for Aggregate<'_, 'inner, S> {
25    type Target = Rows<'inner, S>;
26
27    fn deref(&self) -> &Self::Target {
28        &self.query
29    }
30}
31
32impl<S> DerefMut for Aggregate<'_, '_, S> {
33    fn deref_mut(&mut self) -> &mut Self::Target {
34        &mut self.query
35    }
36}
37
38impl<'outer, 'inner, S: 'static> Aggregate<'outer, 'inner, S> {
39    /// This must be used with an aggregating expression.
40    /// otherwise there is a change that there are multiple rows.
41    fn select<T>(
42        &self,
43        expr: impl 'static + Fn(&mut ValueBuilder) -> sea_query::Expr,
44    ) -> Aggr<S, Option<T>> {
45        let expr = DynTypedExpr::new(expr);
46        let mut builder = self.query.ast.clone().full();
47        let (select, mut fields) = builder.build_select(vec![expr]);
48
49        let conds = builder.forwarded.into_iter().map(|x| x.1.1).collect();
50
51        Aggr {
52            _p2: PhantomData,
53            select: Rc::new(select),
54            field: {
55                debug_assert_eq!(fields.len(), 1);
56                fields.swap_remove(0)
57            },
58            conds,
59        }
60    }
61
62    /// Return the average value in a column, this is [None] if there are zero rows.
63    pub fn avg(&self, val: impl IntoExpr<'inner, S, Typ = f64>) -> Expr<'outer, S, Option<f64>> {
64        let val = val.into_expr().inner;
65        Expr::new(self.select(move |b| Func::avg(val.build_expr(b)).into()))
66    }
67
68    /// Return the maximum value in a column, this is [None] if there are zero rows.
69    pub fn max<T>(&self, val: impl IntoExpr<'inner, S, Typ = T>) -> Expr<'outer, S, Option<T>>
70    where
71        T: NumTyp,
72    {
73        let val = val.into_expr().inner;
74        Expr::new(self.select(move |b| Func::max(val.build_expr(b)).into()))
75    }
76
77    /// Return the minimum value in a column, this is [None] if there are zero rows.
78    pub fn min<T>(&self, val: impl IntoExpr<'inner, S, Typ = T>) -> Expr<'outer, S, Option<T>>
79    where
80        T: NumTyp,
81    {
82        let val = val.into_expr().inner;
83        Expr::new(self.select(move |b| Func::min(val.build_expr(b)).into()))
84    }
85
86    /// Return the sum of a column.
87    pub fn sum<T>(&self, val: impl IntoExpr<'inner, S, Typ = T>) -> Expr<'outer, S, T>
88    where
89        T: NumTyp,
90    {
91        let val = val.into_expr().inner;
92        let val = self.select::<T>(move |b| Func::sum(val.build_expr(b)).into());
93
94        Expr::adhoc(move |b| {
95            sea_query::Expr::expr(val.build_expr(b))
96                .if_null(sea_query::Expr::Constant(T::ZERO.into_sea_value()))
97        })
98    }
99
100    /// Return the number of distinct values in a column.
101    pub fn count_distinct<T: EqTyp + 'static>(
102        &self,
103        val: impl IntoExpr<'inner, S, Typ = T>,
104    ) -> Expr<'outer, S, i64> {
105        let val = val.into_expr().inner;
106        let val = self.select::<i64>(move |b| Func::count_distinct(val.build_expr(b)).into());
107        Expr::adhoc(move |b| {
108            sea_query::Expr::expr(val.build_expr(b))
109                .if_null(sea_query::Expr::Constant(0i64.into_sea_value()))
110        })
111    }
112
113    /// Return whether there are any rows.
114    pub fn exists(&self) -> Expr<'outer, S, bool> {
115        let val = self.select::<i64>(|_| Func::count(sea_query::Expr::col(Asterisk)).into());
116        Expr::adhoc(move |b| sea_query::Expr::expr(val.build_expr(b)).is_not_null())
117    }
118}
119
120pub struct Aggr<S, T> {
121    pub(crate) _p2: PhantomData<(S, T)>,
122    pub(crate) select: Rc<SelectStatement>,
123    pub(crate) conds: Vec<DynTypedExpr>,
124    pub(crate) field: MyAlias,
125}
126
127impl<S, T> Clone for Aggr<S, T> {
128    fn clone(&self) -> Self {
129        Self {
130            _p2: PhantomData,
131            select: self.select.clone(),
132            conds: self.conds.clone(),
133            field: self.field,
134        }
135    }
136}
137
138impl<S, T: MyTyp> Typed for Aggr<S, T> {
139    type Typ = T;
140    fn build_expr(&self, b: &mut ValueBuilder) -> sea_query::Expr {
141        sea_query::Expr::col((self.build_table(b), self.field)).into()
142    }
143}
144
145impl<S, T> Aggr<S, T> {
146    fn build_table(&self, b: &mut ValueBuilder) -> MyAlias {
147        let conds = self.conds.iter().map(|expr| (expr.func)(b)).collect();
148        b.get_aggr(self.select.clone(), conds)
149    }
150}
151
152/// Perform an aggregate that returns a single result for each of the current rows.
153///
154/// You can filter the rows in the aggregate based on values from the outer query.
155/// That is the only way to get a different aggregate for each outer row.
156///
157/// ```
158/// # use rust_query::aggregate;
159/// # use rust_query::private::doctest::*;
160/// # rust_query::private::doctest::get_txn(|txn| {
161/// let res = txn.query_one(aggregate(|rows| {
162///     let user = rows.join(User);
163///     rows.count_distinct(user)
164/// }));
165/// assert_eq!(res, 1, "there is one user in the database");
166/// # });
167/// ```
168pub fn aggregate<'outer, S, F, R>(f: F) -> R
169where
170    F: for<'inner> FnOnce(&mut Aggregate<'outer, 'inner, S>) -> R,
171{
172    let inner = Rows {
173        phantom: PhantomData,
174        ast: Default::default(),
175        _p: PhantomData,
176    };
177    let mut group = Aggregate {
178        query: inner,
179        _p: PhantomData,
180    };
181    f(&mut group)
182}