1use shape_ast::ast::{Expr, Literal, ObjectEntry};
63use shape_ast::error::{Result, ShapeError};
64use shape_value::ValueWord;
65use std::collections::HashMap;
66use std::sync::Arc;
67
68#[derive(Debug, Clone)]
70pub struct ConstEvaluator {
71 params: HashMap<String, ValueWord>,
74}
75
76impl ConstEvaluator {
77 pub fn new() -> Self {
79 Self {
80 params: HashMap::new(),
81 }
82 }
83
84 pub fn with_params(params: HashMap<String, ValueWord>) -> Self {
86 Self {
87 params: params.into_iter().map(|(k, v)| (k, v)).collect(),
88 }
89 }
90
91 pub fn add_param(&mut self, name: String, value: ValueWord) {
93 self.params.insert(name, value);
94 }
95
96 pub fn add_param_nb(&mut self, name: String, value: ValueWord) {
98 self.params.insert(name, value);
99 }
100
101 pub fn eval(&self, expr: &Expr) -> Result<ValueWord> {
105 Ok(self.eval_nb(expr)?.clone())
106 }
107
108 pub fn eval_as_nb(&self, expr: &Expr) -> Result<ValueWord> {
110 self.eval_nb(expr)
111 }
112
113 fn eval_nb(&self, expr: &Expr) -> Result<ValueWord> {
115 match expr {
116 Expr::Literal(lit, _) => match lit {
118 Literal::Int(i) => Ok(ValueWord::from_f64(*i as f64)),
119 Literal::UInt(u) => Ok(ValueWord::from_native_u64(*u)),
120 Literal::TypedInt(v, _) => Ok(ValueWord::from_i64(*v)),
121 Literal::Number(n) => Ok(ValueWord::from_f64(*n)),
122 Literal::Decimal(d) => {
123 use rust_decimal::prelude::ToPrimitive;
124 Ok(ValueWord::from_f64(d.to_f64().unwrap_or(0.0)))
125 }
126 Literal::String(s) => Ok(ValueWord::from_string(Arc::new(s.clone()))),
127 Literal::FormattedString { value, .. } => {
128 Ok(ValueWord::from_string(Arc::new(value.clone())))
129 }
130 Literal::ContentString { value, .. } => {
131 Ok(ValueWord::from_string(Arc::new(value.clone())))
132 }
133 Literal::Bool(b) => Ok(ValueWord::from_bool(*b)),
134 Literal::None => Ok(ValueWord::none()),
135 Literal::Unit => Ok(ValueWord::unit()),
136 Literal::Timeframe(tf) => Ok(ValueWord::from_timeframe(*tf)),
137 },
138
139 Expr::Object(entries, _) => {
141 let mut pairs: Vec<(String, ValueWord)> = Vec::new();
142 for entry in entries {
143 match entry {
144 ObjectEntry::Field {
145 key,
146 value,
147 type_annotation: _,
148 } => {
149 let val = self.eval_nb(value)?;
150 pairs.push((key.clone(), val));
151 }
152 ObjectEntry::Spread(_) => {
153 return Err(ShapeError::RuntimeError {
154 message: "Object spread (...) not allowed in const context"
155 .to_string(),
156 location: None,
157 });
158 }
159 }
160 }
161 let ref_pairs: Vec<(&str, ValueWord)> =
162 pairs.iter().map(|(k, v)| (k.as_str(), v.clone())).collect();
163 Ok(crate::type_schema::typed_object_from_nb_pairs(&ref_pairs))
164 }
165
166 Expr::Array(elements, _) => {
168 let mut arr = Vec::new();
169 for elem in elements {
170 arr.push(self.eval_nb(elem)?);
171 }
172 Ok(ValueWord::from_array(Arc::new(arr)))
173 }
174
175 Expr::Identifier(name, _span) => {
177 self.params
178 .get(name)
179 .cloned()
180 .ok_or_else(|| ShapeError::RuntimeError {
181 message: format!(
182 "Cannot reference variable '{}' in const context (metadata()). \
183 Only annotation parameters are allowed.",
184 name
185 ),
186 location: None,
187 })
188 }
189
190 Expr::BinaryOp {
192 left,
193 op,
194 right,
195 span: _,
196 } => {
197 let left_val = self.eval_nb(left)?;
198 let right_val = self.eval_nb(right)?;
199
200 use shape_ast::ast::BinaryOp;
201 match op {
202 BinaryOp::Add => self.const_add_nb(left_val, right_val),
204 BinaryOp::Sub => {
205 self.const_arith_nb(left_val, right_val, "subtraction", |a, b| a - b)
206 }
207 BinaryOp::Mul => {
208 self.const_arith_nb(left_val, right_val, "multiplication", |a, b| a * b)
209 }
210 BinaryOp::Div => {
211 let a = left_val.as_f64().ok_or_else(|| ShapeError::RuntimeError {
212 message: "Const division only works on numbers".to_string(),
213 location: None,
214 })?;
215 let b = right_val.as_f64().ok_or_else(|| ShapeError::RuntimeError {
216 message: "Const division only works on numbers".to_string(),
217 location: None,
218 })?;
219 if b == 0.0 {
220 Err(ShapeError::RuntimeError {
221 message: "Division by zero in const context".to_string(),
222 location: None,
223 })
224 } else {
225 Ok(ValueWord::from_f64(a / b))
226 }
227 }
228 BinaryOp::Mod => {
229 self.const_arith_nb(left_val, right_val, "modulo", |a, b| a % b)
230 }
231
232 BinaryOp::Equal => Ok(ValueWord::from_bool(left_val.vw_equals(&right_val))),
234 BinaryOp::NotEqual => Ok(ValueWord::from_bool(!left_val.vw_equals(&right_val))),
235 BinaryOp::Less => self.const_compare_nb(left_val, right_val, |a, b| a < b),
236 BinaryOp::LessEq => self.const_compare_nb(left_val, right_val, |a, b| a <= b),
237 BinaryOp::Greater => self.const_compare_nb(left_val, right_val, |a, b| a > b),
238 BinaryOp::GreaterEq => {
239 self.const_compare_nb(left_val, right_val, |a, b| a >= b)
240 }
241
242 BinaryOp::And => Ok(ValueWord::from_bool(
244 left_val.is_truthy() && right_val.is_truthy(),
245 )),
246 BinaryOp::Or => Ok(ValueWord::from_bool(
247 left_val.is_truthy() || right_val.is_truthy(),
248 )),
249
250 _ => Err(ShapeError::RuntimeError {
252 message: format!("Binary operator {:?} not allowed in const context", op),
253 location: None,
254 }),
255 }
256 }
257
258 Expr::UnaryOp {
260 op,
261 operand,
262 span: _,
263 } => {
264 let val = self.eval_nb(operand)?;
265 use shape_ast::ast::UnaryOp;
266 match op {
267 UnaryOp::Not => Ok(ValueWord::from_bool(!val.is_truthy())),
268 UnaryOp::Neg => {
269 if let Some(n) = val.as_f64() {
270 Ok(ValueWord::from_f64(-n))
271 } else {
272 Err(ShapeError::RuntimeError {
273 message: "Cannot negate non-number in const context".to_string(),
274 location: None,
275 })
276 }
277 }
278 UnaryOp::BitNot => Err(ShapeError::RuntimeError {
279 message: "Bitwise NOT not allowed in const context".to_string(),
280 location: None,
281 }),
282 }
283 }
284
285 Expr::FunctionCall { .. } => Err(ShapeError::RuntimeError {
287 message: "Function calls are not allowed in const context (metadata())".to_string(),
288 location: None,
289 }),
290
291 Expr::PropertyAccess { .. } => Err(ShapeError::RuntimeError {
292 message:
293 "Property access (obj.field) is not allowed in const context (metadata()). \
294 Cannot access runtime state like ctx.* or fn.*"
295 .to_string(),
296 location: None,
297 }),
298
299 _ => Err(ShapeError::RuntimeError {
300 message: format!(
301 "Expression type not allowed in const context (metadata()): {:?}",
302 expr
303 ),
304 location: None,
305 }),
306 }
307 }
308
309 fn const_add_nb(&self, left: ValueWord, right: ValueWord) -> Result<ValueWord> {
312 if let (Some(a), Some(b)) = (left.as_f64(), right.as_f64()) {
313 return Ok(ValueWord::from_f64(a + b));
314 }
315 if let (Some(a), Some(b)) = (left.as_str(), right.as_str()) {
316 return Ok(ValueWord::from_string(Arc::new(format!("{}{}", a, b))));
317 }
318 Err(ShapeError::RuntimeError {
319 message: "Const addition only works on numbers or strings".to_string(),
320 location: None,
321 })
322 }
323
324 fn const_arith_nb(
325 &self,
326 left: ValueWord,
327 right: ValueWord,
328 op_name: &str,
329 f: fn(f64, f64) -> f64,
330 ) -> Result<ValueWord> {
331 let a = left.as_f64().ok_or_else(|| ShapeError::RuntimeError {
332 message: format!("Const {} only works on numbers", op_name),
333 location: None,
334 })?;
335 let b = right.as_f64().ok_or_else(|| ShapeError::RuntimeError {
336 message: format!("Const {} only works on numbers", op_name),
337 location: None,
338 })?;
339 Ok(ValueWord::from_f64(f(a, b)))
340 }
341
342 fn const_compare_nb(
343 &self,
344 left: ValueWord,
345 right: ValueWord,
346 cmp: fn(f64, f64) -> bool,
347 ) -> Result<ValueWord> {
348 let a = left.as_f64().ok_or_else(|| ShapeError::RuntimeError {
349 message: "Const comparison only works on numbers".to_string(),
350 location: None,
351 })?;
352 let b = right.as_f64().ok_or_else(|| ShapeError::RuntimeError {
353 message: "Const comparison only works on numbers".to_string(),
354 location: None,
355 })?;
356 Ok(ValueWord::from_bool(cmp(a, b)))
357 }
358}
359
360impl Default for ConstEvaluator {
361 fn default() -> Self {
362 Self::new()
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369 use shape_ast::ast::Span;
370 use std::sync::Arc;
371
372 #[test]
373 fn test_const_number_literal() {
374 let evaluator = ConstEvaluator::new();
375 let expr = Expr::Literal(Literal::Number(42.0), Span::DUMMY);
376 let result = evaluator.eval(&expr).unwrap();
377 assert_eq!(result, ValueWord::from_f64(42.0));
378 }
379
380 #[test]
381 fn test_const_string_literal() {
382 let evaluator = ConstEvaluator::new();
383 let expr = Expr::Literal(Literal::String("hello".to_string()), Span::DUMMY);
384 let result = evaluator.eval(&expr).unwrap();
385 assert_eq!(
386 result,
387 ValueWord::from_string(Arc::new("hello".to_string()))
388 );
389 }
390
391 #[test]
392 fn test_const_formatted_string_literal() {
393 let evaluator = ConstEvaluator::new();
394 let expr = Expr::Literal(
395 Literal::FormattedString {
396 value: "value: {x}".to_string(),
397 mode: shape_ast::ast::InterpolationMode::Braces,
398 },
399 Span::DUMMY,
400 );
401 let result = evaluator.eval(&expr).unwrap();
402 assert_eq!(
403 result,
404 ValueWord::from_string(Arc::new("value: {x}".to_string()))
405 );
406 }
407
408 #[test]
409 fn test_const_boolean_literal() {
410 let evaluator = ConstEvaluator::new();
411 let expr = Expr::Literal(Literal::Bool(true), Span::DUMMY);
412 let result = evaluator.eval(&expr).unwrap();
413 assert_eq!(result, ValueWord::from_bool(true));
414 }
415
416 #[test]
417 fn test_const_object_literal() {
418 let evaluator = ConstEvaluator::new();
419 let expr = Expr::Object(
420 vec![
421 ObjectEntry::Field {
422 key: "key1".to_string(),
423 value: Expr::Literal(Literal::Number(42.0), Span::DUMMY),
424 type_annotation: None,
425 },
426 ObjectEntry::Field {
427 key: "key2".to_string(),
428 value: Expr::Literal(Literal::String("value".to_string()), Span::DUMMY),
429 type_annotation: None,
430 },
431 ],
432 Span::DUMMY,
433 );
434 let result = evaluator.eval(&expr).unwrap();
435
436 let obj =
437 crate::type_schema::typed_object_to_hashmap_nb(&result).expect("Expected TypedObject");
438 assert_eq!(obj.get("key1").and_then(|v| v.as_f64()), Some(42.0));
439 assert_eq!(obj.get("key2").and_then(|v| v.as_str()), Some("value"));
440 }
441
442 #[test]
443 fn test_const_array_literal() {
444 let evaluator = ConstEvaluator::new();
445 let expr = Expr::Array(
446 vec![
447 Expr::Literal(Literal::Number(1.0), Span::DUMMY),
448 Expr::Literal(Literal::Number(2.0), Span::DUMMY),
449 Expr::Literal(Literal::Number(3.0), Span::DUMMY),
450 ],
451 Span::DUMMY,
452 );
453 let result = evaluator.eval(&expr).unwrap();
454
455 let arr = result.as_any_array().expect("Expected array").to_generic();
456 assert_eq!(arr.len(), 3);
457 assert_eq!(arr[0].as_f64(), Some(1.0));
458 assert_eq!(arr[1].as_f64(), Some(2.0));
459 assert_eq!(arr[2].as_f64(), Some(3.0));
460 }
461
462 #[test]
463 fn test_const_arithmetic_add() {
464 let evaluator = ConstEvaluator::new();
465 let expr = Expr::BinaryOp {
466 left: Box::new(Expr::Literal(Literal::Number(2.0), Span::DUMMY)),
467 op: shape_ast::ast::BinaryOp::Add,
468 right: Box::new(Expr::Literal(Literal::Number(3.0), Span::DUMMY)),
469 span: Span::DUMMY,
470 };
471 let result = evaluator.eval(&expr).unwrap();
472 assert_eq!(result, ValueWord::from_f64(5.0));
473 }
474
475 #[test]
476 fn test_const_string_concat() {
477 let evaluator = ConstEvaluator::new();
478 let expr = Expr::BinaryOp {
479 left: Box::new(Expr::Literal(
480 Literal::String("hello ".to_string()),
481 Span::DUMMY,
482 )),
483 op: shape_ast::ast::BinaryOp::Add,
484 right: Box::new(Expr::Literal(
485 Literal::String("world".to_string()),
486 Span::DUMMY,
487 )),
488 span: Span::DUMMY,
489 };
490 let result = evaluator.eval(&expr).unwrap();
491 assert_eq!(
492 result,
493 ValueWord::from_string(Arc::new("hello world".to_string()))
494 );
495 }
496
497 #[test]
498 fn test_const_annotation_param() {
499 let mut evaluator = ConstEvaluator::new();
500 evaluator.add_param("period".to_string(), ValueWord::from_f64(20.0));
501
502 let expr = Expr::Identifier("period".to_string(), Span::DUMMY);
503 let result = evaluator.eval(&expr).unwrap();
504 assert_eq!(result, ValueWord::from_f64(20.0));
505 }
506
507 #[test]
508 fn test_const_nested_object() {
509 let evaluator = ConstEvaluator::new();
510 let expr = Expr::Object(
511 vec![
512 ObjectEntry::Field {
513 key: "is_test".to_string(),
514 value: Expr::Literal(Literal::Bool(true), Span::DUMMY),
515 type_annotation: None,
516 },
517 ObjectEntry::Field {
518 key: "code_lens".to_string(),
519 value: Expr::Array(
520 vec![Expr::Object(
521 vec![
522 ObjectEntry::Field {
523 key: "title".to_string(),
524 value: Expr::Literal(
525 Literal::String("Run".to_string()),
526 Span::DUMMY,
527 ),
528 type_annotation: None,
529 },
530 ObjectEntry::Field {
531 key: "command".to_string(),
532 value: Expr::Literal(
533 Literal::String("run".to_string()),
534 Span::DUMMY,
535 ),
536 type_annotation: None,
537 },
538 ],
539 Span::DUMMY,
540 )],
541 Span::DUMMY,
542 ),
543 type_annotation: None,
544 },
545 ],
546 Span::DUMMY,
547 );
548 let result = evaluator.eval(&expr).unwrap();
549
550 let obj =
551 crate::type_schema::typed_object_to_hashmap_nb(&result).expect("Expected TypedObject");
552 assert_eq!(obj.get("is_test").and_then(|v| v.as_bool()), Some(true));
553 assert!(
554 obj.get("code_lens")
555 .and_then(|v| v.as_any_array())
556 .is_some()
557 );
558 }
559
560 #[test]
561 fn test_const_function_call_fails() {
562 let evaluator = ConstEvaluator::new();
563 let expr = Expr::FunctionCall {
564 name: "foo".to_string(),
565 args: vec![],
566 named_args: vec![],
567 span: Span::DUMMY,
568 };
569 let result = evaluator.eval(&expr);
570 assert!(result.is_err());
571 assert!(
572 result
573 .unwrap_err()
574 .to_string()
575 .contains("not allowed in const context")
576 );
577 }
578
579 #[test]
580 fn test_const_undefined_variable_fails() {
581 let evaluator = ConstEvaluator::new();
582 let expr = Expr::Identifier("undefined_var".to_string(), Span::DUMMY);
583 let result = evaluator.eval(&expr);
584 assert!(result.is_err());
585 assert!(
586 result
587 .unwrap_err()
588 .to_string()
589 .contains("annotation parameters")
590 );
591 }
592}