Skip to main content

shape_runtime/
const_eval.rs

1//! Const evaluation for annotation metadata() handlers
2//!
3//! This module provides compile-time evaluation of Shape expressions.
4//! Only a subset of expressions are allowed (literals, object/array construction,
5//! annotation parameters, and const arithmetic).
6//!
7//! ## Purpose
8//!
9//! Const evaluation enables:
10//! - **LSP to extract metadata without runtime execution** - Show code lenses, hover info
11//! - **Compiler optimizations** based on static metadata (pure functions, cacheable results)
12//! - **Static analysis and documentation generation**
13//!
14//! ## How It Works
15//!
16//! When the LSP encounters an `annotation ... { ... }` definition, it:
17//!
18//! 1. **Parses the annotation definition** (from current file or imports)
19//! 2. **Finds the `metadata()` handler** in the handlers list
20//! 3. **Const-evaluates the handler body** using this module
21//! 4. **Extracts special properties** from the result:
22//!    - `code_lens: [...]` → Creates IDE action buttons
23//!    - `pure: true` → Marks for compiler optimization
24//!    - Custom properties → Stored for user's own tooling
25//!
26//! Example:
27//!
28//! ```shape
29//! annotation strategy() {
30//!     metadata() {
31//!         {
32//!             is_strategy: true,              // Custom metadata
33//!             code_lens: [                    // Special: IDE integration
34//!                 { title: "▶ Run", command: "shape.runBacktest" }
35//!             ]
36//!         }
37//!     }
38//! }
39//! ```
40//!
41//! When a function has `@strategy`, the LSP:
42//! 1. Looks up the `@strategy` annotation definition
43//! 2. Const-evaluates `metadata()` → `{ is_strategy: true, code_lens: [...] }`
44//! 3. Creates a "▶ Run" button above the function
45//!
46//! ## Allowed Constructs
47//!
48//! - Literals: `42`, `"hello"`, `true`, `null`
49//! - Objects: `{ key: value, ... }`
50//! - Arrays: `[1, 2, 3]`
51//! - Annotation parameters (captured in scope)
52//! - Const arithmetic: `2 + 2`, `"a" + "b"`
53//!
54//! ## Not Allowed
55//!
56//! - Function calls (runtime dependency)
57//! - Variable references (except annotation parameters)
58//! - `ctx` or `fn` access (runtime state)
59//! - Side effects
60//! - Non-const conditionals
61
62use shape_ast::ast::{Expr, Literal, ObjectEntry};
63use shape_ast::error::{Result, ShapeError};
64use shape_value::ValueWord;
65use std::collections::HashMap;
66use std::sync::Arc;
67
68/// Const evaluator for metadata() handlers
69#[derive(Debug, Clone)]
70pub struct ConstEvaluator {
71    /// Annotation parameters available during evaluation
72    /// Maps parameter name → const value
73    params: HashMap<String, ValueWord>,
74}
75
76impl ConstEvaluator {
77    /// Create a new const evaluator with annotation parameters
78    pub fn new() -> Self {
79        Self {
80            params: HashMap::new(),
81        }
82    }
83
84    /// Create a const evaluator with annotation parameters
85    pub fn with_params(params: HashMap<String, ValueWord>) -> Self {
86        Self {
87            params: params.into_iter().map(|(k, v)| (k, v)).collect(),
88        }
89    }
90
91    /// Add an annotation parameter to the scope
92    pub fn add_param(&mut self, name: String, value: ValueWord) {
93        self.params.insert(name, value);
94    }
95
96    /// Add an annotation parameter to the scope (ValueWord, avoids ValueWord conversion)
97    pub fn add_param_nb(&mut self, name: String, value: ValueWord) {
98        self.params.insert(name, value);
99    }
100
101    /// Evaluate an expression as a const (compile-time) value
102    ///
103    /// Returns an error if the expression uses non-const constructs.
104    pub fn eval(&self, expr: &Expr) -> Result<ValueWord> {
105        Ok(self.eval_nb(expr)?.clone())
106    }
107
108    /// Evaluate an expression as a const ValueWord value (avoids ValueWord materialization)
109    pub fn eval_as_nb(&self, expr: &Expr) -> Result<ValueWord> {
110        self.eval_nb(expr)
111    }
112
113    /// Evaluate an expression as a const ValueWord value
114    fn eval_nb(&self, expr: &Expr) -> Result<ValueWord> {
115        match expr {
116            // Literals are always const
117            Expr::Literal(lit, _) => match lit {
118                Literal::Int(i) => Ok(ValueWord::from_f64(*i as f64)),
119                Literal::UInt(u) => Ok(ValueWord::from_native_u64(*u)),
120                Literal::TypedInt(v, _) => Ok(ValueWord::from_i64(*v)),
121                Literal::Number(n) => Ok(ValueWord::from_f64(*n)),
122                Literal::Decimal(d) => {
123                    use rust_decimal::prelude::ToPrimitive;
124                    Ok(ValueWord::from_f64(d.to_f64().unwrap_or(0.0)))
125                }
126                Literal::String(s) => Ok(ValueWord::from_string(Arc::new(s.clone()))),
127                Literal::FormattedString { value, .. } => {
128                    Ok(ValueWord::from_string(Arc::new(value.clone())))
129                }
130                Literal::ContentString { value, .. } => {
131                    Ok(ValueWord::from_string(Arc::new(value.clone())))
132                }
133                Literal::Bool(b) => Ok(ValueWord::from_bool(*b)),
134                Literal::None => Ok(ValueWord::none()),
135                Literal::Unit => Ok(ValueWord::unit()),
136                Literal::Timeframe(tf) => Ok(ValueWord::from_timeframe(*tf)),
137            },
138
139            // Object literals - recursively evaluate all values
140            Expr::Object(entries, _) => {
141                let mut pairs: Vec<(String, ValueWord)> = Vec::new();
142                for entry in entries {
143                    match entry {
144                        ObjectEntry::Field {
145                            key,
146                            value,
147                            type_annotation: _,
148                        } => {
149                            let val = self.eval_nb(value)?;
150                            pairs.push((key.clone(), val));
151                        }
152                        ObjectEntry::Spread(_) => {
153                            return Err(ShapeError::RuntimeError {
154                                message: "Object spread (...) not allowed in const context"
155                                    .to_string(),
156                                location: None,
157                            });
158                        }
159                    }
160                }
161                let ref_pairs: Vec<(&str, ValueWord)> =
162                    pairs.iter().map(|(k, v)| (k.as_str(), v.clone())).collect();
163                Ok(crate::type_schema::typed_object_from_nb_pairs(&ref_pairs))
164            }
165
166            // Array literals - recursively evaluate all elements
167            Expr::Array(elements, _) => {
168                let mut arr = Vec::new();
169                for elem in elements {
170                    arr.push(self.eval_nb(elem)?);
171                }
172                Ok(ValueWord::from_array(Arc::new(arr)))
173            }
174
175            // Identifiers - only allowed if they're annotation parameters
176            Expr::Identifier(name, _span) => {
177                self.params
178                    .get(name)
179                    .cloned()
180                    .ok_or_else(|| ShapeError::RuntimeError {
181                        message: format!(
182                            "Cannot reference variable '{}' in const context (metadata()). \
183                             Only annotation parameters are allowed.",
184                            name
185                        ),
186                        location: None,
187                    })
188            }
189
190            // Binary operations - only const arithmetic/string concat
191            Expr::BinaryOp {
192                left,
193                op,
194                right,
195                span: _,
196            } => {
197                let left_val = self.eval_nb(left)?;
198                let right_val = self.eval_nb(right)?;
199
200                use shape_ast::ast::BinaryOp;
201                match op {
202                    // Arithmetic
203                    BinaryOp::Add => self.const_add_nb(left_val, right_val),
204                    BinaryOp::Sub => {
205                        self.const_arith_nb(left_val, right_val, "subtraction", |a, b| a - b)
206                    }
207                    BinaryOp::Mul => {
208                        self.const_arith_nb(left_val, right_val, "multiplication", |a, b| a * b)
209                    }
210                    BinaryOp::Div => {
211                        let a = left_val.as_f64().ok_or_else(|| ShapeError::RuntimeError {
212                            message: "Const division only works on numbers".to_string(),
213                            location: None,
214                        })?;
215                        let b = right_val.as_f64().ok_or_else(|| ShapeError::RuntimeError {
216                            message: "Const division only works on numbers".to_string(),
217                            location: None,
218                        })?;
219                        if b == 0.0 {
220                            Err(ShapeError::RuntimeError {
221                                message: "Division by zero in const context".to_string(),
222                                location: None,
223                            })
224                        } else {
225                            Ok(ValueWord::from_f64(a / b))
226                        }
227                    }
228                    BinaryOp::Mod => {
229                        self.const_arith_nb(left_val, right_val, "modulo", |a, b| a % b)
230                    }
231
232                    // Comparison
233                    BinaryOp::Equal => Ok(ValueWord::from_bool(left_val.vw_equals(&right_val))),
234                    BinaryOp::NotEqual => Ok(ValueWord::from_bool(!left_val.vw_equals(&right_val))),
235                    BinaryOp::Less => self.const_compare_nb(left_val, right_val, |a, b| a < b),
236                    BinaryOp::LessEq => self.const_compare_nb(left_val, right_val, |a, b| a <= b),
237                    BinaryOp::Greater => self.const_compare_nb(left_val, right_val, |a, b| a > b),
238                    BinaryOp::GreaterEq => {
239                        self.const_compare_nb(left_val, right_val, |a, b| a >= b)
240                    }
241
242                    // Logical
243                    BinaryOp::And => Ok(ValueWord::from_bool(
244                        left_val.is_truthy() && right_val.is_truthy(),
245                    )),
246                    BinaryOp::Or => Ok(ValueWord::from_bool(
247                        left_val.is_truthy() || right_val.is_truthy(),
248                    )),
249
250                    // Not allowed in const context
251                    _ => Err(ShapeError::RuntimeError {
252                        message: format!("Binary operator {:?} not allowed in const context", op),
253                        location: None,
254                    }),
255                }
256            }
257
258            // Unary operations
259            Expr::UnaryOp {
260                op,
261                operand,
262                span: _,
263            } => {
264                let val = self.eval_nb(operand)?;
265                use shape_ast::ast::UnaryOp;
266                match op {
267                    UnaryOp::Not => Ok(ValueWord::from_bool(!val.is_truthy())),
268                    UnaryOp::Neg => {
269                        if let Some(n) = val.as_f64() {
270                            Ok(ValueWord::from_f64(-n))
271                        } else {
272                            Err(ShapeError::RuntimeError {
273                                message: "Cannot negate non-number in const context".to_string(),
274                                location: None,
275                            })
276                        }
277                    }
278                    UnaryOp::BitNot => Err(ShapeError::RuntimeError {
279                        message: "Bitwise NOT not allowed in const context".to_string(),
280                        location: None,
281                    }),
282                }
283            }
284
285            // Everything else is not allowed in const context
286            Expr::FunctionCall { .. } => Err(ShapeError::RuntimeError {
287                message: "Function calls are not allowed in const context (metadata())".to_string(),
288                location: None,
289            }),
290
291            Expr::PropertyAccess { .. } => Err(ShapeError::RuntimeError {
292                message:
293                    "Property access (obj.field) is not allowed in const context (metadata()). \
294                         Cannot access runtime state like ctx.* or fn.*"
295                        .to_string(),
296                location: None,
297            }),
298
299            _ => Err(ShapeError::RuntimeError {
300                message: format!(
301                    "Expression type not allowed in const context (metadata()): {:?}",
302                    expr
303                ),
304                location: None,
305            }),
306        }
307    }
308
309    // Const arithmetic operations (ValueWord)
310
311    fn const_add_nb(&self, left: ValueWord, right: ValueWord) -> Result<ValueWord> {
312        if let (Some(a), Some(b)) = (left.as_f64(), right.as_f64()) {
313            return Ok(ValueWord::from_f64(a + b));
314        }
315        if let (Some(a), Some(b)) = (left.as_str(), right.as_str()) {
316            return Ok(ValueWord::from_string(Arc::new(format!("{}{}", a, b))));
317        }
318        Err(ShapeError::RuntimeError {
319            message: "Const addition only works on numbers or strings".to_string(),
320            location: None,
321        })
322    }
323
324    fn const_arith_nb(
325        &self,
326        left: ValueWord,
327        right: ValueWord,
328        op_name: &str,
329        f: fn(f64, f64) -> f64,
330    ) -> Result<ValueWord> {
331        let a = left.as_f64().ok_or_else(|| ShapeError::RuntimeError {
332            message: format!("Const {} only works on numbers", op_name),
333            location: None,
334        })?;
335        let b = right.as_f64().ok_or_else(|| ShapeError::RuntimeError {
336            message: format!("Const {} only works on numbers", op_name),
337            location: None,
338        })?;
339        Ok(ValueWord::from_f64(f(a, b)))
340    }
341
342    fn const_compare_nb(
343        &self,
344        left: ValueWord,
345        right: ValueWord,
346        cmp: fn(f64, f64) -> bool,
347    ) -> Result<ValueWord> {
348        let a = left.as_f64().ok_or_else(|| ShapeError::RuntimeError {
349            message: "Const comparison only works on numbers".to_string(),
350            location: None,
351        })?;
352        let b = right.as_f64().ok_or_else(|| ShapeError::RuntimeError {
353            message: "Const comparison only works on numbers".to_string(),
354            location: None,
355        })?;
356        Ok(ValueWord::from_bool(cmp(a, b)))
357    }
358}
359
360impl Default for ConstEvaluator {
361    fn default() -> Self {
362        Self::new()
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369    use shape_ast::ast::Span;
370    use std::sync::Arc;
371
372    #[test]
373    fn test_const_number_literal() {
374        let evaluator = ConstEvaluator::new();
375        let expr = Expr::Literal(Literal::Number(42.0), Span::DUMMY);
376        let result = evaluator.eval(&expr).unwrap();
377        assert_eq!(result, ValueWord::from_f64(42.0));
378    }
379
380    #[test]
381    fn test_const_string_literal() {
382        let evaluator = ConstEvaluator::new();
383        let expr = Expr::Literal(Literal::String("hello".to_string()), Span::DUMMY);
384        let result = evaluator.eval(&expr).unwrap();
385        assert_eq!(
386            result,
387            ValueWord::from_string(Arc::new("hello".to_string()))
388        );
389    }
390
391    #[test]
392    fn test_const_formatted_string_literal() {
393        let evaluator = ConstEvaluator::new();
394        let expr = Expr::Literal(
395            Literal::FormattedString {
396                value: "value: {x}".to_string(),
397                mode: shape_ast::ast::InterpolationMode::Braces,
398            },
399            Span::DUMMY,
400        );
401        let result = evaluator.eval(&expr).unwrap();
402        assert_eq!(
403            result,
404            ValueWord::from_string(Arc::new("value: {x}".to_string()))
405        );
406    }
407
408    #[test]
409    fn test_const_boolean_literal() {
410        let evaluator = ConstEvaluator::new();
411        let expr = Expr::Literal(Literal::Bool(true), Span::DUMMY);
412        let result = evaluator.eval(&expr).unwrap();
413        assert_eq!(result, ValueWord::from_bool(true));
414    }
415
416    #[test]
417    fn test_const_object_literal() {
418        let evaluator = ConstEvaluator::new();
419        let expr = Expr::Object(
420            vec![
421                ObjectEntry::Field {
422                    key: "key1".to_string(),
423                    value: Expr::Literal(Literal::Number(42.0), Span::DUMMY),
424                    type_annotation: None,
425                },
426                ObjectEntry::Field {
427                    key: "key2".to_string(),
428                    value: Expr::Literal(Literal::String("value".to_string()), Span::DUMMY),
429                    type_annotation: None,
430                },
431            ],
432            Span::DUMMY,
433        );
434        let result = evaluator.eval(&expr).unwrap();
435
436        let obj =
437            crate::type_schema::typed_object_to_hashmap_nb(&result).expect("Expected TypedObject");
438        assert_eq!(obj.get("key1").and_then(|v| v.as_f64()), Some(42.0));
439        assert_eq!(obj.get("key2").and_then(|v| v.as_str()), Some("value"));
440    }
441
442    #[test]
443    fn test_const_array_literal() {
444        let evaluator = ConstEvaluator::new();
445        let expr = Expr::Array(
446            vec![
447                Expr::Literal(Literal::Number(1.0), Span::DUMMY),
448                Expr::Literal(Literal::Number(2.0), Span::DUMMY),
449                Expr::Literal(Literal::Number(3.0), Span::DUMMY),
450            ],
451            Span::DUMMY,
452        );
453        let result = evaluator.eval(&expr).unwrap();
454
455        let arr = result.as_any_array().expect("Expected array").to_generic();
456        assert_eq!(arr.len(), 3);
457        assert_eq!(arr[0].as_f64(), Some(1.0));
458        assert_eq!(arr[1].as_f64(), Some(2.0));
459        assert_eq!(arr[2].as_f64(), Some(3.0));
460    }
461
462    #[test]
463    fn test_const_arithmetic_add() {
464        let evaluator = ConstEvaluator::new();
465        let expr = Expr::BinaryOp {
466            left: Box::new(Expr::Literal(Literal::Number(2.0), Span::DUMMY)),
467            op: shape_ast::ast::BinaryOp::Add,
468            right: Box::new(Expr::Literal(Literal::Number(3.0), Span::DUMMY)),
469            span: Span::DUMMY,
470        };
471        let result = evaluator.eval(&expr).unwrap();
472        assert_eq!(result, ValueWord::from_f64(5.0));
473    }
474
475    #[test]
476    fn test_const_string_concat() {
477        let evaluator = ConstEvaluator::new();
478        let expr = Expr::BinaryOp {
479            left: Box::new(Expr::Literal(
480                Literal::String("hello ".to_string()),
481                Span::DUMMY,
482            )),
483            op: shape_ast::ast::BinaryOp::Add,
484            right: Box::new(Expr::Literal(
485                Literal::String("world".to_string()),
486                Span::DUMMY,
487            )),
488            span: Span::DUMMY,
489        };
490        let result = evaluator.eval(&expr).unwrap();
491        assert_eq!(
492            result,
493            ValueWord::from_string(Arc::new("hello world".to_string()))
494        );
495    }
496
497    #[test]
498    fn test_const_annotation_param() {
499        let mut evaluator = ConstEvaluator::new();
500        evaluator.add_param("period".to_string(), ValueWord::from_f64(20.0));
501
502        let expr = Expr::Identifier("period".to_string(), Span::DUMMY);
503        let result = evaluator.eval(&expr).unwrap();
504        assert_eq!(result, ValueWord::from_f64(20.0));
505    }
506
507    #[test]
508    fn test_const_nested_object() {
509        let evaluator = ConstEvaluator::new();
510        let expr = Expr::Object(
511            vec![
512                ObjectEntry::Field {
513                    key: "is_test".to_string(),
514                    value: Expr::Literal(Literal::Bool(true), Span::DUMMY),
515                    type_annotation: None,
516                },
517                ObjectEntry::Field {
518                    key: "code_lens".to_string(),
519                    value: Expr::Array(
520                        vec![Expr::Object(
521                            vec![
522                                ObjectEntry::Field {
523                                    key: "title".to_string(),
524                                    value: Expr::Literal(
525                                        Literal::String("Run".to_string()),
526                                        Span::DUMMY,
527                                    ),
528                                    type_annotation: None,
529                                },
530                                ObjectEntry::Field {
531                                    key: "command".to_string(),
532                                    value: Expr::Literal(
533                                        Literal::String("run".to_string()),
534                                        Span::DUMMY,
535                                    ),
536                                    type_annotation: None,
537                                },
538                            ],
539                            Span::DUMMY,
540                        )],
541                        Span::DUMMY,
542                    ),
543                    type_annotation: None,
544                },
545            ],
546            Span::DUMMY,
547        );
548        let result = evaluator.eval(&expr).unwrap();
549
550        let obj =
551            crate::type_schema::typed_object_to_hashmap_nb(&result).expect("Expected TypedObject");
552        assert_eq!(obj.get("is_test").and_then(|v| v.as_bool()), Some(true));
553        assert!(
554            obj.get("code_lens")
555                .and_then(|v| v.as_any_array())
556                .is_some()
557        );
558    }
559
560    #[test]
561    fn test_const_function_call_fails() {
562        let evaluator = ConstEvaluator::new();
563        let expr = Expr::FunctionCall {
564            name: "foo".to_string(),
565            args: vec![],
566            named_args: vec![],
567            span: Span::DUMMY,
568        };
569        let result = evaluator.eval(&expr);
570        assert!(result.is_err());
571        assert!(
572            result
573                .unwrap_err()
574                .to_string()
575                .contains("not allowed in const context")
576        );
577    }
578
579    #[test]
580    fn test_const_undefined_variable_fails() {
581        let evaluator = ConstEvaluator::new();
582        let expr = Expr::Identifier("undefined_var".to_string(), Span::DUMMY);
583        let result = evaluator.eval(&expr);
584        assert!(result.is_err());
585        assert!(
586            result
587                .unwrap_err()
588                .to_string()
589                .contains("annotation parameters")
590        );
591    }
592}