Skip to main content

umbral_core/orm/
aggregate.rs

1//! Aggregate functions for `QuerySet::aggregate` / `annotate`.
2//!
3//! `Aggregate` is a closed enum covering the SQL standard set the
4//! framework supports today — COUNT, SUM, AVG, MAX, MIN. Each variant
5//! knows the column it operates on (or `*` for COUNT) and renders to a
6//! sea-query `SimpleExpr` at terminal time.
7//!
8//! ```rust,ignore
9//! use umbral::orm::Aggregate;
10//!
11//! // Single-row aggregate: "how many published posts, and what's the
12//! // total view count?"
13//! let summary = Post::objects()
14//!     .filter(post::PUBLISHED.eq(true))
15//!     .aggregate(&[
16//!         ("count", Aggregate::count()),
17//!         ("views", Aggregate::sum("view_count")),
18//!     ])
19//!     .await?;
20//!
21//! // Grouped: "post count per author."
22//! let by_author = Post::objects()
23//!     .annotate(&["author_id"], &[("count", Aggregate::count())])
24//!     .await?;
25//! ```
26//!
27//! StdDev / Variance / window-function aggregates are deferred. Add a
28//! new variant when a real consumer surfaces the need.
29
30use sea_query::{Alias, Expr, Func, SimpleExpr};
31
32/// A SQL aggregate function over a single column (or `*` for COUNT).
33///
34/// Built via the named constructors on this type and passed to
35/// [`crate::orm::QuerySet::aggregate`] or
36/// [`crate::orm::QuerySet::annotate`] paired with an output name.
37#[derive(Debug, Clone)]
38pub enum Aggregate {
39    /// `COUNT(*)` when `column` is `None`, `COUNT(col)` when set
40    /// (which skips NULLs in the named column).
41    Count(Option<String>),
42    /// `SUM(col)` — NULL when no rows match.
43    Sum(String),
44    /// `AVG(col)` — always renders to a floating-point result type.
45    Avg(String),
46    /// `MAX(col)` — same return type as the column.
47    Max(String),
48    /// `MIN(col)` — same return type as the column.
49    Min(String),
50}
51
52impl Aggregate {
53    /// `COUNT(*)` — every row, including those with NULL columns.
54    pub fn count() -> Self {
55        Aggregate::Count(None)
56    }
57
58    /// `COUNT(col)` — skips rows where `col` is NULL.
59    pub fn count_col(name: impl Into<String>) -> Self {
60        Aggregate::Count(Some(name.into()))
61    }
62
63    /// `SUM(col)`.
64    pub fn sum(name: impl Into<String>) -> Self {
65        Aggregate::Sum(name.into())
66    }
67
68    /// `AVG(col)`.
69    pub fn avg(name: impl Into<String>) -> Self {
70        Aggregate::Avg(name.into())
71    }
72
73    /// `MAX(col)`.
74    pub fn max(name: impl Into<String>) -> Self {
75        Aggregate::Max(name.into())
76    }
77
78    /// `MIN(col)`.
79    pub fn min(name: impl Into<String>) -> Self {
80        Aggregate::Min(name.into())
81    }
82
83    /// Source column for this aggregate, or `None` for `COUNT(*)`.
84    /// Used by the QuerySet terminals to validate against
85    /// `Model::FIELDS` before running any SQL.
86    pub fn source_column(&self) -> Option<&str> {
87        match self {
88            Aggregate::Count(c) => c.as_deref(),
89            Aggregate::Sum(c) | Aggregate::Avg(c) | Aggregate::Max(c) | Aggregate::Min(c) => {
90                Some(c.as_str())
91            }
92        }
93    }
94
95    /// Render to a `sea_query::SimpleExpr` for the SELECT list. Both
96    /// backends accept the same function names for the supported set.
97    pub fn to_simple_expr(&self) -> SimpleExpr {
98        match self {
99            Aggregate::Count(None) => Func::count(Expr::col(sea_query::Asterisk)).into(),
100            Aggregate::Count(Some(col)) => Func::count(Expr::col(Alias::new(col.as_str()))).into(),
101            Aggregate::Sum(col) => Func::sum(Expr::col(Alias::new(col.as_str()))).into(),
102            Aggregate::Avg(col) => Func::avg(Expr::col(Alias::new(col.as_str()))).into(),
103            Aggregate::Max(col) => Func::max(Expr::col(Alias::new(col.as_str()))).into(),
104            Aggregate::Min(col) => Func::min(Expr::col(Alias::new(col.as_str()))).into(),
105        }
106    }
107
108    /// One of `"count"`, `"sum"`, `"avg"`, `"max"`, `"min"` — used by
109    /// the terminal to dispatch row-decoding (COUNT always returns
110    /// i64, AVG always returns f64, SUM/MAX/MIN inherit the source
111    /// column's type).
112    pub fn kind(&self) -> AggregateKind {
113        match self {
114            Aggregate::Count(_) => AggregateKind::Count,
115            Aggregate::Sum(_) => AggregateKind::Sum,
116            Aggregate::Avg(_) => AggregateKind::Avg,
117            Aggregate::Max(_) => AggregateKind::Max,
118            Aggregate::Min(_) => AggregateKind::Min,
119        }
120    }
121}
122
123/// Discriminator for [`Aggregate`]. Carried separately so terminals
124/// can pattern-match without pulling string allocations.
125#[derive(Debug, Clone, Copy, PartialEq, Eq)]
126pub enum AggregateKind {
127    Count,
128    Sum,
129    Avg,
130    Max,
131    Min,
132}