rust_query/value/aggregate.rs
1use std::{
2 marker::PhantomData,
3 ops::{Deref, DerefMut},
4 rc::Rc,
5};
6
7use sea_query::Func;
8
9use crate::{
10 Expr, IntoExpr,
11 ast::CONST_0,
12 rows::Rows,
13 value::{AdHoc, EqTyp, NumTyp, ValueBuilder},
14};
15
16use super::DynTypedExpr;
17
18/// This is the argument type used for [aggregate].
19pub struct Aggregate<'outer, 'inner, S> {
20 pub(crate) query: Rows<'inner, S>,
21 _p: PhantomData<&'inner &'outer ()>,
22}
23
24impl<'inner, S> Deref for Aggregate<'_, 'inner, S> {
25 type Target = Rows<'inner, S>;
26
27 fn deref(&self) -> &Self::Target {
28 &self.query
29 }
30}
31
32impl<S> DerefMut for Aggregate<'_, '_, S> {
33 fn deref_mut(&mut self) -> &mut Self::Target {
34 &mut self.query
35 }
36}
37
38impl<'outer, 'inner, S: 'static> Aggregate<'outer, 'inner, S> {
39 /// This must be used with an aggregating expression.
40 /// otherwise there is a chance that there are multiple rows.
41 fn select<T: EqTyp>(
42 &self,
43 expr: impl 'static + Fn(&mut ValueBuilder) -> sea_query::Expr,
44 ) -> Rc<AdHoc<dyn Fn(&mut ValueBuilder) -> sea_query::Expr, Option<T>>> {
45 let expr = DynTypedExpr::new(expr);
46 let mut builder = self.query.ast.clone().full();
47 let (select, mut fields) = builder.build_select(vec![expr], Vec::new());
48
49 let conds: Vec<_> = builder.forwarded.into_iter().map(|(x, _)| x).collect();
50
51 let select = Rc::new(select);
52 let field = {
53 debug_assert_eq!(fields.len(), 1);
54 fields.swap_remove(0)
55 };
56
57 Expr::<S, _>::adhoc(move |b| {
58 sea_query::Expr::col((b.get_aggr(select.clone(), conds.clone()), field))
59 })
60 .inner
61 }
62
63 /// Return the average value in a column, this is [None] if there are zero rows.
64 ///
65 /// ```
66 /// # use rust_query::private::doctest_aggregate::*;
67 /// # get_txn(|txn| {
68 /// for x in [1, 2, 3] {
69 /// txn.insert_ok(Val { x });
70 /// }
71 /// let (avg1, avg2) = txn.query_one(aggregate(|rows| {
72 /// let val = rows.join(Val);
73 /// let avg1 = rows.avg(val.x.to_f64());
74 /// rows.filter(false); // remove all rows
75 /// let avg2 = rows.avg(val.x.to_f64());
76 /// (avg1, avg2)
77 /// }));
78 /// assert_eq!(avg1, Some(2.0));
79 /// assert_eq!(avg2, None);
80 /// # });
81 /// ```
82 pub fn avg(&self, val: impl IntoExpr<'inner, S, Typ = f64>) -> Expr<'outer, S, Option<f64>> {
83 let val = val.into_expr().inner;
84 Expr::new(self.select(move |b| Func::avg(val.build_expr(b)).into()))
85 }
86
87 /// Return the maximum value in a column, this is [None] if there are zero rows.
88 ///
89 /// ```
90 /// # use rust_query::private::doctest_aggregate::*;
91 /// # get_txn(|txn| {
92 /// for x in [-100, 10, 42] {
93 /// txn.insert_ok(Val { x });
94 /// }
95 /// let (max1, max2) = txn.query_one(aggregate(|rows| {
96 /// let val = rows.join(Val);
97 /// let max1 = rows.max(&val.x);
98 /// rows.filter(false); // remove all rows
99 /// let max2 = rows.max(&val.x);
100 /// (max1, max2)
101 /// }));
102 /// assert_eq!(max1, Some(42));
103 /// assert_eq!(max2, None);
104 /// # });
105 /// ```
106 pub fn max<T>(&self, val: impl IntoExpr<'inner, S, Typ = T>) -> Expr<'outer, S, Option<T>>
107 where
108 T: EqTyp,
109 {
110 let val = val.into_expr().inner;
111 Expr::new(self.select(move |b| Func::max(val.build_expr(b)).into()))
112 }
113
114 /// Return the minimum value in a column, this is [None] if there are zero rows.
115 ///
116 /// ```
117 /// # use rust_query::private::doctest_aggregate::*;
118 /// # get_txn(|txn| {
119 /// for x in [-100, 10, 42] {
120 /// txn.insert_ok(Val { x });
121 /// }
122 /// let (min1, min2) = txn.query_one(aggregate(|rows| {
123 /// let val = rows.join(Val);
124 /// let min1 = rows.min(&val.x);
125 /// rows.filter(false); // remove all rows
126 /// let min2 = rows.min(&val.x);
127 /// (min1, min2)
128 /// }));
129 /// assert_eq!(min1, Some(-100));
130 /// assert_eq!(min2, None);
131 /// # });
132 /// ```
133 pub fn min<T>(&self, val: impl IntoExpr<'inner, S, Typ = T>) -> Expr<'outer, S, Option<T>>
134 where
135 T: EqTyp,
136 {
137 let val = val.into_expr().inner;
138 Expr::new(self.select(move |b| Func::min(val.build_expr(b)).into()))
139 }
140
141 /// Return the sum of a column.
142 ///
143 /// ```
144 /// # use rust_query::private::doctest_aggregate::*;
145 /// # get_txn(|txn| {
146 /// for x in [-100, 10, 42] {
147 /// txn.insert_ok(Val { x });
148 /// }
149 /// let (sum1, sum2) = txn.query_one(aggregate(|rows| {
150 /// let val = rows.join(Val);
151 /// let sum1 = rows.sum(&val.x);
152 /// rows.filter(false); // remove all rows
153 /// let sum2 = rows.sum(&val.x);
154 /// (sum1, sum2)
155 /// }));
156 /// assert_eq!(sum1, -48);
157 /// assert_eq!(sum2, 0);
158 /// # });
159 /// ```
160 pub fn sum<T>(&self, val: impl IntoExpr<'inner, S, Typ = T>) -> Expr<'outer, S, T>
161 where
162 T: NumTyp,
163 {
164 let val = val.into_expr().inner;
165 let val = self.select::<T>(move |b| Func::sum(val.build_expr(b)).into());
166
167 Expr::adhoc(move |b| {
168 sea_query::Expr::expr(val.build_expr(b)).if_null(sea_query::Expr::Constant(T::ZERO))
169 })
170 }
171
172 /// Return the number of distinct values in a column.
173 ///
174 /// ```
175 /// # use rust_query::private::doctest_aggregate::*;
176 /// # get_txn(|txn| {
177 /// for x in [-100, 10, 42, 10] {
178 /// txn.insert_ok(Val { x });
179 /// }
180 /// let (count1, count2) = txn.query_one(aggregate(|rows| {
181 /// let val = rows.join(Val);
182 /// let count1 = rows.count_distinct(&val.x);
183 /// rows.filter(false); // remove all rows
184 /// let count2 = rows.count_distinct(&val.x);
185 /// (count1, count2)
186 /// }));
187 /// assert_eq!(count1, 3);
188 /// assert_eq!(count2, 0);
189 /// # });
190 /// ```
191 pub fn count_distinct<T: EqTyp + 'static>(
192 &self,
193 val: impl IntoExpr<'inner, S, Typ = T>,
194 ) -> Expr<'outer, S, i64> {
195 let val = val.into_expr().inner;
196 let val = self.select::<i64>(move |b| Func::count_distinct(val.build_expr(b)).into());
197 Expr::adhoc(move |b| {
198 // technically the `if_null` here is only required for correlated sub queries
199 sea_query::Expr::expr(val.build_expr(b)).if_null(sea_query::Expr::Constant(0i64.into()))
200 })
201 }
202
203 /// Return whether there are any rows.
204 ///
205 /// ```
206 /// # use rust_query::private::doctest_aggregate::*;
207 /// # get_txn(|txn| {
208 /// for x in [10, 42, 10] {
209 /// txn.insert_ok(Val { x });
210 /// }
211 /// let (e1, e2) = txn.query_one(aggregate(|rows| {
212 /// rows.join(Val);
213 /// let e1 = rows.exists();
214 /// rows.filter(false); // removes all rows
215 /// let e2 = rows.exists();
216 /// (e1, e2)
217 /// }));
218 /// assert_eq!(e1, true);
219 /// assert_eq!(e2, false);
220 /// # });
221 /// ```
222 pub fn exists(&self) -> Expr<'outer, S, bool> {
223 let zero_expr = Expr::<_, i64>::adhoc(|_| CONST_0);
224 self.count_distinct(zero_expr.clone()).neq(zero_expr)
225 }
226}
227
228/// Perform an aggregate that returns a single result for each of the current rows.
229///
230/// One can filter the rows in the aggregate based on values from the outer query.
231/// See the documentation for [Aggregate] for more information.
232///
233/// ```
234/// # use rust_query::migration::{schema, Config};
235/// # use rust_query::{Database, aggregate};
236/// #[schema(Site)]
237/// pub mod vN {
238/// pub struct Review {
239/// #[index]
240/// pub book: rust_query::TableRow<Book>,
241/// pub rating: f64,
242/// }
243/// pub struct Book {
244/// pub name: String
245/// }
246/// }
247/// use v0::*;
248///
249/// Database::new(Config::open_in_memory()).transaction(|txn| {
250/// let books = txn.query(|rows| {
251/// let book = rows.join(Book);
252/// let rating = aggregate(|aggr| {
253/// let review = aggr.join(Review.book(&book));
254/// // books without reviews will get a rating of 0.0
255/// aggr.avg(&review.rating).unwrap_or(0.0)
256/// });
257/// // top 10 highest rated books
258/// rows.order_by()
259/// .desc(rating)
260/// .into_iter(book)
261/// .take(10)
262/// });
263/// });
264/// ```
265pub fn aggregate<'outer, S, F, R>(f: F) -> R
266where
267 F: for<'inner> FnOnce(&mut Aggregate<'outer, 'inner, S>) -> R,
268{
269 let inner = Rows {
270 phantom: PhantomData,
271 ast: Default::default(),
272 _p: PhantomData,
273 };
274 let mut group = Aggregate {
275 query: inner,
276 _p: PhantomData,
277 };
278 f(&mut group)
279}