Skip to main content

rust_query/value/
aggregate.rs

1use std::{
2    marker::PhantomData,
3    ops::{Deref, DerefMut},
4    rc::Rc,
5};
6
7use sea_query::Func;
8
9use crate::{
10    Expr, IntoExpr,
11    ast::CONST_0,
12    rows::Rows,
13    value::{AdHoc, EqTyp, NumTyp, ValueBuilder},
14};
15
16use super::DynTypedExpr;
17
18/// This is the argument type used for [aggregate].
19pub struct Aggregate<'outer, 'inner, S> {
20    pub(crate) query: Rows<'inner, S>,
21    _p: PhantomData<&'inner &'outer ()>,
22}
23
24impl<'inner, S> Deref for Aggregate<'_, 'inner, S> {
25    type Target = Rows<'inner, S>;
26
27    fn deref(&self) -> &Self::Target {
28        &self.query
29    }
30}
31
32impl<S> DerefMut for Aggregate<'_, '_, S> {
33    fn deref_mut(&mut self) -> &mut Self::Target {
34        &mut self.query
35    }
36}
37
38impl<'outer, 'inner, S: 'static> Aggregate<'outer, 'inner, S> {
39    /// This must be used with an aggregating expression.
40    /// otherwise there is a chance that there are multiple rows.
41    fn select<T: EqTyp>(
42        &self,
43        expr: impl 'static + Fn(&mut ValueBuilder) -> sea_query::Expr,
44    ) -> Rc<AdHoc<dyn Fn(&mut ValueBuilder) -> sea_query::Expr, Option<T>>> {
45        let expr = DynTypedExpr::new(expr);
46        let mut builder = self.query.ast.clone().full();
47        let (select, mut fields) = builder.build_select(vec![expr], Vec::new());
48
49        let conds: Vec<_> = builder.forwarded.into_iter().map(|(x, _)| x).collect();
50
51        let select = Rc::new(select);
52        let field = {
53            debug_assert_eq!(fields.len(), 1);
54            fields.swap_remove(0)
55        };
56
57        Expr::<S, _>::adhoc(move |b| {
58            sea_query::Expr::col((b.get_aggr(select.clone(), conds.clone()), field))
59        })
60        .inner
61    }
62
63    /// Return the average value in a column, this is [None] if there are zero rows.
64    ///
65    /// ```
66    /// # use rust_query::private::doctest_aggregate::*;
67    /// # get_txn(|txn| {
68    /// for x in [1, 2, 3] {
69    ///     txn.insert_ok(Val { x });
70    /// }
71    /// let (avg1, avg2) = txn.query_one(aggregate(|rows| {
72    ///     let val = rows.join(Val);
73    ///     let avg1 = rows.avg(val.x.to_f64());
74    ///     rows.filter(false); // remove all rows
75    ///     let avg2 = rows.avg(val.x.to_f64());
76    ///     (avg1, avg2)
77    /// }));
78    /// assert_eq!(avg1, Some(2.0));
79    /// assert_eq!(avg2, None);
80    /// # });
81    /// ```
82    pub fn avg(&self, val: impl IntoExpr<'inner, S, Typ = f64>) -> Expr<'outer, S, Option<f64>> {
83        let val = val.into_expr().inner;
84        Expr::new(self.select(move |b| Func::avg(val.build_expr(b)).into()))
85    }
86
87    /// Return the maximum value in a column, this is [None] if there are zero rows.
88    ///
89    /// ```
90    /// # use rust_query::private::doctest_aggregate::*;
91    /// # get_txn(|txn| {
92    /// for x in [-100, 10, 42] {
93    ///     txn.insert_ok(Val { x });
94    /// }
95    /// let (max1, max2) = txn.query_one(aggregate(|rows| {
96    ///     let val = rows.join(Val);
97    ///     let max1 = rows.max(&val.x);
98    ///     rows.filter(false); // remove all rows
99    ///     let max2 = rows.max(&val.x);
100    ///     (max1, max2)
101    /// }));
102    /// assert_eq!(max1, Some(42));
103    /// assert_eq!(max2, None);
104    /// # });
105    /// ```
106    pub fn max<T>(&self, val: impl IntoExpr<'inner, S, Typ = T>) -> Expr<'outer, S, Option<T>>
107    where
108        T: EqTyp,
109    {
110        let val = val.into_expr().inner;
111        Expr::new(self.select(move |b| Func::max(val.build_expr(b)).into()))
112    }
113
114    /// Return the minimum value in a column, this is [None] if there are zero rows.
115    ///
116    /// ```
117    /// # use rust_query::private::doctest_aggregate::*;
118    /// # get_txn(|txn| {
119    /// for x in [-100, 10, 42] {
120    ///     txn.insert_ok(Val { x });
121    /// }
122    /// let (min1, min2) = txn.query_one(aggregate(|rows| {
123    ///     let val = rows.join(Val);
124    ///     let min1 = rows.min(&val.x);
125    ///     rows.filter(false); // remove all rows
126    ///     let min2 = rows.min(&val.x);
127    ///     (min1, min2)
128    /// }));
129    /// assert_eq!(min1, Some(-100));
130    /// assert_eq!(min2, None);
131    /// # });
132    /// ```
133    pub fn min<T>(&self, val: impl IntoExpr<'inner, S, Typ = T>) -> Expr<'outer, S, Option<T>>
134    where
135        T: EqTyp,
136    {
137        let val = val.into_expr().inner;
138        Expr::new(self.select(move |b| Func::min(val.build_expr(b)).into()))
139    }
140
141    /// Return the sum of a column.
142    ///
143    /// ```
144    /// # use rust_query::private::doctest_aggregate::*;
145    /// # get_txn(|txn| {
146    /// for x in [-100, 10, 42] {
147    ///     txn.insert_ok(Val { x });
148    /// }
149    /// let (sum1, sum2) = txn.query_one(aggregate(|rows| {
150    ///     let val = rows.join(Val);
151    ///     let sum1 = rows.sum(&val.x);
152    ///     rows.filter(false); // remove all rows
153    ///     let sum2 = rows.sum(&val.x);
154    ///     (sum1, sum2)
155    /// }));
156    /// assert_eq!(sum1, -48);
157    /// assert_eq!(sum2, 0);
158    /// # });
159    /// ```
160    pub fn sum<T>(&self, val: impl IntoExpr<'inner, S, Typ = T>) -> Expr<'outer, S, T>
161    where
162        T: NumTyp,
163    {
164        let val = val.into_expr().inner;
165        let val = self.select::<T>(move |b| Func::sum(val.build_expr(b)).into());
166
167        Expr::adhoc(move |b| {
168            sea_query::Expr::expr(val.build_expr(b)).if_null(sea_query::Expr::Constant(T::ZERO))
169        })
170    }
171
172    /// Return the number of distinct values in a column.
173    ///
174    /// ```
175    /// # use rust_query::private::doctest_aggregate::*;
176    /// # get_txn(|txn| {
177    /// for x in [-100, 10, 42, 10] {
178    ///     txn.insert_ok(Val { x });
179    /// }
180    /// let (count1, count2) = txn.query_one(aggregate(|rows| {
181    ///     let val = rows.join(Val);
182    ///     let count1 = rows.count_distinct(&val.x);
183    ///     rows.filter(false); // remove all rows
184    ///     let count2 = rows.count_distinct(&val.x);
185    ///     (count1, count2)
186    /// }));
187    /// assert_eq!(count1, 3);
188    /// assert_eq!(count2, 0);
189    /// # });
190    /// ```
191    pub fn count_distinct<T: EqTyp + 'static>(
192        &self,
193        val: impl IntoExpr<'inner, S, Typ = T>,
194    ) -> Expr<'outer, S, i64> {
195        let val = val.into_expr().inner;
196        let val = self.select::<i64>(move |b| Func::count_distinct(val.build_expr(b)).into());
197        Expr::adhoc(move |b| {
198            // technically the `if_null` here is only required for correlated sub queries
199            sea_query::Expr::expr(val.build_expr(b)).if_null(sea_query::Expr::Constant(0i64.into()))
200        })
201    }
202
203    /// Return whether there are any rows.
204    ///
205    /// ```
206    /// # use rust_query::private::doctest_aggregate::*;
207    /// # get_txn(|txn| {
208    /// for x in [10, 42, 10] {
209    ///     txn.insert_ok(Val { x });
210    /// }
211    /// let (e1, e2) = txn.query_one(aggregate(|rows| {
212    ///     rows.join(Val);
213    ///     let e1 = rows.exists();
214    ///     rows.filter(false); // removes all rows
215    ///     let e2 = rows.exists();
216    ///     (e1, e2)
217    /// }));
218    /// assert_eq!(e1, true);
219    /// assert_eq!(e2, false);
220    /// # });
221    /// ```
222    pub fn exists(&self) -> Expr<'outer, S, bool> {
223        let zero_expr = Expr::<_, i64>::adhoc(|_| CONST_0);
224        self.count_distinct(zero_expr.clone()).neq(zero_expr)
225    }
226}
227
228/// Perform an aggregate that returns a single result for each of the current rows.
229///
230/// One can filter the rows in the aggregate based on values from the outer query.
231/// See the documentation for [Aggregate] for more information.
232///
233/// ```
234/// # use rust_query::migration::{schema, Config};
235/// # use rust_query::{Database, aggregate};
236/// #[schema(Site)]
237/// pub mod vN {
238///     pub struct Review {
239///         #[index]
240///         pub book: rust_query::TableRow<Book>,
241///         pub rating: f64,
242///     }
243///     pub struct Book {
244///         pub name: String
245///     }
246/// }
247/// use v0::*;
248///
249/// Database::new(Config::open_in_memory()).transaction(|txn| {
250///     let books = txn.query(|rows| {
251///         let book = rows.join(Book);
252///         let rating = aggregate(|aggr| {
253///             let review = aggr.join(Review.book(&book));
254///             // books without reviews will get a rating of 0.0
255///             aggr.avg(&review.rating).unwrap_or(0.0)
256///         });
257///         // top 10 highest rated books
258///         rows.order_by()
259///             .desc(rating)
260///             .into_iter(book)
261///             .take(10)
262///     });
263/// });
264/// ```
265pub fn aggregate<'outer, S, F, R>(f: F) -> R
266where
267    F: for<'inner> FnOnce(&mut Aggregate<'outer, 'inner, S>) -> R,
268{
269    let inner = Rows {
270        phantom: PhantomData,
271        ast: Default::default(),
272        _p: PhantomData,
273    };
274    let mut group = Aggregate {
275        query: inner,
276        _p: PhantomData,
277    };
278    f(&mut group)
279}