xerv_nodes/flow/
switch.rs

1//! Switch node (conditional routing).
2//!
3//! Routes data to different output ports based on a condition expression.
4//! Similar to a conditional branch in programming.
5
6use std::collections::HashMap;
7use xerv_core::traits::{Context, Node, NodeFuture, NodeInfo, NodeOutput, Port, PortDirection};
8use xerv_core::types::RelPtr;
9use xerv_core::value::Value;
10
11/// A condition that determines routing.
12#[derive(Debug, Clone)]
13pub enum SwitchCondition {
14    /// Always true (for default routing).
15    Always,
16    /// Check if a field equals a value.
17    FieldEquals { field: String, value: String },
18    /// Check if a field matches a regex pattern.
19    FieldMatches { field: String, pattern: String },
20    /// Check if a field is greater than a threshold.
21    FieldGreaterThan { field: String, threshold: f64 },
22    /// Check if a field is less than a threshold.
23    FieldLessThan { field: String, threshold: f64 },
24    /// Custom expression (e.g., "${input.score} > 0.8").
25    Expression(String),
26}
27
28impl Default for SwitchCondition {
29    fn default() -> Self {
30        Self::Always
31    }
32}
33
34/// Switch node - conditional routing.
35///
36/// Routes incoming data to one of two output ports ("true" or "false")
37/// based on evaluating a condition.
38///
39/// # Ports
40/// - Input: "in" - The data to route
41/// - Output: "true" - Activated when condition is true
42/// - Output: "false" - Activated when condition is false
43///
44/// # Example Configuration
45/// ```yaml
46/// nodes:
47///   check_fraud:
48///     type: std::switch
49///     config:
50///       condition:
51///         type: field_greater_than
52///         field: $.score
53///         threshold: 0.8
54///     inputs:
55///       - from: fraud_model.out -> in
56///     outputs:
57///       true: -> high_risk.in
58///       false: -> low_risk.in
59/// ```
60#[derive(Debug)]
61pub struct SwitchNode {
62    /// The condition to evaluate.
63    condition: SwitchCondition,
64}
65
66impl SwitchNode {
67    /// Create a switch node with the given condition.
68    pub fn new(condition: SwitchCondition) -> Self {
69        Self { condition }
70    }
71
72    /// Create a switch that routes based on field equality.
73    pub fn field_equals(field: impl Into<String>, value: impl Into<String>) -> Self {
74        Self {
75            condition: SwitchCondition::FieldEquals {
76                field: field.into(),
77                value: value.into(),
78            },
79        }
80    }
81
82    /// Create a switch that routes based on a threshold.
83    pub fn threshold(field: impl Into<String>, threshold: f64) -> Self {
84        Self {
85            condition: SwitchCondition::FieldGreaterThan {
86                field: field.into(),
87                threshold,
88            },
89        }
90    }
91
92    /// Create a switch that routes based on a custom expression.
93    pub fn expression(expr: impl Into<String>) -> Self {
94        Self {
95            condition: SwitchCondition::Expression(expr.into()),
96        }
97    }
98
99    /// Evaluate the condition against the input value.
100    ///
101    /// Returns `true` if the condition is satisfied, `false` otherwise.
102    /// If the input is null or the field doesn't exist, returns `false`
103    /// (except for `Always` which always returns `true`).
104    fn evaluate(&self, value: &Value) -> bool {
105        match &self.condition {
106            SwitchCondition::Always => true,
107
108            SwitchCondition::FieldEquals {
109                field,
110                value: expected,
111            } => {
112                let result = value.field_equals(field, expected);
113                tracing::debug!(
114                    field = %field,
115                    expected = %expected,
116                    result = result,
117                    "Evaluated field_equals condition"
118                );
119                result
120            }
121
122            SwitchCondition::FieldMatches { field, pattern } => {
123                let result = value.field_matches(field, pattern);
124                tracing::debug!(
125                    field = %field,
126                    pattern = %pattern,
127                    result = result,
128                    "Evaluated field_matches condition"
129                );
130                result
131            }
132
133            SwitchCondition::FieldGreaterThan { field, threshold } => {
134                let result = value.field_greater_than(field, *threshold);
135                tracing::debug!(
136                    field = %field,
137                    threshold = %threshold,
138                    result = result,
139                    "Evaluated field_greater_than condition"
140                );
141                result
142            }
143
144            SwitchCondition::FieldLessThan { field, threshold } => {
145                let result = value.field_less_than(field, *threshold);
146                tracing::debug!(
147                    field = %field,
148                    threshold = %threshold,
149                    result = result,
150                    "Evaluated field_less_than condition"
151                );
152                result
153            }
154
155            SwitchCondition::Expression(expr) => {
156                // Expression evaluation is a simplified subset:
157                // - "${field} == value" -> field_equals
158                // - "${field} > value" -> field_greater_than
159                // - "${field} < value" -> field_less_than
160                // - "${field}" -> field_is_true (boolean check)
161                let result = self.evaluate_expression(expr, value);
162                tracing::debug!(
163                    expr = %expr,
164                    result = result,
165                    "Evaluated expression condition"
166                );
167                result
168            }
169        }
170    }
171
172    /// Evaluate a simple expression against the input value.
173    ///
174    /// Supports basic patterns:
175    /// - `${field}` -> checks if field is truthy
176    /// - `${field} == "value"` -> string equality
177    /// - `${field} > number` -> numeric greater than
178    /// - `${field} < number` -> numeric less than
179    /// - `${field} >= number` -> numeric greater than or equal
180    /// - `${field} <= number` -> numeric less than or equal
181    fn evaluate_expression(&self, expr: &str, value: &Value) -> bool {
182        let expr = expr.trim();
183
184        // Try to parse comparison expressions
185        if let Some((field, op, rhs)) = self.parse_comparison(expr) {
186            match op {
187                "==" | "=" => {
188                    // String equality (strip quotes from rhs)
189                    let rhs = rhs.trim_matches('"').trim_matches('\'');
190                    value.field_equals(&field, rhs)
191                }
192                "!=" => {
193                    let rhs = rhs.trim_matches('"').trim_matches('\'');
194                    !value.field_equals(&field, rhs)
195                }
196                ">" => {
197                    if let Ok(threshold) = rhs.parse::<f64>() {
198                        value.field_greater_than(&field, threshold)
199                    } else {
200                        false
201                    }
202                }
203                "<" => {
204                    if let Ok(threshold) = rhs.parse::<f64>() {
205                        value.field_less_than(&field, threshold)
206                    } else {
207                        false
208                    }
209                }
210                ">=" => {
211                    if let Ok(threshold) = rhs.parse::<f64>() {
212                        value.get_f64(&field).map_or(false, |v| v >= threshold)
213                    } else {
214                        false
215                    }
216                }
217                "<=" => {
218                    if let Ok(threshold) = rhs.parse::<f64>() {
219                        value.get_f64(&field).map_or(false, |v| v <= threshold)
220                    } else {
221                        false
222                    }
223                }
224                _ => false,
225            }
226        } else if let Some(field) = self.parse_field_ref(expr) {
227            // Just a field reference - check if truthy
228            value.field_is_true(&field)
229        } else {
230            // Unrecognized expression format
231            tracing::warn!(expr = %expr, "Unrecognized expression format");
232            false
233        }
234    }
235
236    /// Parse a comparison expression like "${field} > 0.5"
237    fn parse_comparison<'a>(&self, expr: &'a str) -> Option<(String, &'a str, &'a str)> {
238        // Operators in order of specificity (>= before >)
239        let operators = [">=", "<=", "==", "!=", ">", "<", "="];
240
241        for op in operators {
242            if let Some(pos) = expr.find(op) {
243                let lhs = expr[..pos].trim();
244                let rhs = expr[pos + op.len()..].trim();
245
246                // Extract field from ${field} or $.field syntax
247                if let Some(field) = self.parse_field_ref(lhs) {
248                    return Some((field, op, rhs));
249                }
250            }
251        }
252        None
253    }
254
255    /// Parse a field reference like "${field}" or "$.field"
256    fn parse_field_ref(&self, s: &str) -> Option<String> {
257        let s = s.trim();
258
259        // ${field.path} format
260        if s.starts_with("${") && s.ends_with('}') {
261            return Some(s[2..s.len() - 1].to_string());
262        }
263
264        // $.field.path format
265        if s.starts_with("$.") {
266            return Some(s[2..].to_string());
267        }
268
269        // Plain field name (if it looks like an identifier)
270        if !s.is_empty()
271            && s.chars()
272                .all(|c| c.is_alphanumeric() || c == '_' || c == '.')
273        {
274            return Some(s.to_string());
275        }
276
277        None
278    }
279}
280
281impl Node for SwitchNode {
282    fn info(&self) -> NodeInfo {
283        NodeInfo::new("std", "switch")
284            .with_description("Conditional routing based on expression")
285            .with_inputs(vec![Port::input("Any")])
286            .with_outputs(vec![
287                Port::named("true", PortDirection::Output, "Any"),
288                Port::named("false", PortDirection::Output, "Any"),
289                Port::error(),
290            ])
291    }
292
293    fn execute<'a>(&'a self, ctx: Context, inputs: HashMap<String, RelPtr<()>>) -> NodeFuture<'a> {
294        Box::pin(async move {
295            let input = inputs.get("in").copied().unwrap_or_else(RelPtr::null);
296
297            // Read and parse input data from arena
298            let value = if input.is_null() {
299                Value::null()
300            } else {
301                match ctx.read_bytes(input) {
302                    Ok(bytes) => Value::from_bytes(&bytes).unwrap_or_else(|e| {
303                        tracing::warn!(error = %e, "Failed to parse input as JSON, using null");
304                        Value::null()
305                    }),
306                    Err(e) => {
307                        tracing::warn!(error = %e, "Failed to read input from arena, using null");
308                        Value::null()
309                    }
310                }
311            };
312
313            let result = self.evaluate(&value);
314
315            tracing::debug!(condition_result = result, "Switch evaluated condition");
316
317            if result {
318                Ok(NodeOutput::on_true(input))
319            } else {
320                Ok(NodeOutput::on_false(input))
321            }
322        })
323    }
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329    use serde_json::json;
330
331    #[test]
332    fn switch_node_info() {
333        let node = SwitchNode::new(SwitchCondition::Always);
334        let info = node.info();
335
336        assert_eq!(info.name, "std::switch");
337        assert_eq!(info.inputs.len(), 1);
338        assert_eq!(info.inputs[0].name, "in");
339        assert_eq!(info.outputs.len(), 3);
340        assert_eq!(info.outputs[0].name, "true");
341        assert_eq!(info.outputs[1].name, "false");
342    }
343
344    #[test]
345    fn switch_condition_always() {
346        let node = SwitchNode::new(SwitchCondition::Always);
347        let value = Value::null();
348        assert!(node.evaluate(&value));
349    }
350
351    #[test]
352    fn switch_threshold_creation() {
353        let node = SwitchNode::threshold("score", 0.8);
354        assert!(matches!(
355            node.condition,
356            SwitchCondition::FieldGreaterThan { threshold, .. } if threshold == 0.8
357        ));
358    }
359
360    #[test]
361    fn switch_field_equals() {
362        let node = SwitchNode::field_equals("status", "active");
363        let value = Value(json!({"status": "active"}));
364        assert!(node.evaluate(&value));
365
366        let value = Value(json!({"status": "inactive"}));
367        assert!(!node.evaluate(&value));
368    }
369
370    #[test]
371    fn switch_field_greater_than() {
372        let node = SwitchNode::threshold("score", 0.8);
373
374        let value = Value(json!({"score": 0.9}));
375        assert!(node.evaluate(&value));
376
377        let value = Value(json!({"score": 0.7}));
378        assert!(!node.evaluate(&value));
379
380        let value = Value(json!({"score": 0.8}));
381        assert!(!node.evaluate(&value)); // not strictly greater than
382    }
383
384    #[test]
385    fn switch_field_less_than() {
386        let node = SwitchNode::new(SwitchCondition::FieldLessThan {
387            field: "temperature".to_string(),
388            threshold: 30.0,
389        });
390
391        let value = Value(json!({"temperature": 25.0}));
392        assert!(node.evaluate(&value));
393
394        let value = Value(json!({"temperature": 35.0}));
395        assert!(!node.evaluate(&value));
396    }
397
398    #[test]
399    fn switch_field_matches() {
400        let node = SwitchNode::new(SwitchCondition::FieldMatches {
401            field: "email".to_string(),
402            pattern: r"^[\w.+-]+@[\w.-]+\.\w+$".to_string(),
403        });
404
405        let value = Value(json!({"email": "test@example.com"}));
406        assert!(node.evaluate(&value));
407
408        let value = Value(json!({"email": "invalid-email"}));
409        assert!(!node.evaluate(&value));
410    }
411
412    #[test]
413    fn switch_expression_comparison() {
414        let node = SwitchNode::expression("${score} > 0.5");
415
416        let value = Value(json!({"score": 0.7}));
417        assert!(node.evaluate(&value));
418
419        let value = Value(json!({"score": 0.3}));
420        assert!(!node.evaluate(&value));
421    }
422
423    #[test]
424    fn switch_expression_equality() {
425        let node = SwitchNode::expression("${status} == \"success\"");
426
427        let value = Value(json!({"status": "success"}));
428        assert!(node.evaluate(&value));
429
430        let value = Value(json!({"status": "failed"}));
431        assert!(!node.evaluate(&value));
432    }
433
434    #[test]
435    fn switch_expression_boolean_field() {
436        let node = SwitchNode::expression("${is_valid}");
437
438        let value = Value(json!({"is_valid": true}));
439        assert!(node.evaluate(&value));
440
441        let value = Value(json!({"is_valid": false}));
442        assert!(!node.evaluate(&value));
443    }
444
445    #[test]
446    fn switch_nested_field_access() {
447        let node = SwitchNode::field_equals("result.status", "ok");
448
449        let value = Value(json!({"result": {"status": "ok"}}));
450        assert!(node.evaluate(&value));
451
452        let value = Value(json!({"result": {"status": "error"}}));
453        assert!(!node.evaluate(&value));
454    }
455
456    #[test]
457    fn switch_missing_field_returns_false() {
458        let node = SwitchNode::field_equals("nonexistent", "value");
459        let value = Value(json!({"other": "data"}));
460        assert!(!node.evaluate(&value));
461    }
462
463    #[test]
464    fn switch_expression_gte() {
465        let node = SwitchNode::expression("${count} >= 10");
466
467        let value = Value(json!({"count": 10}));
468        assert!(node.evaluate(&value));
469
470        let value = Value(json!({"count": 15}));
471        assert!(node.evaluate(&value));
472
473        let value = Value(json!({"count": 5}));
474        assert!(!node.evaluate(&value));
475    }
476
477    #[test]
478    fn switch_expression_not_equals() {
479        let node = SwitchNode::expression("${status} != \"error\"");
480
481        let value = Value(json!({"status": "success"}));
482        assert!(node.evaluate(&value));
483
484        let value = Value(json!({"status": "error"}));
485        assert!(!node.evaluate(&value));
486    }
487}