1use anyhow::{anyhow, Result};
7
8use crate::data::datatable::DataValue;
9
10pub mod analytics;
11pub mod functions;
12
13#[derive(Debug, Clone)]
15pub enum AggregateState {
16 Count(i64),
17 Sum(SumState),
18 Avg(AvgState),
19 MinMax(MinMaxState),
20 Variance(VarianceState),
21 CollectList(Vec<DataValue>),
22 Analytics(analytics::AnalyticsState),
23}
24
25#[derive(Debug, Clone)]
27pub struct SumState {
28 pub int_sum: Option<i64>,
29 pub float_sum: Option<f64>,
30 pub has_values: bool,
31}
32
33impl Default for SumState {
34 fn default() -> Self {
35 Self::new()
36 }
37}
38
39impl SumState {
40 #[must_use]
41 pub fn new() -> Self {
42 Self {
43 int_sum: None,
44 float_sum: None,
45 has_values: false,
46 }
47 }
48
49 pub fn add(&mut self, value: &DataValue) -> Result<()> {
50 match value {
51 DataValue::Null => Ok(()), DataValue::Integer(n) => {
53 self.has_values = true;
54 if let Some(ref mut sum) = self.int_sum {
55 *sum = sum.saturating_add(*n);
56 } else if let Some(ref mut fsum) = self.float_sum {
57 *fsum += *n as f64;
58 } else {
59 self.int_sum = Some(*n);
60 }
61 Ok(())
62 }
63 DataValue::Float(f) => {
64 self.has_values = true;
65 if let Some(isum) = self.int_sum.take() {
67 self.float_sum = Some(isum as f64 + f);
68 } else if let Some(ref mut fsum) = self.float_sum {
69 *fsum += f;
70 } else {
71 self.float_sum = Some(*f);
72 }
73 Ok(())
74 }
75 _ => Err(anyhow!("Cannot sum non-numeric value")),
76 }
77 }
78
79 #[must_use]
80 pub fn finalize(self) -> DataValue {
81 if !self.has_values {
82 return DataValue::Null;
83 }
84
85 if let Some(fsum) = self.float_sum {
86 DataValue::Float(fsum)
87 } else if let Some(isum) = self.int_sum {
88 DataValue::Integer(isum)
89 } else {
90 DataValue::Null
91 }
92 }
93}
94
95#[derive(Debug, Clone)]
97pub struct AvgState {
98 pub sum: SumState,
99 pub count: i64,
100}
101
102impl Default for AvgState {
103 fn default() -> Self {
104 Self::new()
105 }
106}
107
108impl AvgState {
109 #[must_use]
110 pub fn new() -> Self {
111 Self {
112 sum: SumState::new(),
113 count: 0,
114 }
115 }
116
117 pub fn add(&mut self, value: &DataValue) -> Result<()> {
118 if !matches!(value, DataValue::Null) {
119 self.sum.add(value)?;
120 self.count += 1;
121 }
122 Ok(())
123 }
124
125 #[must_use]
126 pub fn finalize(self) -> DataValue {
127 if self.count == 0 {
128 return DataValue::Null;
129 }
130
131 let sum = self.sum.finalize();
132 match sum {
133 DataValue::Integer(n) => DataValue::Float(n as f64 / self.count as f64),
134 DataValue::Float(f) => DataValue::Float(f / self.count as f64),
135 _ => DataValue::Null,
136 }
137 }
138}
139
140#[derive(Debug, Clone)]
142pub struct MinMaxState {
143 pub is_min: bool,
144 pub current: Option<DataValue>,
145}
146
147impl MinMaxState {
148 #[must_use]
149 pub fn new(is_min: bool) -> Self {
150 Self {
151 is_min,
152 current: None,
153 }
154 }
155
156 pub fn add(&mut self, value: &DataValue) -> Result<()> {
157 if matches!(value, DataValue::Null) {
158 return Ok(());
159 }
160
161 if let Some(ref current) = self.current {
162 let should_update = if self.is_min {
163 value < current
164 } else {
165 value > current
166 };
167
168 if should_update {
169 self.current = Some(value.clone());
170 }
171 } else {
172 self.current = Some(value.clone());
173 }
174
175 Ok(())
176 }
177
178 #[must_use]
179 pub fn finalize(self) -> DataValue {
180 self.current.unwrap_or(DataValue::Null)
181 }
182}
183
184#[derive(Debug, Clone)]
186pub struct VarianceState {
187 pub sum: f64,
188 pub sum_of_squares: f64,
189 pub count: i64,
190}
191
192impl Default for VarianceState {
193 fn default() -> Self {
194 Self::new()
195 }
196}
197
198impl VarianceState {
199 #[must_use]
200 pub fn new() -> Self {
201 Self {
202 sum: 0.0,
203 sum_of_squares: 0.0,
204 count: 0,
205 }
206 }
207
208 pub fn add(&mut self, value: &DataValue) -> Result<()> {
209 match value {
210 DataValue::Null => Ok(()), DataValue::Integer(n) => {
212 let f = *n as f64;
213 self.sum += f;
214 self.sum_of_squares += f * f;
215 self.count += 1;
216 Ok(())
217 }
218 DataValue::Float(f) => {
219 self.sum += f;
220 self.sum_of_squares += f * f;
221 self.count += 1;
222 Ok(())
223 }
224 _ => Err(anyhow!("Cannot compute variance of non-numeric value")),
225 }
226 }
227
228 #[must_use]
229 pub fn variance(&self) -> f64 {
230 if self.count <= 1 {
231 return 0.0;
232 }
233 let mean = self.sum / self.count as f64;
234 (self.sum_of_squares / self.count as f64) - (mean * mean)
235 }
236
237 #[must_use]
238 pub fn stddev(&self) -> f64 {
239 self.variance().sqrt()
240 }
241
242 #[must_use]
243 pub fn finalize_variance(self) -> DataValue {
244 if self.count == 0 {
245 DataValue::Null
246 } else {
247 DataValue::Float(self.variance())
248 }
249 }
250
251 #[must_use]
252 pub fn finalize_stddev(self) -> DataValue {
253 if self.count == 0 {
254 DataValue::Null
255 } else {
256 DataValue::Float(self.stddev())
257 }
258 }
259}
260
261pub trait AggregateFunction: Send + Sync {
263 fn name(&self) -> &str;
265
266 fn init(&self) -> AggregateState;
268
269 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()>;
271
272 fn finalize(&self, state: AggregateState) -> DataValue;
274
275 fn requires_numeric(&self) -> bool {
277 false
278 }
279}
280
281pub struct AggregateRegistry {
283 functions: Vec<Box<dyn AggregateFunction>>,
284}
285
286impl AggregateRegistry {
287 #[must_use]
288 pub fn new() -> Self {
289 use analytics::{
290 CumMaxFunction, CumMinFunction, DeltasFunction, MavgFunction, PctChangeFunction,
291 RankFunction, SumsFunction,
292 };
293 use functions::{
294 AvgFunction, CountFunction, CountStarFunction, MaxFunction, MinFunction,
295 StdDevFunction, SumFunction, VarianceFunction,
296 };
297
298 let functions: Vec<Box<dyn AggregateFunction>> = vec![
299 Box::new(CountFunction),
300 Box::new(CountStarFunction),
301 Box::new(SumFunction),
302 Box::new(AvgFunction),
303 Box::new(MinFunction),
304 Box::new(MaxFunction),
305 Box::new(StdDevFunction),
306 Box::new(VarianceFunction),
307 Box::new(DeltasFunction),
309 Box::new(SumsFunction),
310 Box::new(MavgFunction),
311 Box::new(PctChangeFunction),
312 Box::new(RankFunction),
313 Box::new(CumMaxFunction),
314 Box::new(CumMinFunction),
315 ];
316
317 Self { functions }
318 }
319
320 #[must_use]
321 pub fn get(&self, name: &str) -> Option<&dyn AggregateFunction> {
322 let name_upper = name.to_uppercase();
323 self.functions
324 .iter()
325 .find(|f| f.name() == name_upper)
326 .map(std::convert::AsRef::as_ref)
327 }
328
329 #[must_use]
330 pub fn is_aggregate(&self, name: &str) -> bool {
331 self.get(name).is_some() || name.to_uppercase() == "COUNT" }
333}
334
335impl Default for AggregateRegistry {
336 fn default() -> Self {
337 Self::new()
338 }
339}
340
341pub fn contains_aggregate(expr: &crate::recursive_parser::SqlExpression) -> bool {
343 use crate::recursive_parser::SqlExpression;
344
345 match expr {
346 SqlExpression::FunctionCall { name, args, .. } => {
347 let registry = AggregateRegistry::new();
348 if registry.is_aggregate(name) {
349 return true;
350 }
351 args.iter().any(contains_aggregate)
353 }
354 SqlExpression::BinaryOp { left, right, .. } => {
355 contains_aggregate(left) || contains_aggregate(right)
356 }
357 SqlExpression::Not { expr } => contains_aggregate(expr),
358 SqlExpression::CaseExpression {
359 when_branches,
360 else_branch,
361 } => {
362 when_branches.iter().any(|branch| {
363 contains_aggregate(&branch.condition) || contains_aggregate(&branch.result)
364 }) || else_branch.as_ref().is_some_and(|e| contains_aggregate(e))
365 }
366 _ => false,
367 }
368}