sql_cli/sql/aggregates/
mod.rs1use anyhow::{anyhow, Result};
7
8use crate::data::datatable::DataValue;
9
10pub mod functions;
11
12#[derive(Debug, Clone)]
14pub enum AggregateState {
15 Count(i64),
16 Sum(SumState),
17 Avg(AvgState),
18 MinMax(MinMaxState),
19 CollectList(Vec<DataValue>),
20}
21
22#[derive(Debug, Clone)]
24pub struct SumState {
25 pub int_sum: Option<i64>,
26 pub float_sum: Option<f64>,
27 pub has_values: bool,
28}
29
30impl SumState {
31 pub fn new() -> Self {
32 Self {
33 int_sum: None,
34 float_sum: None,
35 has_values: false,
36 }
37 }
38
39 pub fn add(&mut self, value: &DataValue) -> Result<()> {
40 match value {
41 DataValue::Null => Ok(()), DataValue::Integer(n) => {
43 self.has_values = true;
44 if let Some(ref mut sum) = self.int_sum {
45 *sum = sum.saturating_add(*n);
46 } else if let Some(ref mut fsum) = self.float_sum {
47 *fsum += *n as f64;
48 } else {
49 self.int_sum = Some(*n);
50 }
51 Ok(())
52 }
53 DataValue::Float(f) => {
54 self.has_values = true;
55 if let Some(isum) = self.int_sum.take() {
57 self.float_sum = Some(isum as f64 + f);
58 } else if let Some(ref mut fsum) = self.float_sum {
59 *fsum += f;
60 } else {
61 self.float_sum = Some(*f);
62 }
63 Ok(())
64 }
65 _ => Err(anyhow!("Cannot sum non-numeric value")),
66 }
67 }
68
69 pub fn finalize(self) -> DataValue {
70 if !self.has_values {
71 return DataValue::Null;
72 }
73
74 if let Some(fsum) = self.float_sum {
75 DataValue::Float(fsum)
76 } else if let Some(isum) = self.int_sum {
77 DataValue::Integer(isum)
78 } else {
79 DataValue::Null
80 }
81 }
82}
83
84#[derive(Debug, Clone)]
86pub struct AvgState {
87 pub sum: SumState,
88 pub count: i64,
89}
90
91impl AvgState {
92 pub fn new() -> Self {
93 Self {
94 sum: SumState::new(),
95 count: 0,
96 }
97 }
98
99 pub fn add(&mut self, value: &DataValue) -> Result<()> {
100 if !matches!(value, DataValue::Null) {
101 self.sum.add(value)?;
102 self.count += 1;
103 }
104 Ok(())
105 }
106
107 pub fn finalize(self) -> DataValue {
108 if self.count == 0 {
109 return DataValue::Null;
110 }
111
112 let sum = self.sum.finalize();
113 match sum {
114 DataValue::Integer(n) => DataValue::Float(n as f64 / self.count as f64),
115 DataValue::Float(f) => DataValue::Float(f / self.count as f64),
116 _ => DataValue::Null,
117 }
118 }
119}
120
121#[derive(Debug, Clone)]
123pub struct MinMaxState {
124 pub is_min: bool,
125 pub current: Option<DataValue>,
126}
127
128impl MinMaxState {
129 pub fn new(is_min: bool) -> Self {
130 Self {
131 is_min,
132 current: None,
133 }
134 }
135
136 pub fn add(&mut self, value: &DataValue) -> Result<()> {
137 if matches!(value, DataValue::Null) {
138 return Ok(());
139 }
140
141 if let Some(ref current) = self.current {
142 let should_update = if self.is_min {
143 value < current
144 } else {
145 value > current
146 };
147
148 if should_update {
149 self.current = Some(value.clone());
150 }
151 } else {
152 self.current = Some(value.clone());
153 }
154
155 Ok(())
156 }
157
158 pub fn finalize(self) -> DataValue {
159 self.current.unwrap_or(DataValue::Null)
160 }
161}
162
163pub trait AggregateFunction: Send + Sync {
165 fn name(&self) -> &str;
167
168 fn init(&self) -> AggregateState;
170
171 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()>;
173
174 fn finalize(&self, state: AggregateState) -> DataValue;
176
177 fn requires_numeric(&self) -> bool {
179 false
180 }
181}
182
183pub struct AggregateRegistry {
185 functions: Vec<Box<dyn AggregateFunction>>,
186}
187
188impl AggregateRegistry {
189 pub fn new() -> Self {
190 use functions::*;
191
192 let functions: Vec<Box<dyn AggregateFunction>> = vec![
193 Box::new(CountFunction),
194 Box::new(CountStarFunction),
195 Box::new(SumFunction),
196 Box::new(AvgFunction),
197 Box::new(MinFunction),
198 Box::new(MaxFunction),
199 ];
200
201 Self { functions }
202 }
203
204 pub fn get(&self, name: &str) -> Option<&dyn AggregateFunction> {
205 let name_upper = name.to_uppercase();
206 self.functions
207 .iter()
208 .find(|f| f.name() == name_upper)
209 .map(|f| f.as_ref())
210 }
211
212 pub fn is_aggregate(&self, name: &str) -> bool {
213 self.get(name).is_some() || name.to_uppercase() == "COUNT" }
215}
216
217impl Default for AggregateRegistry {
218 fn default() -> Self {
219 Self::new()
220 }
221}
222
223pub fn contains_aggregate(expr: &crate::recursive_parser::SqlExpression) -> bool {
225 use crate::recursive_parser::SqlExpression;
226
227 match expr {
228 SqlExpression::FunctionCall { name, args } => {
229 let registry = AggregateRegistry::new();
230 if registry.is_aggregate(name) {
231 return true;
232 }
233 args.iter().any(contains_aggregate)
235 }
236 SqlExpression::BinaryOp { left, right, .. } => {
237 contains_aggregate(left) || contains_aggregate(right)
238 }
239 SqlExpression::Not { expr } => contains_aggregate(expr),
240 SqlExpression::CaseExpression {
241 when_branches,
242 else_branch,
243 } => {
244 when_branches.iter().any(|branch| {
245 contains_aggregate(&branch.condition) || contains_aggregate(&branch.result)
246 }) || else_branch
247 .as_ref()
248 .map_or(false, |e| contains_aggregate(e))
249 }
250 _ => false,
251 }
252}