1use std::collections::HashMap;
6use xerv_core::traits::{Context, Node, NodeFuture, NodeInfo, NodeOutput, Port, PortDirection};
7use xerv_core::types::RelPtr;
8use xerv_core::value::Value;
9
10#[derive(Debug, Clone, Copy, Default)]
12pub enum AggregateOperation {
13 #[default]
15 Sum,
16 Average,
18 Min,
20 Max,
22 Count,
24 Product,
26}
27
28impl AggregateOperation {
29 fn apply(&self, values: &[f64]) -> f64 {
31 if values.is_empty() {
32 return match self {
33 Self::Count => 0.0,
34 Self::Sum | Self::Average | Self::Product => 0.0,
35 Self::Min | Self::Max => f64::NAN,
36 };
37 }
38
39 match self {
40 Self::Sum => values.iter().sum(),
41 Self::Average => values.iter().sum::<f64>() / values.len() as f64,
42 Self::Min => values.iter().cloned().fold(f64::INFINITY, f64::min),
43 Self::Max => values.iter().cloned().fold(f64::NEG_INFINITY, f64::max),
44 Self::Count => values.len() as f64,
45 Self::Product => values.iter().product(),
46 }
47 }
48}
49
50#[derive(Debug)]
76pub struct AggregateNode {
77 operation: AggregateOperation,
79 source: AggregateSource,
81 output_field: String,
83}
84
85#[derive(Debug, Clone)]
87pub enum AggregateSource {
88 ArrayField {
90 array_path: String,
92 value_field: Option<String>,
94 },
95 Fields(Vec<String>),
97}
98
99impl AggregateNode {
100 pub fn array(
102 array_path: impl Into<String>,
103 operation: AggregateOperation,
104 output_field: impl Into<String>,
105 ) -> Self {
106 Self {
107 operation,
108 source: AggregateSource::ArrayField {
109 array_path: array_path.into(),
110 value_field: None,
111 },
112 output_field: output_field.into(),
113 }
114 }
115
116 pub fn array_field(
119 array_path: impl Into<String>,
120 value_field: impl Into<String>,
121 operation: AggregateOperation,
122 output_field: impl Into<String>,
123 ) -> Self {
124 Self {
125 operation,
126 source: AggregateSource::ArrayField {
127 array_path: array_path.into(),
128 value_field: Some(value_field.into()),
129 },
130 output_field: output_field.into(),
131 }
132 }
133
134 pub fn fields(
136 fields: Vec<String>,
137 operation: AggregateOperation,
138 output_field: impl Into<String>,
139 ) -> Self {
140 Self {
141 operation,
142 source: AggregateSource::Fields(fields),
143 output_field: output_field.into(),
144 }
145 }
146
147 pub fn sum(array_path: impl Into<String>, output_field: impl Into<String>) -> Self {
149 Self::array(array_path, AggregateOperation::Sum, output_field)
150 }
151
152 pub fn average(array_path: impl Into<String>, output_field: impl Into<String>) -> Self {
154 Self::array(array_path, AggregateOperation::Average, output_field)
155 }
156
157 pub fn count(array_path: impl Into<String>, output_field: impl Into<String>) -> Self {
159 Self::array(array_path, AggregateOperation::Count, output_field)
160 }
161
162 fn extract_values(&self, input: &Value) -> Vec<f64> {
164 match &self.source {
165 AggregateSource::ArrayField {
166 array_path,
167 value_field,
168 } => {
169 let Some(array_value) = input.get_field(array_path) else {
170 return Vec::new();
171 };
172
173 let Some(array) = array_value.inner().as_array() else {
174 return Vec::new();
175 };
176
177 array
178 .iter()
179 .filter_map(|item| {
180 if let Some(field) = value_field {
181 Value::from(item.clone())
183 .get_field(field)
184 .and_then(|v| v.as_f64())
185 } else {
186 item.as_f64()
188 }
189 })
190 .collect()
191 }
192 AggregateSource::Fields(fields) => fields
193 .iter()
194 .filter_map(|path| input.get_f64(path))
195 .collect(),
196 }
197 }
198}
199
200impl Node for AggregateNode {
201 fn info(&self) -> NodeInfo {
202 NodeInfo::new("std", "aggregate")
203 .with_description("Aggregate numeric values (sum, avg, min, max, count)")
204 .with_inputs(vec![Port::input("Any")])
205 .with_outputs(vec![
206 Port::named("out", PortDirection::Output, "Any")
207 .with_description("Object with aggregation result"),
208 Port::error(),
209 ])
210 }
211
212 fn execute<'a>(&'a self, ctx: Context, inputs: HashMap<String, RelPtr<()>>) -> NodeFuture<'a> {
213 Box::pin(async move {
214 let input = inputs.get("in").copied().unwrap_or_else(RelPtr::null);
215
216 let value = if input.is_null() {
218 Value::null()
219 } else {
220 match ctx.read_bytes(input) {
221 Ok(bytes) => Value::from_bytes(&bytes).unwrap_or_else(|_| Value::null()),
222 Err(_) => Value::null(),
223 }
224 };
225
226 let values = self.extract_values(&value);
228 let result = self.operation.apply(&values);
229
230 tracing::debug!(
231 operation = ?self.operation,
232 value_count = values.len(),
233 result = result,
234 output_field = %self.output_field,
235 "Aggregate: computed result"
236 );
237
238 let output = Value::from(serde_json::json!({
240 &self.output_field: result
241 }));
242
243 let output_bytes = match output.to_bytes() {
245 Ok(bytes) => bytes,
246 Err(e) => {
247 return Ok(NodeOutput::error_with_message(format!(
248 "Failed to serialize output: {}",
249 e
250 )));
251 }
252 };
253
254 let output_ptr = match ctx.write_bytes(&output_bytes) {
255 Ok(ptr) => ptr,
256 Err(e) => {
257 return Ok(NodeOutput::error_with_message(format!(
258 "Failed to write output: {}",
259 e
260 )));
261 }
262 };
263
264 Ok(NodeOutput::out(output_ptr))
265 })
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272 use serde_json::json;
273
274 #[test]
275 fn aggregate_node_info() {
276 let node = AggregateNode::sum("values", "total");
277 let info = node.info();
278
279 assert_eq!(info.name, "std::aggregate");
280 assert_eq!(info.inputs.len(), 1);
281 assert_eq!(info.outputs.len(), 2);
282 }
283
284 #[test]
285 fn aggregate_sum_array() {
286 let node = AggregateNode::sum("numbers", "total");
287 let input = Value::from(json!({"numbers": [1, 2, 3, 4, 5]}));
288 let values = node.extract_values(&input);
289 let result = node.operation.apply(&values);
290
291 assert_eq!(result, 15.0);
292 }
293
294 #[test]
295 fn aggregate_average_array() {
296 let node = AggregateNode::average("scores", "avg");
297 let input = Value::from(json!({"scores": [10, 20, 30]}));
298 let values = node.extract_values(&input);
299 let result = node.operation.apply(&values);
300
301 assert_eq!(result, 20.0);
302 }
303
304 #[test]
305 fn aggregate_min_max() {
306 let input = Value::from(json!({"values": [5, 2, 8, 1, 9]}));
307
308 let min_node = AggregateNode::array("values", AggregateOperation::Min, "min");
309 let min_values = min_node.extract_values(&input);
310 assert_eq!(min_node.operation.apply(&min_values), 1.0);
311
312 let max_node = AggregateNode::array("values", AggregateOperation::Max, "max");
313 let max_values = max_node.extract_values(&input);
314 assert_eq!(max_node.operation.apply(&max_values), 9.0);
315 }
316
317 #[test]
318 fn aggregate_count() {
319 let node = AggregateNode::count("items", "count");
320 let input = Value::from(json!({"items": [1, 2, 3, 4, 5, 6]}));
321 let values = node.extract_values(&input);
322 let result = node.operation.apply(&values);
323
324 assert_eq!(result, 6.0);
325 }
326
327 #[test]
328 fn aggregate_array_of_objects() {
329 let node = AggregateNode::array_field("items", "price", AggregateOperation::Sum, "total");
330
331 let input = Value::from(json!({
332 "items": [
333 {"name": "A", "price": 10.0},
334 {"name": "B", "price": 20.0},
335 {"name": "C", "price": 30.0}
336 ]
337 }));
338
339 let values = node.extract_values(&input);
340 let result = node.operation.apply(&values);
341
342 assert_eq!(result, 60.0);
343 }
344
345 #[test]
346 fn aggregate_multiple_fields() {
347 let node = AggregateNode::fields(
348 vec!["a".to_string(), "b".to_string(), "c".to_string()],
349 AggregateOperation::Sum,
350 "sum",
351 );
352
353 let input = Value::from(json!({"a": 10, "b": 20, "c": 30}));
354 let values = node.extract_values(&input);
355 let result = node.operation.apply(&values);
356
357 assert_eq!(result, 60.0);
358 }
359
360 #[test]
361 fn aggregate_empty_array() {
362 let node = AggregateNode::sum("empty", "total");
363 let input = Value::from(json!({"empty": []}));
364 let values = node.extract_values(&input);
365 let result = node.operation.apply(&values);
366
367 assert_eq!(result, 0.0);
368 }
369
370 #[test]
371 fn aggregate_missing_field() {
372 let node = AggregateNode::sum("missing", "total");
373 let input = Value::from(json!({"other": [1, 2, 3]}));
374 let values = node.extract_values(&input);
375 let result = node.operation.apply(&values);
376
377 assert_eq!(result, 0.0);
378 }
379
380 #[test]
381 fn aggregate_product() {
382 let node = AggregateNode::array("factors", AggregateOperation::Product, "product");
383 let input = Value::from(json!({"factors": [2, 3, 4]}));
384 let values = node.extract_values(&input);
385 let result = node.operation.apply(&values);
386
387 assert_eq!(result, 24.0);
388 }
389}