1use std::collections::HashMap;
2use std::sync::{Mutex, MutexGuard, OnceLock};
3
4use serde_json::{Number, Value};
5use thiserror::Error;
6
7#[derive(Debug, Error, Clone, PartialEq, Eq)]
8pub enum ExpressionError {
9 #[error("expression is empty")]
10 Empty,
11 #[error("invalid expression '{expression}': {reason}")]
12 Invalid { expression: String, reason: String },
13 #[error("path '{path}' not found in scoped input")]
14 MissingPath { path: String },
15 #[error("expression complexity limit exceeded: {metric}={value}, max={max}")]
16 ComplexityLimitExceeded {
17 metric: &'static str,
18 value: usize,
19 max: usize,
20 },
21}
22
23#[derive(Debug, Clone, Copy)]
24pub struct ExpressionLimits {
25 pub max_expression_chars: usize,
26 pub max_operator_count: usize,
27 pub max_depth: usize,
28 pub max_path_segments: usize,
29 pub max_cache_entries: usize,
30}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
34pub enum ExpressionBackend {
35 #[default]
37 Native,
38 CelCompatible,
40}
41
42impl Default for ExpressionLimits {
43 fn default() -> Self {
44 Self {
45 max_expression_chars: 2048,
46 max_operator_count: 64,
47 max_depth: 24,
48 max_path_segments: 16,
49 max_cache_entries: 512,
50 }
51 }
52}
53
54#[derive(Debug, Clone)]
55enum ExpressionNode {
56 Not(Box<ExpressionNode>),
57 Eq(Operand, Operand),
58 Ne(Operand, Operand),
59 Or(Vec<ExpressionNode>),
60 And(Vec<ExpressionNode>),
61 Truthy(Operand),
62}
63
64#[derive(Debug, Clone)]
65enum Operand {
66 Bool(bool),
67 Null,
68 Number(Number),
69 String(String),
70 Path(String),
71}
72
73#[derive(Debug, Default)]
74pub struct ExpressionEngine {
75 cache: Mutex<HashMap<String, ExpressionNode>>,
76 backend: ExpressionBackend,
77 limits: ExpressionLimits,
78}
79
80impl ExpressionEngine {
81 pub fn new() -> Self {
82 Self::with_limits(ExpressionLimits::default())
83 }
84
85 pub fn with_limits(limits: ExpressionLimits) -> Self {
86 Self::with_backend(ExpressionBackend::Native, limits)
87 }
88
89 pub fn with_backend(backend: ExpressionBackend, limits: ExpressionLimits) -> Self {
90 Self {
91 cache: Mutex::new(HashMap::new()),
92 backend,
93 limits,
94 }
95 }
96
97 pub fn validate(&self, expression: &str) -> Result<(), ExpressionError> {
98 let _ = self.compile(expression)?;
99 Ok(())
100 }
101
102 pub fn evaluate_bool(
103 &self,
104 expression: &str,
105 scoped_input: &Value,
106 ) -> Result<bool, ExpressionError> {
107 let node = self.compile(expression)?;
108 eval_node(&node, scoped_input)
109 }
110
111 fn cache_lock(&self) -> MutexGuard<'_, HashMap<String, ExpressionNode>> {
112 self.cache
113 .lock()
114 .unwrap_or_else(|poisoned| poisoned.into_inner())
115 }
116
117 fn compile(&self, expression: &str) -> Result<ExpressionNode, ExpressionError> {
118 let normalized = expression.trim();
119 if normalized.is_empty() {
120 return Err(ExpressionError::Empty);
121 }
122
123 if normalized.len() > self.limits.max_expression_chars {
124 return Err(ExpressionError::ComplexityLimitExceeded {
125 metric: "chars",
126 value: normalized.len(),
127 max: self.limits.max_expression_chars,
128 });
129 }
130
131 if let Some(cached) = self.cache_lock().get(normalized).cloned() {
132 return Ok(cached);
133 }
134
135 let parsed = match self.backend {
136 ExpressionBackend::Native | ExpressionBackend::CelCompatible => parse_expr(normalized)?,
137 };
138 let op_count = count_operators(&parsed);
139 if op_count > self.limits.max_operator_count {
140 return Err(ExpressionError::ComplexityLimitExceeded {
141 metric: "operators",
142 value: op_count,
143 max: self.limits.max_operator_count,
144 });
145 }
146 let depth = tree_depth(&parsed);
147 if depth > self.limits.max_depth {
148 return Err(ExpressionError::ComplexityLimitExceeded {
149 metric: "depth",
150 value: depth,
151 max: self.limits.max_depth,
152 });
153 }
154 validate_path_segments(&parsed, self.limits.max_path_segments)?;
155
156 let mut cache = self.cache_lock();
157 if cache.len() >= self.limits.max_cache_entries {
158 if let Some(evicted) = cache.keys().next().cloned() {
159 cache.remove(&evicted);
160 }
161 }
162 cache.insert(normalized.to_string(), parsed.clone());
163 Ok(parsed)
164 }
165}
166
167pub fn default_expression_engine() -> &'static ExpressionEngine {
168 static ENGINE: OnceLock<ExpressionEngine> = OnceLock::new();
169 ENGINE.get_or_init(ExpressionEngine::new)
170}
171
172pub fn evaluate_bool(expression: &str, scoped_input: &Value) -> Result<bool, ExpressionError> {
173 default_expression_engine().evaluate_bool(expression, scoped_input)
174}
175
176fn parse_expr(expression: &str) -> Result<ExpressionNode, ExpressionError> {
177 if let Some(parts) = split_top_level(expression, "||") {
178 let mut nodes = Vec::with_capacity(parts.len());
179 for part in parts {
180 nodes.push(parse_expr(part)?);
181 }
182 return Ok(ExpressionNode::Or(nodes));
183 }
184
185 if let Some(parts) = split_top_level(expression, "&&") {
186 let mut nodes = Vec::with_capacity(parts.len());
187 for part in parts {
188 nodes.push(parse_expr(part)?);
189 }
190 return Ok(ExpressionNode::And(nodes));
191 }
192
193 if let Some(inner) = expression.strip_prefix('!') {
194 return Ok(ExpressionNode::Not(Box::new(parse_expr(inner.trim())?)));
195 }
196
197 if let Some((left, right)) = split_once_top_level(expression, "==") {
198 return Ok(ExpressionNode::Eq(
199 parse_operand(left)?,
200 parse_operand(right)?,
201 ));
202 }
203
204 if let Some((left, right)) = split_once_top_level(expression, "!=") {
205 return Ok(ExpressionNode::Ne(
206 parse_operand(left)?,
207 parse_operand(right)?,
208 ));
209 }
210
211 Ok(ExpressionNode::Truthy(parse_operand(expression)?))
212}
213
214fn parse_operand(token: &str) -> Result<Operand, ExpressionError> {
215 let trimmed = token.trim();
216 if trimmed.is_empty() {
217 return Err(ExpressionError::Invalid {
218 expression: token.to_string(),
219 reason: "empty operand".to_string(),
220 });
221 }
222
223 if trimmed.eq_ignore_ascii_case("true") {
224 return Ok(Operand::Bool(true));
225 }
226 if trimmed.eq_ignore_ascii_case("false") {
227 return Ok(Operand::Bool(false));
228 }
229 if trimmed.eq_ignore_ascii_case("null") {
230 return Ok(Operand::Null);
231 }
232
233 if let Some(value) = trimmed.strip_prefix('"').and_then(|v| v.strip_suffix('"')) {
234 return Ok(Operand::String(value.to_string()));
235 }
236 if let Some(value) = trimmed
237 .strip_prefix('\'')
238 .and_then(|v| v.strip_suffix('\''))
239 {
240 return Ok(Operand::String(value.to_string()));
241 }
242
243 if let Ok(value) = trimmed.parse::<i64>() {
244 return Ok(Operand::Number(Number::from(value)));
245 }
246
247 if let Ok(value) = trimmed.parse::<f64>() {
248 if let Some(number) = Number::from_f64(value) {
249 return Ok(Operand::Number(number));
250 }
251 }
252
253 let path = trimmed.strip_prefix("$.").unwrap_or(trimmed);
254 if path.is_empty() {
255 return Err(ExpressionError::Invalid {
256 expression: token.to_string(),
257 reason: "path cannot be '$.'".to_string(),
258 });
259 }
260
261 Ok(Operand::Path(path.to_string()))
262}
263
264fn split_top_level<'a>(input: &'a str, delimiter: &'a str) -> Option<Vec<&'a str>> {
265 let mut parts = Vec::new();
266 let mut start = 0usize;
267 let mut in_single = false;
268 let mut in_double = false;
269 let bytes = input.as_bytes();
270 let delim = delimiter.as_bytes();
271 let mut idx = 0usize;
272
273 while idx < bytes.len() {
274 match bytes[idx] {
275 b'\'' if !in_double => in_single = !in_single,
276 b'"' if !in_single => in_double = !in_double,
277 _ => {}
278 }
279
280 if !in_single
281 && !in_double
282 && idx + delim.len() <= bytes.len()
283 && &bytes[idx..idx + delim.len()] == delim
284 {
285 parts.push(input[start..idx].trim());
286 start = idx + delim.len();
287 idx = start;
288 continue;
289 }
290 idx += 1;
291 }
292
293 if parts.is_empty() {
294 return None;
295 }
296 parts.push(input[start..].trim());
297 Some(parts)
298}
299
300fn split_once_top_level<'a>(input: &'a str, delimiter: &'a str) -> Option<(&'a str, &'a str)> {
301 let mut in_single = false;
302 let mut in_double = false;
303 let bytes = input.as_bytes();
304 let delim = delimiter.as_bytes();
305 let mut idx = 0usize;
306
307 while idx < bytes.len() {
308 match bytes[idx] {
309 b'\'' if !in_double => in_single = !in_single,
310 b'"' if !in_single => in_double = !in_double,
311 _ => {}
312 }
313
314 if !in_single
315 && !in_double
316 && idx + delim.len() <= bytes.len()
317 && &bytes[idx..idx + delim.len()] == delim
318 {
319 let left = input[..idx].trim();
320 let right = input[idx + delim.len()..].trim();
321 if left.is_empty() || right.is_empty() {
322 return None;
323 }
324 return Some((left, right));
325 }
326 idx += 1;
327 }
328
329 None
330}
331
332fn eval_node(node: &ExpressionNode, scoped_input: &Value) -> Result<bool, ExpressionError> {
333 match node {
334 ExpressionNode::Not(inner) => Ok(!eval_node(inner, scoped_input)?),
335 ExpressionNode::Eq(left, right) => {
336 Ok(eval_operand(left, scoped_input)? == eval_operand(right, scoped_input)?)
337 }
338 ExpressionNode::Ne(left, right) => {
339 Ok(eval_operand(left, scoped_input)? != eval_operand(right, scoped_input)?)
340 }
341 ExpressionNode::Or(nodes) => {
342 for node in nodes {
343 if eval_node(node, scoped_input)? {
344 return Ok(true);
345 }
346 }
347 Ok(false)
348 }
349 ExpressionNode::And(nodes) => {
350 for node in nodes {
351 if !eval_node(node, scoped_input)? {
352 return Ok(false);
353 }
354 }
355 Ok(true)
356 }
357 ExpressionNode::Truthy(operand) => Ok(is_truthy(&eval_operand(operand, scoped_input)?)),
358 }
359}
360
361fn eval_operand(operand: &Operand, scoped_input: &Value) -> Result<Value, ExpressionError> {
362 match operand {
363 Operand::Bool(v) => Ok(Value::Bool(*v)),
364 Operand::Null => Ok(Value::Null),
365 Operand::Number(v) => Ok(Value::Number(v.clone())),
366 Operand::String(v) => Ok(Value::String(v.clone())),
367 Operand::Path(path) => resolve_path(scoped_input, path)
368 .cloned()
369 .ok_or_else(|| ExpressionError::MissingPath { path: path.clone() }),
370 }
371}
372
373fn resolve_path<'a>(root: &'a Value, path: &str) -> Option<&'a Value> {
374 if path.is_empty() {
375 return Some(root);
376 }
377
378 path.split('.')
379 .filter(|segment| !segment.is_empty())
380 .try_fold(root, |current, segment| current.get(segment))
381}
382
383fn is_truthy(value: &Value) -> bool {
384 match value {
385 Value::Bool(value) => *value,
386 Value::Null => false,
387 Value::Number(number) => number.as_f64().is_some_and(|n| n != 0.0),
388 Value::String(value) => !value.is_empty(),
389 Value::Array(values) => !values.is_empty(),
390 Value::Object(values) => !values.is_empty(),
391 }
392}
393
394fn count_operators(node: &ExpressionNode) -> usize {
395 match node {
396 ExpressionNode::Not(inner) => 1 + count_operators(inner),
397 ExpressionNode::Eq(..) | ExpressionNode::Ne(..) | ExpressionNode::Truthy(..) => 1,
398 ExpressionNode::Or(nodes) | ExpressionNode::And(nodes) => {
399 1 + nodes.iter().map(count_operators).sum::<usize>()
400 }
401 }
402}
403
404fn tree_depth(node: &ExpressionNode) -> usize {
405 match node {
406 ExpressionNode::Not(inner) => 1 + tree_depth(inner),
407 ExpressionNode::Eq(..) | ExpressionNode::Ne(..) | ExpressionNode::Truthy(..) => 1,
408 ExpressionNode::Or(nodes) | ExpressionNode::And(nodes) => {
409 1 + nodes.iter().map(tree_depth).max().unwrap_or(0)
410 }
411 }
412}
413
414fn validate_path_segments(
415 node: &ExpressionNode,
416 max_segments: usize,
417) -> Result<(), ExpressionError> {
418 match node {
419 ExpressionNode::Not(inner) => validate_path_segments(inner, max_segments),
420 ExpressionNode::Eq(left, right) | ExpressionNode::Ne(left, right) => {
421 validate_operand_path(left, max_segments)?;
422 validate_operand_path(right, max_segments)
423 }
424 ExpressionNode::Or(nodes) | ExpressionNode::And(nodes) => {
425 for item in nodes {
426 validate_path_segments(item, max_segments)?;
427 }
428 Ok(())
429 }
430 ExpressionNode::Truthy(operand) => validate_operand_path(operand, max_segments),
431 }
432}
433
434fn validate_operand_path(operand: &Operand, max_segments: usize) -> Result<(), ExpressionError> {
435 if let Operand::Path(path) = operand {
436 let segments = path
437 .split('.')
438 .filter(|segment| !segment.is_empty())
439 .count();
440 if segments > max_segments {
441 return Err(ExpressionError::ComplexityLimitExceeded {
442 metric: "path_segments",
443 value: segments,
444 max: max_segments,
445 });
446 }
447 }
448 Ok(())
449}
450
451#[cfg(test)]
452mod tests {
453 use serde_json::json;
454
455 use super::{ExpressionBackend, ExpressionEngine, ExpressionError, ExpressionLimits};
456
457 #[test]
458 fn supports_truthy_path_and_equality() {
459 let engine = ExpressionEngine::new();
460 let input = json!({"input": {"approved": true}, "score": 5});
461
462 assert!(engine
463 .evaluate_bool("input.approved", &input)
464 .expect("truthy check should pass"));
465 assert!(engine
466 .evaluate_bool("score == 5", &input)
467 .expect("equality check should pass"));
468 assert!(engine
469 .evaluate_bool("score != 2", &input)
470 .expect("inequality check should pass"));
471 }
472
473 #[test]
474 fn supports_boolean_operators() {
475 let engine = ExpressionEngine::new();
476 let input = json!({"a": true, "b": false, "n": 1});
477
478 assert!(engine
479 .evaluate_bool("a && n == 1", &input)
480 .expect("and expression should pass"));
481 assert!(engine
482 .evaluate_bool("b || n == 1", &input)
483 .expect("or expression should pass"));
484 assert!(engine
485 .evaluate_bool("!b", &input)
486 .expect("not expression should pass"));
487 }
488
489 #[test]
490 fn reports_missing_path() {
491 let engine = ExpressionEngine::new();
492 let error = engine
493 .evaluate_bool("missing.path", &json!({}))
494 .expect_err("missing path should fail");
495 assert!(matches!(error, ExpressionError::MissingPath { .. }));
496 }
497
498 #[test]
499 fn validate_uses_parse_cache() {
500 let engine = ExpressionEngine::new();
501 engine
502 .validate("input.ready == true")
503 .expect("first parse should pass");
504 engine
505 .validate("input.ready == true")
506 .expect("second parse should hit cache");
507 }
508
509 #[test]
510 fn rejects_expression_when_depth_limit_exceeded() {
511 let engine = ExpressionEngine::with_limits(ExpressionLimits {
512 max_depth: 1,
513 ..ExpressionLimits::default()
514 });
515
516 let error = engine
517 .evaluate_bool("a && b && c", &json!({"a": true, "b": true, "c": true}))
518 .expect_err("depth guard should reject expression");
519 assert!(matches!(
520 error,
521 ExpressionError::ComplexityLimitExceeded {
522 metric: "depth",
523 ..
524 }
525 ));
526 }
527
528 #[test]
529 fn rejects_expression_when_path_segments_limit_exceeded() {
530 let engine = ExpressionEngine::with_limits(ExpressionLimits {
531 max_path_segments: 2,
532 ..ExpressionLimits::default()
533 });
534
535 let error = engine
536 .evaluate_bool(
537 "input.deep.value == true",
538 &json!({"input": {"deep": {"value": true}}}),
539 )
540 .expect_err("path segment guard should reject expression");
541 assert!(matches!(
542 error,
543 ExpressionError::ComplexityLimitExceeded {
544 metric: "path_segments",
545 ..
546 }
547 ));
548 }
549
550 #[test]
551 fn supports_cel_compatible_backend_path() {
552 let engine = ExpressionEngine::with_backend(
553 ExpressionBackend::CelCompatible,
554 ExpressionLimits::default(),
555 );
556 let result = engine
557 .evaluate_bool("input.ready == true", &json!({"input": {"ready": true}}))
558 .expect("cel-compatible backend should evaluate expression");
559 assert!(result);
560 }
561}