rust_query/value/aggregate.rs
1use std::{
2 marker::PhantomData,
3 ops::{Deref, DerefMut},
4 rc::Rc,
5};
6
7use crate::{
8 Expr, IntoExpr, lower,
9 rows::Rows,
10 value::{EqTyp, NumTyp},
11};
12
13/// This is the argument type used for [aggregate].
14pub struct Aggregate<'outer, 'inner, S> {
15 pub(crate) query: Rows<'inner, S>,
16 _p: PhantomData<&'inner &'outer ()>,
17}
18
19impl<'inner, S> Deref for Aggregate<'_, 'inner, S> {
20 type Target = Rows<'inner, S>;
21
22 fn deref(&self) -> &Self::Target {
23 &self.query
24 }
25}
26
27impl<S> DerefMut for Aggregate<'_, '_, S> {
28 fn deref_mut(&mut self) -> &mut Self::Target {
29 &mut self.query
30 }
31}
32
33impl<'outer, 'inner, S: 'static> Aggregate<'outer, 'inner, S> {
34 /// This must be used with an aggregating expression.
35 /// otherwise there is a chance that there are multiple rows.
36 fn select_func(&self, agg_func: &'static str, val: Rc<lower::Expr>) -> Rc<lower::Expr> {
37 let expr = Rc::new(lower::Expr::Func(agg_func, Box::new([val])));
38 Rc::new(lower::Expr::AggrIndex(self.ast.clone(), expr))
39 }
40
41 /// Return the average value in a column, this is [None] if there are zero rows.
42 ///
43 /// ```
44 /// # use rust_query::private::doctest_aggregate::*;
45 /// # get_txn(|txn| {
46 /// for x in [1, 2, 3] {
47 /// txn.insert_ok(Val { x });
48 /// }
49 /// let (avg1, avg2) = txn.query_one(aggregate(|rows| {
50 /// let val = rows.join(Val);
51 /// let avg1 = rows.avg(val.x.to_f64());
52 /// rows.filter(false); // remove all rows
53 /// let avg2 = rows.avg(val.x.to_f64());
54 /// (avg1, avg2)
55 /// }));
56 /// assert_eq!(avg1, Some(2.0));
57 /// assert_eq!(avg2, None);
58 /// # });
59 /// ```
60 pub fn avg(&self, val: impl IntoExpr<'inner, S, Typ = f64>) -> Expr<'outer, S, Option<f64>> {
61 let val = val.into_expr().inner;
62 Expr::new(self.select_func("avg", val))
63 }
64
65 /// Return the maximum value in a column, this is [None] if there are zero rows.
66 ///
67 /// ```
68 /// # use rust_query::private::doctest_aggregate::*;
69 /// # get_txn(|txn| {
70 /// for x in [-100, 10, 42] {
71 /// txn.insert_ok(Val { x });
72 /// }
73 /// let (max1, max2) = txn.query_one(aggregate(|rows| {
74 /// let val = rows.join(Val);
75 /// let max1 = rows.max(&val.x);
76 /// rows.filter(false); // remove all rows
77 /// let max2 = rows.max(&val.x);
78 /// (max1, max2)
79 /// }));
80 /// assert_eq!(max1, Some(42));
81 /// assert_eq!(max2, None);
82 /// # });
83 /// ```
84 pub fn max<T>(&self, val: impl IntoExpr<'inner, S, Typ = T>) -> Expr<'outer, S, Option<T>>
85 where
86 T: EqTyp,
87 {
88 let val = val.into_expr().inner;
89 Expr::new(self.select_func("max", val))
90 }
91
92 /// Return the minimum value in a column, this is [None] if there are zero rows.
93 ///
94 /// ```
95 /// # use rust_query::private::doctest_aggregate::*;
96 /// # get_txn(|txn| {
97 /// for x in [-100, 10, 42] {
98 /// txn.insert_ok(Val { x });
99 /// }
100 /// let (min1, min2) = txn.query_one(aggregate(|rows| {
101 /// let val = rows.join(Val);
102 /// let min1 = rows.min(&val.x);
103 /// rows.filter(false); // remove all rows
104 /// let min2 = rows.min(&val.x);
105 /// (min1, min2)
106 /// }));
107 /// assert_eq!(min1, Some(-100));
108 /// assert_eq!(min2, None);
109 /// # });
110 /// ```
111 pub fn min<T>(&self, val: impl IntoExpr<'inner, S, Typ = T>) -> Expr<'outer, S, Option<T>>
112 where
113 T: EqTyp,
114 {
115 let val = val.into_expr().inner;
116 Expr::new(self.select_func("min", val))
117 }
118
119 /// Return the sum of a column.
120 ///
121 /// ```
122 /// # use rust_query::private::doctest_aggregate::*;
123 /// # get_txn(|txn| {
124 /// for x in [-100, 10, 42] {
125 /// txn.insert_ok(Val { x });
126 /// }
127 /// let (sum1, sum2) = txn.query_one(aggregate(|rows| {
128 /// let val = rows.join(Val);
129 /// let sum1 = rows.sum(&val.x);
130 /// rows.filter(false); // remove all rows
131 /// let sum2 = rows.sum(&val.x);
132 /// (sum1, sum2)
133 /// }));
134 /// assert_eq!(sum1, -48);
135 /// assert_eq!(sum2, 0);
136 /// # });
137 /// ```
138 pub fn sum<T>(&self, val: impl IntoExpr<'inner, S, Typ = T>) -> Expr<'outer, S, T>
139 where
140 T: NumTyp,
141 {
142 let val = val.into_expr().inner;
143 let val = self.select_func("sum", val);
144
145 Expr::adhoc(lower::Expr::Func(
146 "IFNULL",
147 Box::new([val, Rc::new(lower::Expr::Constant(T::ZERO))]),
148 ))
149 }
150
151 /// Return the number of distinct values in a column.
152 ///
153 /// ```
154 /// # use rust_query::private::doctest_aggregate::*;
155 /// # get_txn(|txn| {
156 /// for x in [-100, 10, 42, 10] {
157 /// txn.insert_ok(Val { x });
158 /// }
159 /// let (count1, count2) = txn.query_one(aggregate(|rows| {
160 /// let val = rows.join(Val);
161 /// let count1 = rows.count_distinct(&val.x);
162 /// rows.filter(false); // remove all rows
163 /// let count2 = rows.count_distinct(&val.x);
164 /// (count1, count2)
165 /// }));
166 /// assert_eq!(count1, 3);
167 /// assert_eq!(count2, 0);
168 /// # });
169 /// ```
170 pub fn count_distinct<T: EqTyp + 'static>(
171 &self,
172 val: impl IntoExpr<'inner, S, Typ = T>,
173 ) -> Expr<'outer, S, i64> {
174 let val = val.into_expr().inner;
175 let val = self.select_func("COUNT", Rc::new(lower::Expr::Prefix("DISTINCT ", val)));
176 // technically the `if_null` here is only required for correlated sub queries
177 Expr::adhoc(lower::Expr::Func(
178 "IFNULL",
179 Box::new([val, Rc::new(lower::Expr::Constant(i64::ZERO))]),
180 ))
181 }
182
183 /// Return whether there are any rows.
184 ///
185 /// ```
186 /// # use rust_query::private::doctest_aggregate::*;
187 /// # get_txn(|txn| {
188 /// for x in [10, 42, 10] {
189 /// txn.insert_ok(Val { x });
190 /// }
191 /// let (e1, e2) = txn.query_one(aggregate(|rows| {
192 /// rows.join(Val);
193 /// let e1 = rows.exists();
194 /// rows.filter(false); // removes all rows
195 /// let e2 = rows.exists();
196 /// (e1, e2)
197 /// }));
198 /// assert_eq!(e1, true);
199 /// assert_eq!(e2, false);
200 /// # });
201 /// ```
202 pub fn exists(&self) -> Expr<'outer, S, bool> {
203 let zero_expr = Expr::<_, i64>::adhoc(lower::CONST_0);
204 self.count_distinct(zero_expr.clone()).neq(zero_expr)
205 }
206}
207
208/// Perform an aggregate that returns a single result for each of the current rows.
209///
210/// One can filter the rows in the aggregate based on values from the outer query.
211/// See the documentation for [Aggregate] for more information.
212///
213/// ```
214/// # use rust_query::migration::{schema, Config};
215/// # use rust_query::{Database, aggregate};
216/// #[schema(Site)]
217/// pub mod vN {
218/// pub struct Review {
219/// #[index]
220/// pub book: rust_query::TableRow<Book>,
221/// pub rating: f64,
222/// }
223/// pub struct Book {
224/// pub name: String
225/// }
226/// }
227/// use v0::*;
228///
229/// Database::new(Config::open_in_memory()).transaction(|txn| {
230/// let books = txn.query(|rows| {
231/// let book = rows.join(Book);
232/// let rating = aggregate(|aggr| {
233/// let review = aggr.join(Review.book(&book));
234/// // books without reviews will get a rating of 0.0
235/// aggr.avg(&review.rating).unwrap_or(0.0)
236/// });
237/// // top 10 highest rated books
238/// rows.order_by()
239/// .desc(rating)
240/// .into_iter(book)
241/// .take(10)
242/// });
243/// });
244/// ```
245pub fn aggregate<'outer, S, F, R>(f: F) -> R
246where
247 F: for<'inner> FnOnce(&mut Aggregate<'outer, 'inner, S>) -> R,
248{
249 let inner = Rows {
250 phantom: PhantomData,
251 ast: Default::default(),
252 _p: PhantomData,
253 };
254 let mut group = Aggregate {
255 query: inner,
256 _p: PhantomData,
257 };
258 f(&mut group)
259}