Skip to main content

simple_agents_workflow/
expressions.rs

1use std::collections::HashMap;
2use std::sync::{Mutex, MutexGuard, OnceLock};
3
4use serde_json::{Number, Value};
5use thiserror::Error;
6
7#[derive(Debug, Error, Clone, PartialEq, Eq)]
8pub enum ExpressionError {
9    #[error("expression is empty")]
10    Empty,
11    #[error("invalid expression '{expression}': {reason}")]
12    Invalid { expression: String, reason: String },
13    #[error("path '{path}' not found in scoped input")]
14    MissingPath { path: String },
15    #[error("expression complexity limit exceeded: {metric}={value}, max={max}")]
16    ComplexityLimitExceeded {
17        metric: &'static str,
18        value: usize,
19        max: usize,
20    },
21}
22
23#[derive(Debug, Clone, Copy)]
24pub struct ExpressionLimits {
25    pub max_expression_chars: usize,
26    pub max_operator_count: usize,
27    pub max_depth: usize,
28    pub max_path_segments: usize,
29    pub max_cache_entries: usize,
30}
31
32/// Expression backend strategy.
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
34pub enum ExpressionBackend {
35    /// Native Rust parser/evaluator.
36    #[default]
37    Native,
38    /// CEL-compatible backend routed through the same abstraction.
39    CelCompatible,
40}
41
42impl Default for ExpressionLimits {
43    fn default() -> Self {
44        Self {
45            max_expression_chars: 2048,
46            max_operator_count: 64,
47            max_depth: 24,
48            max_path_segments: 16,
49            max_cache_entries: 512,
50        }
51    }
52}
53
54#[derive(Debug, Clone)]
55enum ExpressionNode {
56    Not(Box<ExpressionNode>),
57    Eq(Operand, Operand),
58    Ne(Operand, Operand),
59    Or(Vec<ExpressionNode>),
60    And(Vec<ExpressionNode>),
61    Truthy(Operand),
62}
63
64#[derive(Debug, Clone)]
65enum Operand {
66    Bool(bool),
67    Null,
68    Number(Number),
69    String(String),
70    Path(String),
71}
72
73#[derive(Debug, Default)]
74pub struct ExpressionEngine {
75    cache: Mutex<HashMap<String, ExpressionNode>>,
76    backend: ExpressionBackend,
77    limits: ExpressionLimits,
78}
79
80impl ExpressionEngine {
81    pub fn new() -> Self {
82        Self::with_limits(ExpressionLimits::default())
83    }
84
85    pub fn with_limits(limits: ExpressionLimits) -> Self {
86        Self::with_backend(ExpressionBackend::Native, limits)
87    }
88
89    pub fn with_backend(backend: ExpressionBackend, limits: ExpressionLimits) -> Self {
90        Self {
91            cache: Mutex::new(HashMap::new()),
92            backend,
93            limits,
94        }
95    }
96
97    pub fn validate(&self, expression: &str) -> Result<(), ExpressionError> {
98        let _ = self.compile(expression)?;
99        Ok(())
100    }
101
102    pub fn evaluate_bool(
103        &self,
104        expression: &str,
105        scoped_input: &Value,
106    ) -> Result<bool, ExpressionError> {
107        let node = self.compile(expression)?;
108        eval_node(&node, scoped_input)
109    }
110
111    fn cache_lock(&self) -> MutexGuard<'_, HashMap<String, ExpressionNode>> {
112        self.cache
113            .lock()
114            .unwrap_or_else(|poisoned| poisoned.into_inner())
115    }
116
117    fn compile(&self, expression: &str) -> Result<ExpressionNode, ExpressionError> {
118        let normalized = expression.trim();
119        if normalized.is_empty() {
120            return Err(ExpressionError::Empty);
121        }
122
123        if normalized.len() > self.limits.max_expression_chars {
124            return Err(ExpressionError::ComplexityLimitExceeded {
125                metric: "chars",
126                value: normalized.len(),
127                max: self.limits.max_expression_chars,
128            });
129        }
130
131        if let Some(cached) = self.cache_lock().get(normalized).cloned() {
132            return Ok(cached);
133        }
134
135        let parsed = match self.backend {
136            ExpressionBackend::Native | ExpressionBackend::CelCompatible => parse_expr(normalized)?,
137        };
138        let op_count = count_operators(&parsed);
139        if op_count > self.limits.max_operator_count {
140            return Err(ExpressionError::ComplexityLimitExceeded {
141                metric: "operators",
142                value: op_count,
143                max: self.limits.max_operator_count,
144            });
145        }
146        let depth = tree_depth(&parsed);
147        if depth > self.limits.max_depth {
148            return Err(ExpressionError::ComplexityLimitExceeded {
149                metric: "depth",
150                value: depth,
151                max: self.limits.max_depth,
152            });
153        }
154        validate_path_segments(&parsed, self.limits.max_path_segments)?;
155
156        let mut cache = self.cache_lock();
157        if cache.len() >= self.limits.max_cache_entries {
158            if let Some(evicted) = cache.keys().next().cloned() {
159                cache.remove(&evicted);
160            }
161        }
162        cache.insert(normalized.to_string(), parsed.clone());
163        Ok(parsed)
164    }
165}
166
167pub fn default_expression_engine() -> &'static ExpressionEngine {
168    static ENGINE: OnceLock<ExpressionEngine> = OnceLock::new();
169    ENGINE.get_or_init(ExpressionEngine::new)
170}
171
172pub fn evaluate_bool(expression: &str, scoped_input: &Value) -> Result<bool, ExpressionError> {
173    default_expression_engine().evaluate_bool(expression, scoped_input)
174}
175
176fn parse_expr(expression: &str) -> Result<ExpressionNode, ExpressionError> {
177    if let Some(parts) = split_top_level(expression, "||") {
178        let mut nodes = Vec::with_capacity(parts.len());
179        for part in parts {
180            nodes.push(parse_expr(part)?);
181        }
182        return Ok(ExpressionNode::Or(nodes));
183    }
184
185    if let Some(parts) = split_top_level(expression, "&&") {
186        let mut nodes = Vec::with_capacity(parts.len());
187        for part in parts {
188            nodes.push(parse_expr(part)?);
189        }
190        return Ok(ExpressionNode::And(nodes));
191    }
192
193    if let Some(inner) = expression.strip_prefix('!') {
194        return Ok(ExpressionNode::Not(Box::new(parse_expr(inner.trim())?)));
195    }
196
197    if let Some((left, right)) = split_once_top_level(expression, "==") {
198        return Ok(ExpressionNode::Eq(
199            parse_operand(left)?,
200            parse_operand(right)?,
201        ));
202    }
203
204    if let Some((left, right)) = split_once_top_level(expression, "!=") {
205        return Ok(ExpressionNode::Ne(
206            parse_operand(left)?,
207            parse_operand(right)?,
208        ));
209    }
210
211    Ok(ExpressionNode::Truthy(parse_operand(expression)?))
212}
213
214fn parse_operand(token: &str) -> Result<Operand, ExpressionError> {
215    let trimmed = token.trim();
216    if trimmed.is_empty() {
217        return Err(ExpressionError::Invalid {
218            expression: token.to_string(),
219            reason: "empty operand".to_string(),
220        });
221    }
222
223    if trimmed.eq_ignore_ascii_case("true") {
224        return Ok(Operand::Bool(true));
225    }
226    if trimmed.eq_ignore_ascii_case("false") {
227        return Ok(Operand::Bool(false));
228    }
229    if trimmed.eq_ignore_ascii_case("null") {
230        return Ok(Operand::Null);
231    }
232
233    if let Some(value) = trimmed.strip_prefix('"').and_then(|v| v.strip_suffix('"')) {
234        return Ok(Operand::String(value.to_string()));
235    }
236    if let Some(value) = trimmed
237        .strip_prefix('\'')
238        .and_then(|v| v.strip_suffix('\''))
239    {
240        return Ok(Operand::String(value.to_string()));
241    }
242
243    if let Ok(value) = trimmed.parse::<i64>() {
244        return Ok(Operand::Number(Number::from(value)));
245    }
246
247    if let Ok(value) = trimmed.parse::<f64>() {
248        if let Some(number) = Number::from_f64(value) {
249            return Ok(Operand::Number(number));
250        }
251    }
252
253    let path = trimmed.strip_prefix("$.").unwrap_or(trimmed);
254    if path.is_empty() {
255        return Err(ExpressionError::Invalid {
256            expression: token.to_string(),
257            reason: "path cannot be '$.'".to_string(),
258        });
259    }
260
261    Ok(Operand::Path(path.to_string()))
262}
263
264fn split_top_level<'a>(input: &'a str, delimiter: &'a str) -> Option<Vec<&'a str>> {
265    let mut parts = Vec::new();
266    let mut start = 0usize;
267    let mut in_single = false;
268    let mut in_double = false;
269    let bytes = input.as_bytes();
270    let delim = delimiter.as_bytes();
271    let mut idx = 0usize;
272
273    while idx < bytes.len() {
274        match bytes[idx] {
275            b'\'' if !in_double => in_single = !in_single,
276            b'"' if !in_single => in_double = !in_double,
277            _ => {}
278        }
279
280        if !in_single
281            && !in_double
282            && idx + delim.len() <= bytes.len()
283            && &bytes[idx..idx + delim.len()] == delim
284        {
285            parts.push(input[start..idx].trim());
286            start = idx + delim.len();
287            idx = start;
288            continue;
289        }
290        idx += 1;
291    }
292
293    if parts.is_empty() {
294        return None;
295    }
296    parts.push(input[start..].trim());
297    Some(parts)
298}
299
300fn split_once_top_level<'a>(input: &'a str, delimiter: &'a str) -> Option<(&'a str, &'a str)> {
301    let mut in_single = false;
302    let mut in_double = false;
303    let bytes = input.as_bytes();
304    let delim = delimiter.as_bytes();
305    let mut idx = 0usize;
306
307    while idx < bytes.len() {
308        match bytes[idx] {
309            b'\'' if !in_double => in_single = !in_single,
310            b'"' if !in_single => in_double = !in_double,
311            _ => {}
312        }
313
314        if !in_single
315            && !in_double
316            && idx + delim.len() <= bytes.len()
317            && &bytes[idx..idx + delim.len()] == delim
318        {
319            let left = input[..idx].trim();
320            let right = input[idx + delim.len()..].trim();
321            if left.is_empty() || right.is_empty() {
322                return None;
323            }
324            return Some((left, right));
325        }
326        idx += 1;
327    }
328
329    None
330}
331
332fn eval_node(node: &ExpressionNode, scoped_input: &Value) -> Result<bool, ExpressionError> {
333    match node {
334        ExpressionNode::Not(inner) => Ok(!eval_node(inner, scoped_input)?),
335        ExpressionNode::Eq(left, right) => {
336            Ok(eval_operand(left, scoped_input)? == eval_operand(right, scoped_input)?)
337        }
338        ExpressionNode::Ne(left, right) => {
339            Ok(eval_operand(left, scoped_input)? != eval_operand(right, scoped_input)?)
340        }
341        ExpressionNode::Or(nodes) => {
342            for node in nodes {
343                if eval_node(node, scoped_input)? {
344                    return Ok(true);
345                }
346            }
347            Ok(false)
348        }
349        ExpressionNode::And(nodes) => {
350            for node in nodes {
351                if !eval_node(node, scoped_input)? {
352                    return Ok(false);
353                }
354            }
355            Ok(true)
356        }
357        ExpressionNode::Truthy(operand) => Ok(is_truthy(&eval_operand(operand, scoped_input)?)),
358    }
359}
360
361fn eval_operand(operand: &Operand, scoped_input: &Value) -> Result<Value, ExpressionError> {
362    match operand {
363        Operand::Bool(v) => Ok(Value::Bool(*v)),
364        Operand::Null => Ok(Value::Null),
365        Operand::Number(v) => Ok(Value::Number(v.clone())),
366        Operand::String(v) => Ok(Value::String(v.clone())),
367        Operand::Path(path) => resolve_path(scoped_input, path)
368            .cloned()
369            .ok_or_else(|| ExpressionError::MissingPath { path: path.clone() }),
370    }
371}
372
373fn resolve_path<'a>(root: &'a Value, path: &str) -> Option<&'a Value> {
374    if path.is_empty() {
375        return Some(root);
376    }
377
378    path.split('.')
379        .filter(|segment| !segment.is_empty())
380        .try_fold(root, |current, segment| current.get(segment))
381}
382
383fn is_truthy(value: &Value) -> bool {
384    match value {
385        Value::Bool(value) => *value,
386        Value::Null => false,
387        Value::Number(number) => number.as_f64().is_some_and(|n| n != 0.0),
388        Value::String(value) => !value.is_empty(),
389        Value::Array(values) => !values.is_empty(),
390        Value::Object(values) => !values.is_empty(),
391    }
392}
393
394fn count_operators(node: &ExpressionNode) -> usize {
395    match node {
396        ExpressionNode::Not(inner) => 1 + count_operators(inner),
397        ExpressionNode::Eq(..) | ExpressionNode::Ne(..) | ExpressionNode::Truthy(..) => 1,
398        ExpressionNode::Or(nodes) | ExpressionNode::And(nodes) => {
399            1 + nodes.iter().map(count_operators).sum::<usize>()
400        }
401    }
402}
403
404fn tree_depth(node: &ExpressionNode) -> usize {
405    match node {
406        ExpressionNode::Not(inner) => 1 + tree_depth(inner),
407        ExpressionNode::Eq(..) | ExpressionNode::Ne(..) | ExpressionNode::Truthy(..) => 1,
408        ExpressionNode::Or(nodes) | ExpressionNode::And(nodes) => {
409            1 + nodes.iter().map(tree_depth).max().unwrap_or(0)
410        }
411    }
412}
413
414fn validate_path_segments(
415    node: &ExpressionNode,
416    max_segments: usize,
417) -> Result<(), ExpressionError> {
418    match node {
419        ExpressionNode::Not(inner) => validate_path_segments(inner, max_segments),
420        ExpressionNode::Eq(left, right) | ExpressionNode::Ne(left, right) => {
421            validate_operand_path(left, max_segments)?;
422            validate_operand_path(right, max_segments)
423        }
424        ExpressionNode::Or(nodes) | ExpressionNode::And(nodes) => {
425            for item in nodes {
426                validate_path_segments(item, max_segments)?;
427            }
428            Ok(())
429        }
430        ExpressionNode::Truthy(operand) => validate_operand_path(operand, max_segments),
431    }
432}
433
434fn validate_operand_path(operand: &Operand, max_segments: usize) -> Result<(), ExpressionError> {
435    if let Operand::Path(path) = operand {
436        let segments = path
437            .split('.')
438            .filter(|segment| !segment.is_empty())
439            .count();
440        if segments > max_segments {
441            return Err(ExpressionError::ComplexityLimitExceeded {
442                metric: "path_segments",
443                value: segments,
444                max: max_segments,
445            });
446        }
447    }
448    Ok(())
449}
450
451#[cfg(test)]
452mod tests {
453    use serde_json::json;
454
455    use super::{ExpressionBackend, ExpressionEngine, ExpressionError, ExpressionLimits};
456
457    #[test]
458    fn supports_truthy_path_and_equality() {
459        let engine = ExpressionEngine::new();
460        let input = json!({"input": {"approved": true}, "score": 5});
461
462        assert!(engine
463            .evaluate_bool("input.approved", &input)
464            .expect("truthy check should pass"));
465        assert!(engine
466            .evaluate_bool("score == 5", &input)
467            .expect("equality check should pass"));
468        assert!(engine
469            .evaluate_bool("score != 2", &input)
470            .expect("inequality check should pass"));
471    }
472
473    #[test]
474    fn supports_boolean_operators() {
475        let engine = ExpressionEngine::new();
476        let input = json!({"a": true, "b": false, "n": 1});
477
478        assert!(engine
479            .evaluate_bool("a && n == 1", &input)
480            .expect("and expression should pass"));
481        assert!(engine
482            .evaluate_bool("b || n == 1", &input)
483            .expect("or expression should pass"));
484        assert!(engine
485            .evaluate_bool("!b", &input)
486            .expect("not expression should pass"));
487    }
488
489    #[test]
490    fn reports_missing_path() {
491        let engine = ExpressionEngine::new();
492        let error = engine
493            .evaluate_bool("missing.path", &json!({}))
494            .expect_err("missing path should fail");
495        assert!(matches!(error, ExpressionError::MissingPath { .. }));
496    }
497
498    #[test]
499    fn validate_uses_parse_cache() {
500        let engine = ExpressionEngine::new();
501        engine
502            .validate("input.ready == true")
503            .expect("first parse should pass");
504        engine
505            .validate("input.ready == true")
506            .expect("second parse should hit cache");
507    }
508
509    #[test]
510    fn rejects_expression_when_depth_limit_exceeded() {
511        let engine = ExpressionEngine::with_limits(ExpressionLimits {
512            max_depth: 1,
513            ..ExpressionLimits::default()
514        });
515
516        let error = engine
517            .evaluate_bool("a && b && c", &json!({"a": true, "b": true, "c": true}))
518            .expect_err("depth guard should reject expression");
519        assert!(matches!(
520            error,
521            ExpressionError::ComplexityLimitExceeded {
522                metric: "depth",
523                ..
524            }
525        ));
526    }
527
528    #[test]
529    fn rejects_expression_when_path_segments_limit_exceeded() {
530        let engine = ExpressionEngine::with_limits(ExpressionLimits {
531            max_path_segments: 2,
532            ..ExpressionLimits::default()
533        });
534
535        let error = engine
536            .evaluate_bool(
537                "input.deep.value == true",
538                &json!({"input": {"deep": {"value": true}}}),
539            )
540            .expect_err("path segment guard should reject expression");
541        assert!(matches!(
542            error,
543            ExpressionError::ComplexityLimitExceeded {
544                metric: "path_segments",
545                ..
546            }
547        ));
548    }
549
550    #[test]
551    fn supports_cel_compatible_backend_path() {
552        let engine = ExpressionEngine::with_backend(
553            ExpressionBackend::CelCompatible,
554            ExpressionLimits::default(),
555        );
556        let result = engine
557            .evaluate_bool("input.ready == true", &json!({"input": {"ready": true}}))
558            .expect("cel-compatible backend should evaluate expression");
559        assert!(result);
560    }
561}