1use std::collections::HashMap;
7use xerv_core::traits::{Context, Node, NodeFuture, NodeInfo, NodeOutput, Port, PortDirection};
8use xerv_core::types::RelPtr;
9use xerv_core::value::Value;
10
11#[derive(Debug, Clone)]
13pub enum SwitchCondition {
14 Always,
16 FieldEquals { field: String, value: String },
18 FieldMatches { field: String, pattern: String },
20 FieldGreaterThan { field: String, threshold: f64 },
22 FieldLessThan { field: String, threshold: f64 },
24 Expression(String),
26}
27
28impl Default for SwitchCondition {
29 fn default() -> Self {
30 Self::Always
31 }
32}
33
34#[derive(Debug)]
61pub struct SwitchNode {
62 condition: SwitchCondition,
64}
65
66impl SwitchNode {
67 pub fn new(condition: SwitchCondition) -> Self {
69 Self { condition }
70 }
71
72 pub fn field_equals(field: impl Into<String>, value: impl Into<String>) -> Self {
74 Self {
75 condition: SwitchCondition::FieldEquals {
76 field: field.into(),
77 value: value.into(),
78 },
79 }
80 }
81
82 pub fn threshold(field: impl Into<String>, threshold: f64) -> Self {
84 Self {
85 condition: SwitchCondition::FieldGreaterThan {
86 field: field.into(),
87 threshold,
88 },
89 }
90 }
91
92 pub fn expression(expr: impl Into<String>) -> Self {
94 Self {
95 condition: SwitchCondition::Expression(expr.into()),
96 }
97 }
98
99 fn evaluate(&self, value: &Value) -> bool {
105 match &self.condition {
106 SwitchCondition::Always => true,
107
108 SwitchCondition::FieldEquals {
109 field,
110 value: expected,
111 } => {
112 let result = value.field_equals(field, expected);
113 tracing::debug!(
114 field = %field,
115 expected = %expected,
116 result = result,
117 "Evaluated field_equals condition"
118 );
119 result
120 }
121
122 SwitchCondition::FieldMatches { field, pattern } => {
123 let result = value.field_matches(field, pattern);
124 tracing::debug!(
125 field = %field,
126 pattern = %pattern,
127 result = result,
128 "Evaluated field_matches condition"
129 );
130 result
131 }
132
133 SwitchCondition::FieldGreaterThan { field, threshold } => {
134 let result = value.field_greater_than(field, *threshold);
135 tracing::debug!(
136 field = %field,
137 threshold = %threshold,
138 result = result,
139 "Evaluated field_greater_than condition"
140 );
141 result
142 }
143
144 SwitchCondition::FieldLessThan { field, threshold } => {
145 let result = value.field_less_than(field, *threshold);
146 tracing::debug!(
147 field = %field,
148 threshold = %threshold,
149 result = result,
150 "Evaluated field_less_than condition"
151 );
152 result
153 }
154
155 SwitchCondition::Expression(expr) => {
156 let result = self.evaluate_expression(expr, value);
162 tracing::debug!(
163 expr = %expr,
164 result = result,
165 "Evaluated expression condition"
166 );
167 result
168 }
169 }
170 }
171
172 fn evaluate_expression(&self, expr: &str, value: &Value) -> bool {
182 let expr = expr.trim();
183
184 if let Some((field, op, rhs)) = self.parse_comparison(expr) {
186 match op {
187 "==" | "=" => {
188 let rhs = rhs.trim_matches('"').trim_matches('\'');
190 value.field_equals(&field, rhs)
191 }
192 "!=" => {
193 let rhs = rhs.trim_matches('"').trim_matches('\'');
194 !value.field_equals(&field, rhs)
195 }
196 ">" => {
197 if let Ok(threshold) = rhs.parse::<f64>() {
198 value.field_greater_than(&field, threshold)
199 } else {
200 false
201 }
202 }
203 "<" => {
204 if let Ok(threshold) = rhs.parse::<f64>() {
205 value.field_less_than(&field, threshold)
206 } else {
207 false
208 }
209 }
210 ">=" => {
211 if let Ok(threshold) = rhs.parse::<f64>() {
212 value.get_f64(&field).map_or(false, |v| v >= threshold)
213 } else {
214 false
215 }
216 }
217 "<=" => {
218 if let Ok(threshold) = rhs.parse::<f64>() {
219 value.get_f64(&field).map_or(false, |v| v <= threshold)
220 } else {
221 false
222 }
223 }
224 _ => false,
225 }
226 } else if let Some(field) = self.parse_field_ref(expr) {
227 value.field_is_true(&field)
229 } else {
230 tracing::warn!(expr = %expr, "Unrecognized expression format");
232 false
233 }
234 }
235
236 fn parse_comparison<'a>(&self, expr: &'a str) -> Option<(String, &'a str, &'a str)> {
238 let operators = [">=", "<=", "==", "!=", ">", "<", "="];
240
241 for op in operators {
242 if let Some(pos) = expr.find(op) {
243 let lhs = expr[..pos].trim();
244 let rhs = expr[pos + op.len()..].trim();
245
246 if let Some(field) = self.parse_field_ref(lhs) {
248 return Some((field, op, rhs));
249 }
250 }
251 }
252 None
253 }
254
255 fn parse_field_ref(&self, s: &str) -> Option<String> {
257 let s = s.trim();
258
259 if s.starts_with("${") && s.ends_with('}') {
261 return Some(s[2..s.len() - 1].to_string());
262 }
263
264 if s.starts_with("$.") {
266 return Some(s[2..].to_string());
267 }
268
269 if !s.is_empty()
271 && s.chars()
272 .all(|c| c.is_alphanumeric() || c == '_' || c == '.')
273 {
274 return Some(s.to_string());
275 }
276
277 None
278 }
279}
280
281impl Node for SwitchNode {
282 fn info(&self) -> NodeInfo {
283 NodeInfo::new("std", "switch")
284 .with_description("Conditional routing based on expression")
285 .with_inputs(vec![Port::input("Any")])
286 .with_outputs(vec![
287 Port::named("true", PortDirection::Output, "Any"),
288 Port::named("false", PortDirection::Output, "Any"),
289 Port::error(),
290 ])
291 }
292
293 fn execute<'a>(&'a self, ctx: Context, inputs: HashMap<String, RelPtr<()>>) -> NodeFuture<'a> {
294 Box::pin(async move {
295 let input = inputs.get("in").copied().unwrap_or_else(RelPtr::null);
296
297 let value = if input.is_null() {
299 Value::null()
300 } else {
301 match ctx.read_bytes(input) {
302 Ok(bytes) => Value::from_bytes(&bytes).unwrap_or_else(|e| {
303 tracing::warn!(error = %e, "Failed to parse input as JSON, using null");
304 Value::null()
305 }),
306 Err(e) => {
307 tracing::warn!(error = %e, "Failed to read input from arena, using null");
308 Value::null()
309 }
310 }
311 };
312
313 let result = self.evaluate(&value);
314
315 tracing::debug!(condition_result = result, "Switch evaluated condition");
316
317 if result {
318 Ok(NodeOutput::on_true(input))
319 } else {
320 Ok(NodeOutput::on_false(input))
321 }
322 })
323 }
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329 use serde_json::json;
330
331 #[test]
332 fn switch_node_info() {
333 let node = SwitchNode::new(SwitchCondition::Always);
334 let info = node.info();
335
336 assert_eq!(info.name, "std::switch");
337 assert_eq!(info.inputs.len(), 1);
338 assert_eq!(info.inputs[0].name, "in");
339 assert_eq!(info.outputs.len(), 3);
340 assert_eq!(info.outputs[0].name, "true");
341 assert_eq!(info.outputs[1].name, "false");
342 }
343
344 #[test]
345 fn switch_condition_always() {
346 let node = SwitchNode::new(SwitchCondition::Always);
347 let value = Value::null();
348 assert!(node.evaluate(&value));
349 }
350
351 #[test]
352 fn switch_threshold_creation() {
353 let node = SwitchNode::threshold("score", 0.8);
354 assert!(matches!(
355 node.condition,
356 SwitchCondition::FieldGreaterThan { threshold, .. } if threshold == 0.8
357 ));
358 }
359
360 #[test]
361 fn switch_field_equals() {
362 let node = SwitchNode::field_equals("status", "active");
363 let value = Value(json!({"status": "active"}));
364 assert!(node.evaluate(&value));
365
366 let value = Value(json!({"status": "inactive"}));
367 assert!(!node.evaluate(&value));
368 }
369
370 #[test]
371 fn switch_field_greater_than() {
372 let node = SwitchNode::threshold("score", 0.8);
373
374 let value = Value(json!({"score": 0.9}));
375 assert!(node.evaluate(&value));
376
377 let value = Value(json!({"score": 0.7}));
378 assert!(!node.evaluate(&value));
379
380 let value = Value(json!({"score": 0.8}));
381 assert!(!node.evaluate(&value)); }
383
384 #[test]
385 fn switch_field_less_than() {
386 let node = SwitchNode::new(SwitchCondition::FieldLessThan {
387 field: "temperature".to_string(),
388 threshold: 30.0,
389 });
390
391 let value = Value(json!({"temperature": 25.0}));
392 assert!(node.evaluate(&value));
393
394 let value = Value(json!({"temperature": 35.0}));
395 assert!(!node.evaluate(&value));
396 }
397
398 #[test]
399 fn switch_field_matches() {
400 let node = SwitchNode::new(SwitchCondition::FieldMatches {
401 field: "email".to_string(),
402 pattern: r"^[\w.+-]+@[\w.-]+\.\w+$".to_string(),
403 });
404
405 let value = Value(json!({"email": "test@example.com"}));
406 assert!(node.evaluate(&value));
407
408 let value = Value(json!({"email": "invalid-email"}));
409 assert!(!node.evaluate(&value));
410 }
411
412 #[test]
413 fn switch_expression_comparison() {
414 let node = SwitchNode::expression("${score} > 0.5");
415
416 let value = Value(json!({"score": 0.7}));
417 assert!(node.evaluate(&value));
418
419 let value = Value(json!({"score": 0.3}));
420 assert!(!node.evaluate(&value));
421 }
422
423 #[test]
424 fn switch_expression_equality() {
425 let node = SwitchNode::expression("${status} == \"success\"");
426
427 let value = Value(json!({"status": "success"}));
428 assert!(node.evaluate(&value));
429
430 let value = Value(json!({"status": "failed"}));
431 assert!(!node.evaluate(&value));
432 }
433
434 #[test]
435 fn switch_expression_boolean_field() {
436 let node = SwitchNode::expression("${is_valid}");
437
438 let value = Value(json!({"is_valid": true}));
439 assert!(node.evaluate(&value));
440
441 let value = Value(json!({"is_valid": false}));
442 assert!(!node.evaluate(&value));
443 }
444
445 #[test]
446 fn switch_nested_field_access() {
447 let node = SwitchNode::field_equals("result.status", "ok");
448
449 let value = Value(json!({"result": {"status": "ok"}}));
450 assert!(node.evaluate(&value));
451
452 let value = Value(json!({"result": {"status": "error"}}));
453 assert!(!node.evaluate(&value));
454 }
455
456 #[test]
457 fn switch_missing_field_returns_false() {
458 let node = SwitchNode::field_equals("nonexistent", "value");
459 let value = Value(json!({"other": "data"}));
460 assert!(!node.evaluate(&value));
461 }
462
463 #[test]
464 fn switch_expression_gte() {
465 let node = SwitchNode::expression("${count} >= 10");
466
467 let value = Value(json!({"count": 10}));
468 assert!(node.evaluate(&value));
469
470 let value = Value(json!({"count": 15}));
471 assert!(node.evaluate(&value));
472
473 let value = Value(json!({"count": 5}));
474 assert!(!node.evaluate(&value));
475 }
476
477 #[test]
478 fn switch_expression_not_equals() {
479 let node = SwitchNode::expression("${status} != \"error\"");
480
481 let value = Value(json!({"status": "success"}));
482 assert!(node.evaluate(&value));
483
484 let value = Value(json!({"status": "error"}));
485 assert!(!node.evaluate(&value));
486 }
487}