sql_cli/sql/aggregates/
mod.rs1use anyhow::{anyhow, Result};
7
8use crate::data::datatable::DataValue;
9
10pub mod analytics;
11pub mod functions;
12
13#[derive(Debug, Clone)]
15pub enum AggregateState {
16 Count(i64),
17 Sum(SumState),
18 Avg(AvgState),
19 MinMax(MinMaxState),
20 CollectList(Vec<DataValue>),
21 Analytics(analytics::AnalyticsState),
22}
23
24#[derive(Debug, Clone)]
26pub struct SumState {
27 pub int_sum: Option<i64>,
28 pub float_sum: Option<f64>,
29 pub has_values: bool,
30}
31
32impl SumState {
33 pub fn new() -> Self {
34 Self {
35 int_sum: None,
36 float_sum: None,
37 has_values: false,
38 }
39 }
40
41 pub fn add(&mut self, value: &DataValue) -> Result<()> {
42 match value {
43 DataValue::Null => Ok(()), DataValue::Integer(n) => {
45 self.has_values = true;
46 if let Some(ref mut sum) = self.int_sum {
47 *sum = sum.saturating_add(*n);
48 } else if let Some(ref mut fsum) = self.float_sum {
49 *fsum += *n as f64;
50 } else {
51 self.int_sum = Some(*n);
52 }
53 Ok(())
54 }
55 DataValue::Float(f) => {
56 self.has_values = true;
57 if let Some(isum) = self.int_sum.take() {
59 self.float_sum = Some(isum as f64 + f);
60 } else if let Some(ref mut fsum) = self.float_sum {
61 *fsum += f;
62 } else {
63 self.float_sum = Some(*f);
64 }
65 Ok(())
66 }
67 _ => Err(anyhow!("Cannot sum non-numeric value")),
68 }
69 }
70
71 pub fn finalize(self) -> DataValue {
72 if !self.has_values {
73 return DataValue::Null;
74 }
75
76 if let Some(fsum) = self.float_sum {
77 DataValue::Float(fsum)
78 } else if let Some(isum) = self.int_sum {
79 DataValue::Integer(isum)
80 } else {
81 DataValue::Null
82 }
83 }
84}
85
86#[derive(Debug, Clone)]
88pub struct AvgState {
89 pub sum: SumState,
90 pub count: i64,
91}
92
93impl AvgState {
94 pub fn new() -> Self {
95 Self {
96 sum: SumState::new(),
97 count: 0,
98 }
99 }
100
101 pub fn add(&mut self, value: &DataValue) -> Result<()> {
102 if !matches!(value, DataValue::Null) {
103 self.sum.add(value)?;
104 self.count += 1;
105 }
106 Ok(())
107 }
108
109 pub fn finalize(self) -> DataValue {
110 if self.count == 0 {
111 return DataValue::Null;
112 }
113
114 let sum = self.sum.finalize();
115 match sum {
116 DataValue::Integer(n) => DataValue::Float(n as f64 / self.count as f64),
117 DataValue::Float(f) => DataValue::Float(f / self.count as f64),
118 _ => DataValue::Null,
119 }
120 }
121}
122
123#[derive(Debug, Clone)]
125pub struct MinMaxState {
126 pub is_min: bool,
127 pub current: Option<DataValue>,
128}
129
130impl MinMaxState {
131 pub fn new(is_min: bool) -> Self {
132 Self {
133 is_min,
134 current: None,
135 }
136 }
137
138 pub fn add(&mut self, value: &DataValue) -> Result<()> {
139 if matches!(value, DataValue::Null) {
140 return Ok(());
141 }
142
143 if let Some(ref current) = self.current {
144 let should_update = if self.is_min {
145 value < current
146 } else {
147 value > current
148 };
149
150 if should_update {
151 self.current = Some(value.clone());
152 }
153 } else {
154 self.current = Some(value.clone());
155 }
156
157 Ok(())
158 }
159
160 pub fn finalize(self) -> DataValue {
161 self.current.unwrap_or(DataValue::Null)
162 }
163}
164
165pub trait AggregateFunction: Send + Sync {
167 fn name(&self) -> &str;
169
170 fn init(&self) -> AggregateState;
172
173 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()>;
175
176 fn finalize(&self, state: AggregateState) -> DataValue;
178
179 fn requires_numeric(&self) -> bool {
181 false
182 }
183}
184
185pub struct AggregateRegistry {
187 functions: Vec<Box<dyn AggregateFunction>>,
188}
189
190impl AggregateRegistry {
191 pub fn new() -> Self {
192 use analytics::*;
193 use functions::*;
194
195 let functions: Vec<Box<dyn AggregateFunction>> = vec![
196 Box::new(CountFunction),
197 Box::new(CountStarFunction),
198 Box::new(SumFunction),
199 Box::new(AvgFunction),
200 Box::new(MinFunction),
201 Box::new(MaxFunction),
202 Box::new(DeltasFunction),
204 Box::new(SumsFunction),
205 Box::new(MavgFunction),
206 Box::new(PctChangeFunction),
207 Box::new(RankFunction),
208 Box::new(CumMaxFunction),
209 Box::new(CumMinFunction),
210 ];
211
212 Self { functions }
213 }
214
215 pub fn get(&self, name: &str) -> Option<&dyn AggregateFunction> {
216 let name_upper = name.to_uppercase();
217 self.functions
218 .iter()
219 .find(|f| f.name() == name_upper)
220 .map(|f| f.as_ref())
221 }
222
223 pub fn is_aggregate(&self, name: &str) -> bool {
224 self.get(name).is_some() || name.to_uppercase() == "COUNT" }
226}
227
228impl Default for AggregateRegistry {
229 fn default() -> Self {
230 Self::new()
231 }
232}
233
234pub fn contains_aggregate(expr: &crate::recursive_parser::SqlExpression) -> bool {
236 use crate::recursive_parser::SqlExpression;
237
238 match expr {
239 SqlExpression::FunctionCall { name, args } => {
240 let registry = AggregateRegistry::new();
241 if registry.is_aggregate(name) {
242 return true;
243 }
244 args.iter().any(contains_aggregate)
246 }
247 SqlExpression::BinaryOp { left, right, .. } => {
248 contains_aggregate(left) || contains_aggregate(right)
249 }
250 SqlExpression::Not { expr } => contains_aggregate(expr),
251 SqlExpression::CaseExpression {
252 when_branches,
253 else_branch,
254 } => {
255 when_branches.iter().any(|branch| {
256 contains_aggregate(&branch.condition) || contains_aggregate(&branch.result)
257 }) || else_branch
258 .as_ref()
259 .map_or(false, |e| contains_aggregate(e))
260 }
261 _ => false,
262 }
263}