rust_query/value/
aggregate.rs

1use std::{
2    marker::PhantomData,
3    ops::{Deref, DerefMut},
4    rc::Rc,
5};
6
7use sea_query::{ExprTrait, Func, SelectStatement};
8
9use crate::{
10    Expr,
11    alias::MyAlias,
12    rows::Rows,
13    value::{EqTyp, IntoExpr, MyTyp, NumTyp, Typed, ValueBuilder},
14};
15
16use super::DynTypedExpr;
17
18/// This is the argument type used for [aggregate].
19pub struct Aggregate<'outer, 'inner, S> {
20    pub(crate) query: Rows<'inner, S>,
21    _p: PhantomData<&'inner &'outer ()>,
22}
23
24impl<'inner, S> Deref for Aggregate<'_, 'inner, S> {
25    type Target = Rows<'inner, S>;
26
27    fn deref(&self) -> &Self::Target {
28        &self.query
29    }
30}
31
32impl<S> DerefMut for Aggregate<'_, '_, S> {
33    fn deref_mut(&mut self) -> &mut Self::Target {
34        &mut self.query
35    }
36}
37
38impl<'outer, 'inner, S: 'static> Aggregate<'outer, 'inner, S> {
39    fn select<T>(
40        &self,
41        expr: impl 'static + Fn(&mut ValueBuilder) -> sea_query::Expr,
42    ) -> Aggr<S, Option<T>> {
43        let expr = DynTypedExpr(Rc::new(expr));
44        let mut builder = self.query.ast.clone().full();
45        let (select, mut fields) = builder.build_select(true, vec![expr]);
46
47        let conds = builder.forwarded.into_iter().map(|x| x.1.1).collect();
48
49        Aggr {
50            _p2: PhantomData,
51            select: Rc::new(select),
52            field: {
53                debug_assert_eq!(fields.len(), 1);
54                fields.swap_remove(0)
55            },
56            conds,
57        }
58    }
59
60    /// Return the average value in a column, this is [None] if there are zero rows.
61    pub fn avg(&self, val: impl IntoExpr<'inner, S, Typ = f64>) -> Expr<'outer, S, Option<f64>> {
62        let val = val.into_expr().inner;
63        Expr::new(self.select(move |b| Func::avg(val.build_expr(b)).into()))
64    }
65
66    /// Return the maximum value in a column, this is [None] if there are zero rows.
67    pub fn max<T>(&self, val: impl IntoExpr<'inner, S, Typ = T>) -> Expr<'outer, S, Option<T>>
68    where
69        T: NumTyp,
70    {
71        let val = val.into_expr().inner;
72        Expr::new(self.select(move |b| Func::max(val.build_expr(b)).into()))
73    }
74
75    /// Return the minimum value in a column, this is [None] if there are zero rows.
76    pub fn min<T>(&self, val: impl IntoExpr<'inner, S, Typ = T>) -> Expr<'outer, S, Option<T>>
77    where
78        T: NumTyp,
79    {
80        let val = val.into_expr().inner;
81        Expr::new(self.select(move |b| Func::min(val.build_expr(b)).into()))
82    }
83
84    /// Return the sum of a column.
85    pub fn sum<T>(&self, val: impl IntoExpr<'inner, S, Typ = T>) -> Expr<'outer, S, T>
86    where
87        T: NumTyp,
88    {
89        let val = val.into_expr().inner;
90        let val = self.select::<T>(move |b| Func::sum(val.build_expr(b)).into());
91
92        Expr::adhoc(move |b| {
93            sea_query::Expr::expr(val.build_expr(b))
94                .if_null(sea_query::Expr::Constant(T::ZERO.into_sea_value()))
95        })
96    }
97
98    /// Return the number of distinct values in a column.
99    pub fn count_distinct<T: EqTyp + 'static>(
100        &self,
101        val: impl IntoExpr<'inner, S, Typ = T>,
102    ) -> Expr<'outer, S, i64> {
103        let val = val.into_expr().inner;
104        let val = self.select::<i64>(move |b| Func::count_distinct(val.build_expr(b)).into());
105        Expr::adhoc(move |b| {
106            sea_query::Expr::expr(val.build_expr(b))
107                .if_null(sea_query::Expr::Constant(0i64.into_sea_value()))
108        })
109    }
110
111    /// Return whether there are any rows.
112    pub fn exists(&self) -> Expr<'outer, S, bool> {
113        let val = self.select::<i64>(|_| sea_query::Expr::Constant(1.into_sea_value()));
114        Expr::adhoc(move |b| sea_query::Expr::expr(val.build_expr(b)).is_not_null())
115    }
116}
117
118pub struct Aggr<S, T> {
119    pub(crate) _p2: PhantomData<(S, T)>,
120    pub(crate) select: Rc<SelectStatement>,
121    pub(crate) conds: Vec<DynTypedExpr>,
122    pub(crate) field: MyAlias,
123}
124
125impl<S, T> Clone for Aggr<S, T> {
126    fn clone(&self) -> Self {
127        Self {
128            _p2: PhantomData,
129            select: self.select.clone(),
130            conds: self.conds.clone(),
131            field: self.field,
132        }
133    }
134}
135
136impl<S, T: MyTyp> Typed for Aggr<S, T> {
137    type Typ = T;
138    fn build_expr(&self, b: &mut ValueBuilder) -> sea_query::Expr {
139        sea_query::Expr::col((self.build_table(b), self.field)).into()
140    }
141}
142
143impl<S, T> Aggr<S, T> {
144    fn build_table(&self, b: &mut ValueBuilder) -> MyAlias {
145        let conds = self.conds.iter().map(|expr| (expr.0)(b)).collect();
146        b.get_aggr(self.select.clone(), conds)
147    }
148}
149
150/// Perform an aggregate that returns a single result for each of the current rows.
151///
152/// You can filter the rows in the aggregate based on values from the outer query.
153/// That is the only way to get a different aggregate for each outer row.
154///
155/// ```
156/// # use rust_query::aggregate;
157/// # use rust_query::private::doctest::*;
158/// # rust_query::private::doctest::get_txn(|txn| {
159/// let res = txn.query_one(aggregate(|rows| {
160///     let user = rows.join(User);
161///     rows.count_distinct(user)
162/// }));
163/// assert_eq!(res, 1, "there is one user in the database");
164/// # });
165/// ```
166pub fn aggregate<'outer, S, F, R>(f: F) -> R
167where
168    F: for<'inner> FnOnce(&mut Aggregate<'outer, 'inner, S>) -> R,
169{
170    let inner = Rows {
171        phantom: PhantomData,
172        ast: Default::default(),
173        _p: PhantomData,
174    };
175    let mut group = Aggregate {
176        query: inner,
177        _p: PhantomData,
178    };
179    f(&mut group)
180}