xerv_nodes/data/
aggregate.rs

1//! Aggregate node (numeric aggregation).
2//!
3//! Aggregates numeric values from arrays or multiple fields.
4
5use std::collections::HashMap;
6use xerv_core::traits::{Context, Node, NodeFuture, NodeInfo, NodeOutput, Port, PortDirection};
7use xerv_core::types::RelPtr;
8use xerv_core::value::Value;
9
10/// Aggregation operation to perform.
11#[derive(Debug, Clone, Copy, Default)]
12pub enum AggregateOperation {
13    /// Sum all values.
14    #[default]
15    Sum,
16    /// Calculate the average.
17    Average,
18    /// Find the minimum value.
19    Min,
20    /// Find the maximum value.
21    Max,
22    /// Count the number of values.
23    Count,
24    /// Calculate the product of all values.
25    Product,
26}
27
28impl AggregateOperation {
29    /// Apply the aggregation operation to a list of values.
30    fn apply(&self, values: &[f64]) -> f64 {
31        if values.is_empty() {
32            return match self {
33                Self::Count => 0.0,
34                Self::Sum | Self::Average | Self::Product => 0.0,
35                Self::Min | Self::Max => f64::NAN,
36            };
37        }
38
39        match self {
40            Self::Sum => values.iter().sum(),
41            Self::Average => values.iter().sum::<f64>() / values.len() as f64,
42            Self::Min => values.iter().cloned().fold(f64::INFINITY, f64::min),
43            Self::Max => values.iter().cloned().fold(f64::NEG_INFINITY, f64::max),
44            Self::Count => values.len() as f64,
45            Self::Product => values.iter().product(),
46        }
47    }
48}
49
50/// Aggregate node - numeric aggregation.
51///
52/// Performs aggregation operations on numeric values from:
53/// - An array field (aggregates all elements)
54/// - Multiple specified fields (aggregates their values)
55///
56/// # Ports
57/// - Input: "in" - Source data with numeric values
58/// - Output: "out" - Object with aggregation result
59/// - Output: "error" - Emitted on errors
60///
61/// # Example Configuration
62/// ```yaml
63/// nodes:
64///   calculate_total:
65///     type: std::aggregate
66///     config:
67///       operation: sum
68///       source: $.items[*].price    # Array field
69///       output_field: total
70///     inputs:
71///       - from: cart.out -> in
72///     outputs:
73///       out: -> checkout.in
74/// ```
75#[derive(Debug)]
76pub struct AggregateNode {
77    /// The aggregation operation to perform.
78    operation: AggregateOperation,
79    /// Source: either an array field path or list of field paths.
80    source: AggregateSource,
81    /// Output field name for the result.
82    output_field: String,
83}
84
85/// Source of values for aggregation.
86#[derive(Debug, Clone)]
87pub enum AggregateSource {
88    /// Aggregate values from an array field.
89    ArrayField {
90        /// Path to the array field.
91        array_path: String,
92        /// Optional sub-field to extract from each element.
93        value_field: Option<String>,
94    },
95    /// Aggregate values from multiple specific fields.
96    Fields(Vec<String>),
97}
98
99impl AggregateNode {
100    /// Create an aggregate node that operates on an array field.
101    pub fn array(
102        array_path: impl Into<String>,
103        operation: AggregateOperation,
104        output_field: impl Into<String>,
105    ) -> Self {
106        Self {
107            operation,
108            source: AggregateSource::ArrayField {
109                array_path: array_path.into(),
110                value_field: None,
111            },
112            output_field: output_field.into(),
113        }
114    }
115
116    /// Create an aggregate node that operates on an array of objects,
117    /// extracting a specific field from each.
118    pub fn array_field(
119        array_path: impl Into<String>,
120        value_field: impl Into<String>,
121        operation: AggregateOperation,
122        output_field: impl Into<String>,
123    ) -> Self {
124        Self {
125            operation,
126            source: AggregateSource::ArrayField {
127                array_path: array_path.into(),
128                value_field: Some(value_field.into()),
129            },
130            output_field: output_field.into(),
131        }
132    }
133
134    /// Create an aggregate node that operates on multiple fields.
135    pub fn fields(
136        fields: Vec<String>,
137        operation: AggregateOperation,
138        output_field: impl Into<String>,
139    ) -> Self {
140        Self {
141            operation,
142            source: AggregateSource::Fields(fields),
143            output_field: output_field.into(),
144        }
145    }
146
147    /// Create a sum aggregation over an array.
148    pub fn sum(array_path: impl Into<String>, output_field: impl Into<String>) -> Self {
149        Self::array(array_path, AggregateOperation::Sum, output_field)
150    }
151
152    /// Create an average aggregation over an array.
153    pub fn average(array_path: impl Into<String>, output_field: impl Into<String>) -> Self {
154        Self::array(array_path, AggregateOperation::Average, output_field)
155    }
156
157    /// Create a count aggregation over an array.
158    pub fn count(array_path: impl Into<String>, output_field: impl Into<String>) -> Self {
159        Self::array(array_path, AggregateOperation::Count, output_field)
160    }
161
162    /// Extract numeric values from input based on source configuration.
163    fn extract_values(&self, input: &Value) -> Vec<f64> {
164        match &self.source {
165            AggregateSource::ArrayField {
166                array_path,
167                value_field,
168            } => {
169                let Some(array_value) = input.get_field(array_path) else {
170                    return Vec::new();
171                };
172
173                let Some(array) = array_value.inner().as_array() else {
174                    return Vec::new();
175                };
176
177                array
178                    .iter()
179                    .filter_map(|item| {
180                        if let Some(field) = value_field {
181                            // Extract sub-field from each object
182                            Value::from(item.clone())
183                                .get_field(field)
184                                .and_then(|v| v.as_f64())
185                        } else {
186                            // Use the element directly
187                            item.as_f64()
188                        }
189                    })
190                    .collect()
191            }
192            AggregateSource::Fields(fields) => fields
193                .iter()
194                .filter_map(|path| input.get_f64(path))
195                .collect(),
196        }
197    }
198}
199
200impl Node for AggregateNode {
201    fn info(&self) -> NodeInfo {
202        NodeInfo::new("std", "aggregate")
203            .with_description("Aggregate numeric values (sum, avg, min, max, count)")
204            .with_inputs(vec![Port::input("Any")])
205            .with_outputs(vec![
206                Port::named("out", PortDirection::Output, "Any")
207                    .with_description("Object with aggregation result"),
208                Port::error(),
209            ])
210    }
211
212    fn execute<'a>(&'a self, ctx: Context, inputs: HashMap<String, RelPtr<()>>) -> NodeFuture<'a> {
213        Box::pin(async move {
214            let input = inputs.get("in").copied().unwrap_or_else(RelPtr::null);
215
216            // Read and parse input data
217            let value = if input.is_null() {
218                Value::null()
219            } else {
220                match ctx.read_bytes(input) {
221                    Ok(bytes) => Value::from_bytes(&bytes).unwrap_or_else(|_| Value::null()),
222                    Err(_) => Value::null(),
223                }
224            };
225
226            // Extract values and compute aggregate
227            let values = self.extract_values(&value);
228            let result = self.operation.apply(&values);
229
230            tracing::debug!(
231                operation = ?self.operation,
232                value_count = values.len(),
233                result = result,
234                output_field = %self.output_field,
235                "Aggregate: computed result"
236            );
237
238            // Create output object
239            let output = Value::from(serde_json::json!({
240                &self.output_field: result
241            }));
242
243            // Write output to arena
244            let output_bytes = match output.to_bytes() {
245                Ok(bytes) => bytes,
246                Err(e) => {
247                    return Ok(NodeOutput::error_with_message(format!(
248                        "Failed to serialize output: {}",
249                        e
250                    )));
251                }
252            };
253
254            let output_ptr = match ctx.write_bytes(&output_bytes) {
255                Ok(ptr) => ptr,
256                Err(e) => {
257                    return Ok(NodeOutput::error_with_message(format!(
258                        "Failed to write output: {}",
259                        e
260                    )));
261                }
262            };
263
264            Ok(NodeOutput::out(output_ptr))
265        })
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272    use serde_json::json;
273
274    #[test]
275    fn aggregate_node_info() {
276        let node = AggregateNode::sum("values", "total");
277        let info = node.info();
278
279        assert_eq!(info.name, "std::aggregate");
280        assert_eq!(info.inputs.len(), 1);
281        assert_eq!(info.outputs.len(), 2);
282    }
283
284    #[test]
285    fn aggregate_sum_array() {
286        let node = AggregateNode::sum("numbers", "total");
287        let input = Value::from(json!({"numbers": [1, 2, 3, 4, 5]}));
288        let values = node.extract_values(&input);
289        let result = node.operation.apply(&values);
290
291        assert_eq!(result, 15.0);
292    }
293
294    #[test]
295    fn aggregate_average_array() {
296        let node = AggregateNode::average("scores", "avg");
297        let input = Value::from(json!({"scores": [10, 20, 30]}));
298        let values = node.extract_values(&input);
299        let result = node.operation.apply(&values);
300
301        assert_eq!(result, 20.0);
302    }
303
304    #[test]
305    fn aggregate_min_max() {
306        let input = Value::from(json!({"values": [5, 2, 8, 1, 9]}));
307
308        let min_node = AggregateNode::array("values", AggregateOperation::Min, "min");
309        let min_values = min_node.extract_values(&input);
310        assert_eq!(min_node.operation.apply(&min_values), 1.0);
311
312        let max_node = AggregateNode::array("values", AggregateOperation::Max, "max");
313        let max_values = max_node.extract_values(&input);
314        assert_eq!(max_node.operation.apply(&max_values), 9.0);
315    }
316
317    #[test]
318    fn aggregate_count() {
319        let node = AggregateNode::count("items", "count");
320        let input = Value::from(json!({"items": [1, 2, 3, 4, 5, 6]}));
321        let values = node.extract_values(&input);
322        let result = node.operation.apply(&values);
323
324        assert_eq!(result, 6.0);
325    }
326
327    #[test]
328    fn aggregate_array_of_objects() {
329        let node = AggregateNode::array_field("items", "price", AggregateOperation::Sum, "total");
330
331        let input = Value::from(json!({
332            "items": [
333                {"name": "A", "price": 10.0},
334                {"name": "B", "price": 20.0},
335                {"name": "C", "price": 30.0}
336            ]
337        }));
338
339        let values = node.extract_values(&input);
340        let result = node.operation.apply(&values);
341
342        assert_eq!(result, 60.0);
343    }
344
345    #[test]
346    fn aggregate_multiple_fields() {
347        let node = AggregateNode::fields(
348            vec!["a".to_string(), "b".to_string(), "c".to_string()],
349            AggregateOperation::Sum,
350            "sum",
351        );
352
353        let input = Value::from(json!({"a": 10, "b": 20, "c": 30}));
354        let values = node.extract_values(&input);
355        let result = node.operation.apply(&values);
356
357        assert_eq!(result, 60.0);
358    }
359
360    #[test]
361    fn aggregate_empty_array() {
362        let node = AggregateNode::sum("empty", "total");
363        let input = Value::from(json!({"empty": []}));
364        let values = node.extract_values(&input);
365        let result = node.operation.apply(&values);
366
367        assert_eq!(result, 0.0);
368    }
369
370    #[test]
371    fn aggregate_missing_field() {
372        let node = AggregateNode::sum("missing", "total");
373        let input = Value::from(json!({"other": [1, 2, 3]}));
374        let values = node.extract_values(&input);
375        let result = node.operation.apply(&values);
376
377        assert_eq!(result, 0.0);
378    }
379
380    #[test]
381    fn aggregate_product() {
382        let node = AggregateNode::array("factors", AggregateOperation::Product, "product");
383        let input = Value::from(json!({"factors": [2, 3, 4]}));
384        let values = node.extract_values(&input);
385        let result = node.operation.apply(&values);
386
387        assert_eq!(result, 24.0);
388    }
389}