Skip to main content

rust_query/value/
aggregate.rs

1use std::{
2    marker::PhantomData,
3    ops::{Deref, DerefMut},
4    rc::Rc,
5};
6
7use sea_query::{Asterisk, ExprTrait, Func};
8
9use crate::{
10    Expr, IntoExpr,
11    rows::Rows,
12    value::{AdHoc, EqTyp, NumTyp, ValueBuilder},
13};
14
15use super::DynTypedExpr;
16
17/// This is the argument type used for [aggregate].
18pub struct Aggregate<'outer, 'inner, S> {
19    pub(crate) query: Rows<'inner, S>,
20    _p: PhantomData<&'inner &'outer ()>,
21}
22
23impl<'inner, S> Deref for Aggregate<'_, 'inner, S> {
24    type Target = Rows<'inner, S>;
25
26    fn deref(&self) -> &Self::Target {
27        &self.query
28    }
29}
30
31impl<S> DerefMut for Aggregate<'_, '_, S> {
32    fn deref_mut(&mut self) -> &mut Self::Target {
33        &mut self.query
34    }
35}
36
37impl<'outer, 'inner, S: 'static> Aggregate<'outer, 'inner, S> {
38    /// This must be used with an aggregating expression.
39    /// otherwise there is a change that there are multiple rows.
40    fn select<T: EqTyp>(
41        &self,
42        expr: impl 'static + Fn(&mut ValueBuilder) -> sea_query::Expr,
43    ) -> Rc<AdHoc<dyn Fn(&mut ValueBuilder) -> sea_query::Expr, Option<T>>> {
44        let expr = DynTypedExpr::new(expr);
45        let mut builder = self.query.ast.clone().full();
46        let (select, mut fields) = builder.build_select(vec![expr], Vec::new());
47
48        let conds: Vec<_> = builder.forwarded.into_iter().map(|(x, _)| x).collect();
49
50        let select = Rc::new(select);
51        let field = {
52            debug_assert_eq!(fields.len(), 1);
53            fields.swap_remove(0)
54        };
55
56        Expr::<S, _>::adhoc(move |b| {
57            sea_query::Expr::col((b.get_aggr(select.clone(), conds.clone()), field))
58        })
59        .inner
60    }
61
62    /// Return the average value in a column, this is [None] if there are zero rows.
63    pub fn avg(&self, val: impl IntoExpr<'inner, S, Typ = f64>) -> Expr<'outer, S, Option<f64>> {
64        let val = val.into_expr().inner;
65        Expr::new(self.select(move |b| Func::avg(val.build_expr(b)).into()))
66    }
67
68    /// Return the maximum value in a column, this is [None] if there are zero rows.
69    pub fn max<T>(&self, val: impl IntoExpr<'inner, S, Typ = T>) -> Expr<'outer, S, Option<T>>
70    where
71        T: EqTyp,
72    {
73        let val = val.into_expr().inner;
74        Expr::new(self.select(move |b| Func::max(val.build_expr(b)).into()))
75    }
76
77    /// Return the minimum value in a column, this is [None] if there are zero rows.
78    pub fn min<T>(&self, val: impl IntoExpr<'inner, S, Typ = T>) -> Expr<'outer, S, Option<T>>
79    where
80        T: EqTyp,
81    {
82        let val = val.into_expr().inner;
83        Expr::new(self.select(move |b| Func::min(val.build_expr(b)).into()))
84    }
85
86    /// Return the sum of a column.
87    pub fn sum<T>(&self, val: impl IntoExpr<'inner, S, Typ = T>) -> Expr<'outer, S, T>
88    where
89        T: NumTyp,
90    {
91        let val = val.into_expr().inner;
92        let val = self.select::<T>(move |b| Func::sum(val.build_expr(b)).into());
93
94        Expr::adhoc(move |b| {
95            sea_query::Expr::expr(val.build_expr(b)).if_null(sea_query::Expr::Constant(T::ZERO))
96        })
97    }
98
99    /// Return the number of distinct values in a column.
100    pub fn count_distinct<T: EqTyp + 'static>(
101        &self,
102        val: impl IntoExpr<'inner, S, Typ = T>,
103    ) -> Expr<'outer, S, i64> {
104        let val = val.into_expr().inner;
105        let val = self.select::<i64>(move |b| Func::count_distinct(val.build_expr(b)).into());
106        Expr::adhoc(move |b| {
107            sea_query::Expr::expr(val.build_expr(b)).if_null(sea_query::Expr::Constant(0i64.into()))
108        })
109    }
110
111    /// Return whether there are any rows.
112    pub fn exists(&self) -> Expr<'outer, S, bool> {
113        let val = self.select::<i64>(|_| Func::count(sea_query::Expr::col(Asterisk)).into());
114        Expr::adhoc(move |b| sea_query::Expr::expr(val.build_expr(b)).is_not_null())
115    }
116}
117
118/// Perform an aggregate that returns a single result for each of the current rows.
119///
120/// You can filter the rows in the aggregate based on values from the outer query.
121/// That is the only way to get a different aggregate for each outer row.
122///
123/// ```
124/// # use rust_query::aggregate;
125/// # use rust_query::private::doctest::*;
126/// # rust_query::private::doctest::get_txn(|txn| {
127/// let res = txn.query_one(aggregate(|rows| {
128///     let user = rows.join(User);
129///     rows.count_distinct(user)
130/// }));
131/// assert_eq!(res, 1, "there is one user in the database");
132/// # });
133/// ```
134pub fn aggregate<'outer, S, F, R>(f: F) -> R
135where
136    F: for<'inner> FnOnce(&mut Aggregate<'outer, 'inner, S>) -> R,
137{
138    let inner = Rows {
139        phantom: PhantomData,
140        ast: Default::default(),
141        _p: PhantomData,
142    };
143    let mut group = Aggregate {
144        query: inner,
145        _p: PhantomData,
146    };
147    f(&mut group)
148}