somni_expr/
lib.rs

1//! # Somni expression evaluation Library
2//!
3//! This crate implements the expression evaluation subset of the Somnni language and VM. The crate
4//! can be used by itself, to evaluate simple expressions or even to run complete Somni programs, although
5//! slower than the Somni VM would.
6//!
7//! ## Overview
8//!
9//! Expressions are a subset of the Somni language:
10//!
11//! The expression language includes:
12//!
13//! - Literals: integers, floats, booleans, strings.
14//! - Variables
15//! - A basic set of operators
16//! - Function calls
17//!
18//! The expression language does not include:
19//!
20//! - Declaring new variables. You can assign to existing variables.
21//! - Control flow (if, loops, etc.)
22//! - Complex data structures (arrays, objects, etc.)
23//! - Defining functions and variables (these are provided by the context)
24//!
25//! ## Operators
26//!
27//! The following binary operators are supported, in order of precedence:
28//!
29//! - `=`: assign a value to an existing variable
30//! - `||`: logical OR, short-circuiting
31//! - `&&`: logical AND, short-circuiting
32//! - `<`, `<=`, `>`, `>=`, `==`, `!=`: comparison operators
33//! - `|`: bitwise OR
34//! - `^`: bitwise XOR
35//! - `&`: bitwise AND
36//! - `<<`, `>>`: bitwise shift
37//! - `+`, `-`: addition and subtraction
38//! - `*`, `/`: multiplication and division
39//!
40//! Unary operators include:
41//! - `&`: taking the address of a variable
42//! - `*`: dereferencing an address to a variable
43//! - `!`: logical NOT
44//! - `-`: negation
45//!
46//! For the full specification of the grammar, see the [`parser`] module's documentation.
47//!
48//! ## Numeric types
49//!
50//! The Somni language supports three numeric types:
51//!
52//! - Integers
53//! - Signed integers
54//! - Floats
55//!
56//! By default, the library uses the [`DefaultTypeSet`], which uses `u64`, `i64`, and `f64` for
57//! these types. You can use other type sets like [`TypeSet32`] or [`TypeSet128`] to use
58//! 32-bit or 128-bit integers and floats. You need to specify the type set when creating
59//! the context.
60//!
61//! Numeric integer literals can be either signed or unsigned integers. Their type is inferred from the usage.
62//!
63//! ## Usage
64//!
65//! To evaluate an expression, you need to create a [`Context`] first. You can assign
66//! variables and define functions in this context, and then you can use this context
67//! to evaluate expressions.
68//!
69//! ```rust
70//! use somni_expr::Context;
71//!
72//! let mut context = Context::new();
73//!
74//! // Define a variable
75//! context.add_variable::<u64>("x", 42);
76//! context.add_function("add_one", |x: u64| { x + 1 });
77//! context.add_function("floor", |x: f64| { x.floor() as u64 });
78//!
79//! // Evaluate an expression - we expect it to evaluate
80//! // to a number, which is u64 in the default type set.
81//! let result = context.evaluate::<u64>("add_one(x + floor(1.2))");
82//!
83//! assert_eq!(result, Ok(44));
84//! ```
85//!
86//! The context may also include a complete Somni program. The program may use the entirety
87//! of the Somni language, not just the expression language.
88//!
89//! ```rust
90//! use somni_expr::Context;
91//!
92//! let mut context = Context::parse("fn double(x: int) -> int { return x * 2; }").unwrap();
93//!
94//! // Evaluate an expression by calling the function defined by the program:
95//! let result = context.evaluate::<u64>("double(4)");
96//!
97//! assert_eq!(result, Ok(8));
98//! ```
99#![warn(missing_docs)]
100
101macro_rules! for_each {
102    // Any parenthesized set of choices, allows multiple matchers in the pattern
103    ($(($pattern:tt) in [$( ($($choice:tt)*) ),*] => $code:tt;)*) => {
104        $(
105            macro_rules! inner { $pattern => $code; }
106
107            $(
108                inner!( $($choice)* );
109            )*
110        )*
111    };
112    // Single type, single matcher
113    ($($pattern:tt in [$($choice:ty),*] => $code:tt;)*) => {
114        $(
115            macro_rules! inner { $pattern => $code; }
116
117            $(
118                inner!($choice);
119            )*
120        )*
121    };
122}
123
124pub mod error;
125pub mod function;
126pub mod value;
127mod visitor;
128
129pub use function::{DynFunction, FunctionCallError};
130pub use value::TypedValue;
131pub use visitor::ExpressionVisitor;
132
133use std::{
134    cell::RefCell,
135    collections::HashMap,
136    fmt::{Debug, Display},
137    rc::Rc,
138};
139
140use somni_parser::{
141    ast::{self, Expression, Function, Item, Program},
142    parser::{self, parse, TypeSet as ParserTypeSet},
143    Location,
144};
145
146use crate::{
147    error::MarkInSource,
148    function::ExprFn,
149    value::{LoadOwned, LoadStore, ValueType},
150};
151
152pub use somni_parser::parser::{DefaultTypeSet, TypeSet128, TypeSet32};
153
154/// Defines the backing types for Somni types.
155///
156/// The [`LoadStore`] and [`LoadOwned`] traits can be used to convert between Rust and Somni types.
157pub trait TypeSet: Sized + Default + Debug + 'static {
158    /// The typeset that will be used to parse source code.
159    type Parser: ParserTypeSet<Integer = Self::Integer, Float = Self::Float>;
160
161    /// The type of unsigned integers in this type set.
162    type Integer: Copy + ValueType<NegateOutput: LoadStore<Self>> + LoadStore<Self>;
163
164    /// The type of signed integers in this type set.
165    type SignedInteger: Copy + ValueType<NegateOutput: LoadStore<Self>> + LoadStore<Self>;
166
167    /// The type of floating point numbers in this type set.
168    type Float: Copy + ValueType<NegateOutput: LoadStore<Self>> + LoadStore<Self>;
169
170    /// The type of a string in this type set.
171    type String: ValueType<NegateOutput: LoadStore<Self>> + LoadStore<Self>;
172
173    /// Converts an unsigned integer into a signed integer.
174    fn to_signed(v: Self::Integer) -> Result<Self::SignedInteger, OperatorError>;
175
176    /// Converts an unsigned integer into a Rust usize.
177    fn to_usize(v: Self::Integer) -> Result<usize, OperatorError>;
178
179    /// Converts the given Rust usize to an integer.
180    fn int_from_usize(v: usize) -> Self::Integer;
181
182    /// Loads a string.
183    fn load_string<'s>(&'s self, str: &'s Self::String) -> &'s str;
184
185    /// Stores a string.
186    fn store_string(&mut self, str: &str) -> Self::String;
187}
188
189for_each! {
190    (($name:ident, $signed:ty)) in [(DefaultTypeSet, i64), (TypeSet32, i32), (TypeSet128, i128)] => {
191        impl TypeSet for $name {
192            type Parser = Self;
193
194            type Integer = <Self::Parser as ParserTypeSet>::Integer;
195            type SignedInteger = $signed;
196            type Float = <Self::Parser as ParserTypeSet>::Float;
197            type String = Box<str>;
198
199            fn to_signed(v: Self::Integer) -> Result<Self::SignedInteger, OperatorError> {
200                <$signed>::try_from(v).map_err(|_| OperatorError::RuntimeError)
201            }
202
203            fn to_usize(v: Self::Integer) -> Result<usize, OperatorError> {
204                usize::try_from(v).map_err(|_| OperatorError::RuntimeError)
205            }
206
207            fn int_from_usize(v: usize) -> Self::Integer {
208                Self::Integer::try_from(v).unwrap()
209            }
210
211            fn load_string<'s>(&'s self, str: &'s Self::String) -> &'s str {
212                str
213            }
214
215            fn store_string(&mut self, str: &str) -> Self::String {
216                str.to_string().into_boxed_str()
217            }
218        }
219    };
220}
221
222/// Represents an error that can occur during operator evaluation.
223#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
224pub enum OperatorError {
225    /// A type error occurred.
226    TypeError,
227    /// A runtime error occurred.
228    RuntimeError,
229}
230
231impl Display for OperatorError {
232    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233        let message = match self {
234            OperatorError::TypeError => "Type error",
235            OperatorError::RuntimeError => "Runtime error",
236        };
237
238        f.write_str(message)
239    }
240}
241
242macro_rules! dispatch_binary {
243    ($method:ident) => {
244        pub(crate) fn $method(ctx: &mut T, lhs: Self, rhs: Self) -> Result<Self, OperatorError> {
245            let result = match (lhs, rhs) {
246                (Self::Bool(value), Self::Bool(other)) => {
247                    ValueType::$method(value, other)?.store(ctx)
248                }
249                (Self::Int(value), Self::Int(other)) => {
250                    ValueType::$method(value, other)?.store(ctx)
251                }
252                (Self::SignedInt(value), Self::SignedInt(other)) => {
253                    ValueType::$method(value, other)?.store(ctx)
254                }
255                (Self::MaybeSignedInt(value), Self::MaybeSignedInt(other)) => {
256                    match ValueType::$method(value, other)?.store(ctx) {
257                        Self::Int(v) => Self::MaybeSignedInt(v),
258                        other => other,
259                    }
260                }
261                (Self::Float(value), Self::Float(other)) => {
262                    ValueType::$method(value, other)?.store(ctx)
263                }
264                (Self::String(value), Self::String(other)) => {
265                    ValueType::$method(value, other)?.store(ctx)
266                }
267                (Self::Int(value), Self::MaybeSignedInt(other)) => {
268                    ValueType::$method(value, other)?.store(ctx)
269                }
270                (Self::MaybeSignedInt(value), Self::Int(other)) => {
271                    ValueType::$method(value, other)?.store(ctx)
272                }
273                (Self::SignedInt(value), Self::MaybeSignedInt(other)) => {
274                    ValueType::$method(value, T::to_signed(other)?)?.store(ctx)
275                }
276                (Self::MaybeSignedInt(value), Self::SignedInt(other)) => {
277                    ValueType::$method(T::to_signed(value)?, other)?.store(ctx)
278                }
279                _ => return Err(OperatorError::TypeError),
280            };
281
282            Ok(result)
283        }
284    };
285}
286
287macro_rules! dispatch_unary {
288    ($method:ident) => {
289        pub(crate) fn $method(ctx: &mut T, operand: Self) -> Result<Self, OperatorError> {
290            match operand {
291                Self::Bool(value) => Ok(ValueType::$method(value)?.store(ctx)),
292                Self::Int(value) | Self::MaybeSignedInt(value) => {
293                    Ok(ValueType::$method(value)?.store(ctx))
294                }
295                Self::SignedInt(value) => Ok(ValueType::$method(value)?.store(ctx)),
296                Self::Float(value) => Ok(ValueType::$method(value)?.store(ctx)),
297                Self::String(value) => Ok(ValueType::$method(value)?.store(ctx)),
298                _ => return Err(OperatorError::TypeError),
299            }
300        }
301    };
302}
303
304impl<T> TypedValue<T>
305where
306    T: TypeSet,
307{
308    dispatch_binary!(equals);
309    dispatch_binary!(less_than);
310    dispatch_binary!(less_than_or_equal);
311    dispatch_binary!(not_equals);
312    dispatch_binary!(bitwise_or);
313    dispatch_binary!(bitwise_xor);
314    dispatch_binary!(bitwise_and);
315    dispatch_binary!(shift_left);
316    dispatch_binary!(shift_right);
317    dispatch_binary!(add);
318    dispatch_binary!(subtract);
319    dispatch_binary!(multiply);
320    dispatch_binary!(divide);
321    dispatch_binary!(modulo);
322    dispatch_unary!(not);
323    dispatch_unary!(negate);
324}
325
326/// An expression context that provides the necessary environment for evaluating expressions.
327pub trait ExprContext<T = DefaultTypeSet>
328where
329    T: TypeSet,
330{
331    /// Returns a reference to the `TypeSet`.
332    fn type_context(&mut self) -> &mut T;
333
334    /// Attempts to load a variable from the context.
335    fn try_load_variable(&mut self, variable: &str) -> Option<TypedValue<T>>;
336
337    /// Declares a variable in the context.
338    fn declare(&mut self, variable: &str, value: TypedValue<T>);
339
340    /// Assigns a new value to a variable in the context.
341    fn assign_variable(&mut self, variable: &str, value: &TypedValue<T>) -> Result<(), Box<str>>;
342
343    /// Returns a value from the given address.
344    fn at_address(&mut self, address: TypedValue<T>) -> Result<TypedValue<T>, Box<str>>;
345
346    /// Assigns a new value to a variable in the context.
347    fn assign_address(
348        &mut self,
349        address: TypedValue<T>,
350        value: &TypedValue<T>,
351    ) -> Result<(), Box<str>>;
352
353    /// Returns the address of a variable in the context.
354    fn address_of(&mut self, variable: &str) -> TypedValue<T>;
355
356    /// Opens a new scope in the current stack frame.
357    fn open_scope(&mut self);
358
359    /// Closes the last scope in the current stack frame.
360    fn close_scope(&mut self);
361
362    /// Calls a function in the context.
363    fn call_function(
364        &mut self,
365        function_name: &str,
366        args: &[TypedValue<T>],
367    ) -> Result<TypedValue<T>, FunctionCallError>;
368}
369
370/// An error that occurs during evaluation of an expression.
371#[derive(Clone, Debug, PartialEq)]
372pub struct EvalError {
373    /// The error message.
374    pub message: Box<str>,
375    /// The location in the source code where the error occurred.
376    pub location: Location,
377}
378
379impl Display for EvalError {
380    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
381        write!(f, "Evaluation error: {}", self.message)
382    }
383}
384
385/// An error that occurs during evaluation.
386///
387/// Printing this error will show the error message and the location in the source code.
388///
389/// ```rust
390/// use somni_expr::{Context, TypeSet32};
391/// let mut ctx = Context::<TypeSet32>::new_with_types();
392///
393/// let error = ctx.evaluate::<u32>("true + 1").unwrap_err();
394///
395/// println!("{error:?}");
396///
397/// // Output:
398/// //
399/// // Evaluation error
400/// // ---> at line 1 column 1
401/// //   |
402/// // 1 | true + 1
403/// //   | ^^^^^^^^ Failed to evaluate expression: Type error
404/// ```
405#[derive(Clone, PartialEq)]
406pub struct ExpressionError<'s> {
407    error: EvalError,
408    source: &'s str,
409}
410
411impl ExpressionError<'_> {
412    /// Returns the inner [`EvalError`].
413    pub fn into_inner(self) -> EvalError {
414        self.error
415    }
416}
417
418impl Debug for ExpressionError<'_> {
419    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
420        let marked = MarkInSource(
421            self.source,
422            self.error.location,
423            "Evaluation error",
424            &self.error.message,
425        );
426        marked.fmt(f)
427    }
428}
429
430/// A type in the Somni language.
431#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
432pub enum Type {
433    /// Represents no value, used for e.g. functions that do not return a value.
434    Void,
435    /// Represents integer that may be signed or unsigned.
436    MaybeSignedInt,
437    /// Represents an unsigned integer.
438    Int,
439    /// Represents a signed integer.
440    SignedInt,
441    /// Represents a floating point number.
442    Float,
443    /// Represents a boolean value.
444    Bool,
445    /// Represents a string value.
446    String,
447}
448impl Type {
449    fn from_name(source: &str) -> Result<Self, Box<str>> {
450        match source {
451            "int" => Ok(Type::Int),
452            "signed" => Ok(Type::SignedInt),
453            "float" => Ok(Type::Float),
454            "bool" => Ok(Type::Bool),
455            "string" => Ok(Type::String),
456            other => Err(format!("Unknown type `{other}`").into_boxed_str()),
457        }
458    }
459}
460
461impl Display for Type {
462    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
463        match self {
464            Type::Void => write!(f, "void"),
465            Type::MaybeSignedInt => write!(f, "{{int/signed}}"),
466            Type::Int => write!(f, "int"),
467            Type::SignedInt => write!(f, "signed"),
468            Type::Bool => write!(f, "bool"),
469            Type::String => write!(f, "string"),
470            Type::Float => write!(f, "float"),
471        }
472    }
473}
474
475/// State of an unevaluated global.
476enum InitializerState {
477    /// Untouched. Contains the item index of the global
478    Unevaluated(usize),
479    /// The global is being evaluated. This state is used to detect cycles.
480    Evaluating,
481}
482
483struct StackFrame<T: TypeSet> {
484    start_addr: usize,
485    variables: Vec<TypedValue<T>>,
486    scopes: Vec<HashMap<String, usize>>,
487}
488
489impl<T: TypeSet> StackFrame<T> {
490    fn new() -> StackFrame<T> {
491        StackFrame {
492            start_addr: 0,
493            variables: vec![],
494            scopes: vec![HashMap::new()],
495        }
496    }
497
498    fn next_call_frame(&self) -> StackFrame<T> {
499        StackFrame {
500            start_addr: self.start_addr + self.variables.len(),
501            variables: vec![],
502            scopes: vec![HashMap::new()],
503        }
504    }
505
506    fn declare(&mut self, variable: &str, value: TypedValue<T>) -> usize {
507        let index = self.variables.len();
508        self.variables.push(value);
509        self.scopes
510            .last_mut()
511            .unwrap()
512            .insert(variable.to_string(), index);
513        index + self.start_addr
514    }
515
516    fn lookup_index(&self, name: &str) -> Option<usize> {
517        for scope in self.scopes.iter().rev() {
518            if let Some(idx) = scope.get(name) {
519                return Some(*idx);
520            }
521        }
522        None
523    }
524
525    fn store(&mut self, variable: &str, value: &TypedValue<T>) -> bool {
526        if let Some(idx) = self.lookup_index(variable) {
527            self.variables.get_mut(idx).unwrap().clone_from(value);
528            true
529        } else {
530            false
531        }
532    }
533
534    fn lookup_by_address(&mut self, address: usize) -> Result<&mut TypedValue<T>, Box<str>> {
535        self.variables
536            .get_mut(address - self.start_addr)
537            .ok_or_else(|| format!("Invalid address {address}").into_boxed_str())
538    }
539
540    fn lookup_by_name<'s>(&'s mut self, variable: &str) -> Option<(usize, &'s mut TypedValue<T>)> {
541        let index = self.lookup_index(variable)?;
542        let address = index + self.start_addr;
543
544        Some((address, self.variables.get_mut(index).unwrap()))
545    }
546
547    fn open_scope(&mut self) {
548        self.scopes.push(HashMap::new());
549    }
550
551    fn close_scope(&mut self) {
552        self.scopes.pop().unwrap();
553    }
554}
555
556struct ProgramData<'ctx, T: TypeSet> {
557    source: &'ctx str,
558    program: Program<T::Parser>,
559    program_functions: HashMap<&'ctx str, usize>,
560    // User-registered functions
561    functions: RefCell<HashMap<&'ctx str, ExprFn<'ctx, T>>>,
562}
563
564impl<'ctx> Default for Context<'ctx, DefaultTypeSet> {
565    fn default() -> Self {
566        Self::new()
567    }
568}
569
570/// The expression context, which holds variables, functions, and other state needed for evaluation.
571pub struct Context<'ctx, T = DefaultTypeSet>
572where
573    T: TypeSet,
574{
575    program: Rc<ProgramData<'ctx, T>>,
576    // Program state
577    // ----
578    /// Variable stack. Element 0 is the global scope.
579    stack: Vec<StackFrame<T>>,
580    // unevaluated globals
581    initializers: HashMap<&'ctx str, InitializerState>,
582    type_context: T,
583}
584
585impl<'ctx> Context<'ctx, DefaultTypeSet> {
586    /// Creates a new context with [default types][DefaultTypeSet].
587    pub fn new() -> Self {
588        Self::new_with_types()
589    }
590
591    /// Loads the given program into a new context with [default types][DefaultTypeSet].
592    pub fn parse(source: &'ctx str) -> Result<Self, ExpressionError<'ctx>> {
593        Self::parse_with_types(source)
594    }
595}
596
597const GLOBAL_VARIABLE: usize = usize::MAX - usize::MAX / 2;
598
599impl<'ctx, T> Context<'ctx, T>
600where
601    T: TypeSet,
602{
603    /// Creates a new context. The type set must be specified when using this function.
604    ///
605    /// ```rust
606    /// use somni_expr::{Context, TypeSet32};
607    /// let mut ctx = Context::<TypeSet32>::new_with_types();
608    /// ```
609    pub fn new_with_types() -> Self {
610        Self::new_from_program("", Program { items: vec![] })
611    }
612
613    /// Parses the given program into a new context. The type set must be specified when using this function.
614    ///
615    /// ```rust
616    /// use somni_expr::{Context, TypeSet32};
617    /// let mut ctx = Context::<TypeSet32>::parse_with_types("// program source comes here").unwrap();
618    /// ```
619    pub fn parse_with_types(source: &'ctx str) -> Result<Self, ExpressionError<'ctx>> {
620        let program = parse::<T::Parser>(source).map_err(|e| ExpressionError {
621            error: EvalError {
622                message: format!("Failed to parse program: {e}").into_boxed_str(),
623                location: e.location,
624            },
625            source,
626        })?;
627
628        Ok(Self::new_from_program(source, program))
629    }
630
631    /// Loads the given program into a new context.
632    pub fn new_from_program(source: &'ctx str, program: Program<T::Parser>) -> Self {
633        let mut program_functions = HashMap::new();
634        let mut initializers = HashMap::new();
635        // Extract data for O(1) function/initializer lookup
636        for (idx, item) in program.items.iter().enumerate() {
637            match item {
638                ast::Item::Function(function) => {
639                    program_functions.insert(function.name.source(source), idx);
640                }
641                ast::Item::GlobalVariable(global_variable) => {
642                    initializers.insert(
643                        global_variable.identifier.source(source),
644                        InitializerState::Unevaluated(idx),
645                    );
646                }
647                ast::Item::ExternFunction(_) => {}
648            }
649        }
650        Self {
651            program: Rc::new(ProgramData {
652                source,
653                program,
654                program_functions,
655                functions: RefCell::new(HashMap::new()),
656            }),
657            stack: vec![StackFrame::new()],
658            type_context: T::default(),
659            initializers,
660        }
661    }
662
663    fn evaluate_any_function_impl(
664        &mut self,
665        function_name: &Function<T::Parser>,
666        args: &[TypedValue<T>],
667    ) -> Result<TypedValue<T>, EvalError> {
668        let source = self.program.clone().source;
669
670        let stack_frame = self
671            .stack
672            .last()
673            .expect("The global scope must always be present")
674            .next_call_frame();
675        self.stack.push(stack_frame);
676
677        let mut visitor = ExpressionVisitor::<Self, T> {
678            context: self,
679            source,
680            _marker: std::marker::PhantomData,
681        };
682
683        let result = visitor.visit_function(function_name, args);
684
685        self.stack.pop();
686
687        result
688    }
689
690    /// Parses and evaluates an expression and returns the result as a specific value type.
691    ///
692    /// This function will attempt to convert the result of the expression to the specified type `V`.
693    /// If the conversion fails, it will return an `ExpressionError`.
694    ///
695    /// ```rust
696    /// use somni_expr::{Context, TypedValue};
697    ///
698    /// let mut context = Context::new();
699    ///
700    /// assert_eq!(context.evaluate::<u64>("1 + 2"), Ok(3));
701    /// assert_eq!(context.evaluate::<TypedValue>("1 + 2"), Ok(TypedValue::Int(3)));
702    /// ```
703    pub fn evaluate<'s, V>(&'s mut self, source: &'s str) -> Result<V::Output, ExpressionError<'s>>
704    where
705        V: LoadOwned<T>,
706    {
707        let expression =
708            parser::parse_expression::<T::Parser>(source).map_err(|e| ExpressionError {
709                error: EvalError {
710                    message: format!("Parser error: {e}").into_boxed_str(),
711                    location: e.location,
712                },
713                source,
714            })?;
715
716        self.evaluate_parsed::<V>(source, &expression)
717    }
718
719    /// Evaluates a pre-parsed expression and returns the result as a specific value type.
720    ///
721    /// This function will attempt to convert the result of the expression to the specified type `V`.
722    /// If the conversion fails, it will return an `ExpressionError`.
723    ///
724    /// ```rust
725    /// use somni_expr::{Context, TypedValue};
726    ///
727    /// let mut context = Context::new();
728    ///
729    /// let source = "1 + 2";
730    /// let expr = somni_parser::parser::parse_expression(source).unwrap();
731    ///
732    /// assert_eq!(context.evaluate_parsed::<u64>(source, &expr), Ok(3));
733    /// assert_eq!(context.evaluate_parsed::<TypedValue>(source, &expr), Ok(TypedValue::Int(3)));
734    /// ```
735    pub fn evaluate_parsed<'s, V>(
736        &'s mut self,
737        source: &'s str,
738        expression: &Expression<T::Parser>,
739    ) -> Result<V::Output, ExpressionError<'s>>
740    where
741        V: LoadOwned<T>,
742    {
743        self.evaluate_impl::<V>(source, expression)
744            .map_err(|error| ExpressionError { error, source })
745    }
746
747    fn evaluate_impl<V>(
748        &mut self,
749        source: &str,
750        expression: &Expression<T::Parser>,
751    ) -> Result<V::Output, EvalError>
752    where
753        V: LoadOwned<T>,
754    {
755        let mut visitor = ExpressionVisitor::<Self, T> {
756            context: self,
757            source,
758            _marker: std::marker::PhantomData,
759        };
760        let result = visitor.visit_expression(expression)?;
761        let result_ty = result.type_of();
762        V::load_owned(self.type_context(), &result).ok_or_else(|| EvalError {
763            message: format!(
764                "Expression evaluates to {result_ty}, which cannot be converted to {}",
765                std::any::type_name::<V>()
766            )
767            .into_boxed_str(),
768            location: expression.location(),
769        })
770    }
771
772    /// Defines a new variable in the context.
773    ///
774    /// The variable can be any type from the current [`TypeSet`], even [`TypedValue`].
775    ///
776    /// The variable will act as a global variable in the context of the program. Its
777    /// value can be changed by expressions.
778    ///
779    /// ```rust
780    /// use somni_expr::{Context, TypedValue};
781    ///
782    /// let mut context = Context::new();
783    ///
784    /// // Variable does not exist, it can't be assigned:
785    /// assert!(context.evaluate::<()>("counter = 0").is_err());
786    ///
787    /// context.add_variable::<u64>("counter", 0);
788    ///
789    /// // Variable exists now, so we can use it:
790    /// assert_eq!(context.evaluate::<()>("counter = counter + 1"), Ok(()));
791    /// assert_eq!(context.evaluate::<u64>("counter"), Ok(1));
792    /// ```
793    pub fn add_variable<V>(&mut self, name: &'ctx str, value: V)
794    where
795        V: LoadStore<T>,
796    {
797        let stored = value.store(self.type_context());
798        self.stack[0].declare(name, stored);
799    }
800
801    /// Adds a new function to the context.
802    ///
803    /// ```rust
804    /// use somni_expr::{Context, TypedValue};
805    ///
806    /// let mut context = Context::new();
807    ///
808    /// context.add_function("plus_one", |x: u64| x + 1);
809    ///
810    /// assert_eq!(context.evaluate::<u64>("plus_one(2)"), Ok(3));
811    /// ```
812    pub fn add_function<F, A>(&mut self, name: &'ctx str, func: F)
813    where
814        F: DynFunction<A, T> + 'ctx,
815    {
816        self.program
817            .functions
818            .borrow_mut()
819            .insert(name, ExprFn::new(func));
820    }
821
822    fn lookup(&mut self, variable: &str) -> Option<(usize, TypedValue<T>)> {
823        if self.stack.len() > 1 {
824            let frame = self.stack.last_mut().unwrap();
825            if let Some((index, var)) = frame.lookup_by_name(variable) {
826                // Already evaluated / user provided
827                return Some((index, var.clone()));
828            }
829        }
830
831        {
832            let global_frame = &mut self.stack[0];
833            if let Some((index, var)) = global_frame.lookup_by_name(variable) {
834                // Already evaluated / user provided
835                return Some((index | GLOBAL_VARIABLE, var.clone()));
836            }
837        }
838
839        // Mark as "initializing" to detect potential cycles
840        let state = self.initializers.get_mut(variable)?;
841        let InitializerState::Unevaluated(idx) =
842            std::mem::replace(state, InitializerState::Evaluating)
843        else {
844            return None;
845        };
846
847        // Get a reference to the initializer
848        let program = self.program.clone();
849        let Some(Item::GlobalVariable(global)) = program.program.items.get(idx) else {
850            return None;
851        };
852
853        let value = self
854            .evaluate_parsed::<TypedValue<T>>(self.program.source, &global.initializer)
855            .ok()?;
856
857        let global_frame = &mut self.stack[0];
858        let index = global_frame.declare(variable, value.clone());
859
860        Some((index | GLOBAL_VARIABLE, value))
861    }
862
863    fn lookup_address(&mut self, address: TypedValue<T>) -> Result<&mut TypedValue<T>, Box<str>> {
864        let TypedValue::Int(address) = address else {
865            return Err(format!("Expected address, got {address:?}").into_boxed_str());
866        };
867
868        let address = T::to_usize(address)
869            .map_err(|_| format!("Invalid address: {address:?}").into_boxed_str())?;
870
871        if address & GLOBAL_VARIABLE != 0 {
872            return self.stack[0].lookup_by_address(address & !GLOBAL_VARIABLE);
873        }
874
875        for frame in self.stack.iter_mut().rev() {
876            if frame.start_addr <= address {
877                return frame.lookup_by_address(address);
878            }
879        }
880
881        Err(format!("Not a valid memory address: {address}").into_boxed_str())
882    }
883}
884
885impl<T> ExprContext<T> for Context<'_, T>
886where
887    T: TypeSet,
888{
889    fn type_context(&mut self) -> &mut T {
890        &mut self.type_context
891    }
892
893    // TODO: return Result
894    fn try_load_variable(&mut self, variable: &str) -> Option<TypedValue<T>> {
895        self.lookup(variable).map(|(_idx, var)| var)
896    }
897
898    fn address_of(&mut self, variable: &str) -> TypedValue<T> {
899        let address = self
900            .lookup(variable)
901            .map(|(address, _var)| address)
902            .unwrap();
903        TypedValue::Int(T::int_from_usize(address))
904    }
905
906    /// Declares a variable in the context.
907    fn declare(&mut self, variable: &str, value: TypedValue<T>) {
908        self.stack.last_mut().unwrap().declare(variable, value);
909    }
910
911    /// Assigns a new value to a variable in the context.
912    fn assign_variable(&mut self, variable: &str, value: &TypedValue<T>) -> Result<(), Box<str>> {
913        if self.stack.last_mut().unwrap().store(variable, value) {
914            return Ok(());
915        }
916        if self.stack[0].store(variable, value) {
917            return Ok(());
918        }
919
920        Err(format!("Variable not found: {variable}").into_boxed_str())
921    }
922
923    fn at_address(&mut self, address: TypedValue<T>) -> Result<TypedValue<T>, Box<str>> {
924        self.lookup_address(address).cloned()
925    }
926
927    fn assign_address(
928        &mut self,
929        address: TypedValue<T>,
930        value: &TypedValue<T>,
931    ) -> Result<(), Box<str>> {
932        let v = self.lookup_address(address)?;
933        v.clone_from(value);
934        Ok(())
935    }
936
937    fn call_function(
938        &mut self,
939        function_name: &str,
940        args: &[TypedValue<T>],
941    ) -> Result<TypedValue<T>, FunctionCallError> {
942        let program = self.program.clone();
943        let Some(fn_item) = self.program.program_functions.get(function_name) else {
944            // Call out to a Rust function
945            return match program.functions.borrow().get(function_name) {
946                Some(func) => func.call(self.type_context(), args),
947                None => Err(FunctionCallError::FunctionNotFound),
948            };
949        };
950
951        // Call a Somni function
952        let Some(ast::Item::Function(function)) = program.program.items.get(*fn_item) else {
953            return Err(FunctionCallError::FunctionNotFound);
954        };
955        self.evaluate_any_function_impl(function, args)
956            .map_err(|err| {
957                FunctionCallError::Other(
958                    format!(
959                        "{:?}",
960                        ExpressionError {
961                            source: self.program.source,
962                            error: err,
963                        }
964                    )
965                    .into_boxed_str(),
966                )
967            })
968    }
969
970    /// Opens a new scope in the current stack frame.
971    fn open_scope(&mut self) {
972        // TODO: error handling
973        self.stack.last_mut().unwrap().open_scope();
974    }
975
976    /// Closes the last scope in the current stack frame.
977    fn close_scope(&mut self) {
978        // TODO: error handling
979        self.stack.last_mut().unwrap().close_scope();
980    }
981}
982
983#[macro_export]
984#[doc(hidden)]
985macro_rules! for_all_tuples {
986    ($pat:tt => $code:tt;) => {
987        macro_rules! inner { $pat => $code; }
988
989        inner!();
990        inner!(V1);
991        inner!(V1, V2);
992        inner!(V1, V2, V3);
993        inner!(V1, V2, V3, V4);
994        inner!(V1, V2, V3, V4, V5);
995        inner!(V1, V2, V3, V4, V5, V6);
996        inner!(V1, V2, V3, V4, V5, V6, V7);
997        inner!(V1, V2, V3, V4, V5, V6, V7, V8);
998        inner!(V1, V2, V3, V4, V5, V6, V7, V8, V9);
999        inner!(V1, V2, V3, V4, V5, V6, V7, V8, V9, V10);
1000    };
1001}
1002
1003#[cfg(test)]
1004mod test {
1005    use std::path::Path;
1006
1007    use super::*;
1008
1009    fn strip_ansi(s: impl AsRef<str>) -> String {
1010        use ansi_parser::AnsiParser;
1011        fn text_block(output: ansi_parser::Output<'_>) -> Option<&str> {
1012            match output {
1013                ansi_parser::Output::TextBlock(text) => Some(text),
1014                _ => None,
1015            }
1016        }
1017
1018        s.as_ref()
1019            .ansi_parse()
1020            .filter_map(text_block)
1021            .collect::<String>()
1022    }
1023
1024    #[test]
1025    fn test_evaluating_exprs() {
1026        let mut ctx = Context::new();
1027
1028        ctx.add_variable::<i64>("signed", 30);
1029        ctx.add_variable::<u64>("value", 30);
1030        ctx.add_function("func", |v: u64| 2 * v);
1031        ctx.add_function("func2", |v1: u64, v2: u64| v1 + v2);
1032        ctx.add_function("five", || "five");
1033        ctx.add_function("is_five", |num: &str| num == "five");
1034        ctx.add_function("concatenate", |a: &str, b: &str| format!("{a}{b}"));
1035
1036        assert_eq!(ctx.evaluate::<bool>("value / 5 == 6"), Ok(true));
1037        assert_eq!(ctx.evaluate::<bool>("five() == \"five\""), Ok(true));
1038        assert_eq!(
1039            ctx.evaluate::<bool>("is_five(five()) != is_five(\"six\")"),
1040            Ok(true)
1041        );
1042        assert_eq!(ctx.evaluate::<u64>("func(20) / 5"), Ok(8));
1043        assert_eq!(
1044            ctx.evaluate::<TypedValue>("func(20) / 5"),
1045            Ok(TypedValue::Int(8))
1046        );
1047        assert_eq!(ctx.evaluate::<u64>("func2(20, 20) / 5"), Ok(8));
1048        assert_eq!(ctx.evaluate::<bool>("true & false"), Ok(false));
1049        assert_eq!(ctx.evaluate::<bool>("!true"), Ok(false));
1050        assert_eq!(ctx.evaluate::<bool>("false | false"), Ok(false));
1051        assert_eq!(ctx.evaluate::<bool>("true ^ true"), Ok(false));
1052        assert_eq!(ctx.evaluate::<u64>("!0x1111"), Ok(0xFFFF_FFFF_FFFF_EEEE));
1053        assert_eq!(
1054            ctx.evaluate::<String>("concatenate(five(), \"six\")"),
1055            Ok(String::from("fivesix"))
1056        );
1057        assert_eq!(ctx.evaluate::<bool>("signed * 2 == 60"), Ok(true));
1058        assert_eq!(ctx.evaluate::<i64>("*&signed"), Ok(30));
1059    }
1060
1061    #[test]
1062    fn test_context_is_mutable() {
1063        let mut ctx = Context::new();
1064
1065        ctx.add_variable::<u64>("value", 30);
1066
1067        ctx.evaluate::<()>("value = 5").unwrap();
1068        assert_eq!(ctx.evaluate::<bool>("value == 5"), Ok(true));
1069    }
1070
1071    #[test]
1072    fn test_evaluating_exprs_with_u32() {
1073        let mut ctx = Context::<TypeSet32>::new_with_types();
1074
1075        ctx.add_variable::<u32>("value", 30);
1076        ctx.add_function("func", |v: u32| 2 * v);
1077        ctx.add_function("func2", |v1: u32, v2: u32| v1 + v2);
1078
1079        assert_eq!(ctx.evaluate::<bool>("value / 5 == 6"), Ok(true));
1080        assert_eq!(ctx.evaluate::<u32>("func(20) / 5"), Ok(8));
1081        assert_eq!(ctx.evaluate::<u32>("func2(20, 20) / 5"), Ok(8));
1082    }
1083
1084    #[test]
1085    fn test_evaluating_exprs_with_u128() {
1086        let mut ctx = Context::<TypeSet128>::new_with_types();
1087
1088        ctx.add_variable::<u128>("value", 30);
1089        ctx.add_function("func", |v: u128| 2 * v);
1090        ctx.add_function("func2", |v1: u128, v2: u128| v1 + v2);
1091
1092        assert_eq!(ctx.evaluate::<bool>("value / 5 == 6"), Ok(true));
1093        assert_eq!(ctx.evaluate::<u128>("func(20) / 5"), Ok(8));
1094        assert_eq!(ctx.evaluate::<u128>("func2(20, 20) / 5"), Ok(8));
1095    }
1096
1097    #[test]
1098    fn test_evaluate_function() {
1099        let mut ctx =
1100            Context::parse("fn multiply_with_global(a: int) -> int { return a * global; }")
1101                .unwrap();
1102
1103        ctx.add_variable::<u64>("global", 3);
1104
1105        assert_eq!(
1106            ctx.evaluate::<bool>("multiply_with_global(2) == 6"),
1107            Ok(true)
1108        );
1109        assert!(ctx
1110            .evaluate::<bool>("multiply_with_global(\"2\") == 6")
1111            .is_err());
1112    }
1113
1114    #[test]
1115    fn run_eval_tests() {
1116        fn filter(path: &Path) -> bool {
1117            let Ok(env) = std::env::var("TEST_FILTER") else {
1118                // No filter set, walk folders and somni source files.
1119                return path.is_dir() || path.extension().map_or(false, |ext| ext == "sm");
1120            };
1121
1122            Path::new(&env) == path
1123        }
1124
1125        fn walk(dir: &Path, on_file: &impl Fn(&Path)) {
1126            for entry in std::fs::read_dir(dir)
1127                .unwrap_or_else(|_| panic!("Folder not found: {}", dir.display()))
1128                .flatten()
1129            {
1130                let path = entry.path();
1131
1132                if !filter(&path) {
1133                    continue;
1134                }
1135
1136                if path.is_file() {
1137                    on_file(&path);
1138                } else {
1139                    walk(&path, on_file);
1140                }
1141            }
1142        }
1143
1144        fn run_eval_test(path: &Path) {
1145            fn parse(source: &str) -> Context<'_> {
1146                let mut context = Context::parse(&source).unwrap();
1147
1148                context.add_function("add_from_rust", |a: u64, b: u64| -> i64 { (a + b) as i64 });
1149                context.add_function("assert", |a: bool| a); // No-op to test calling Rust functions from expressions
1150                context.add_function("reverse", |s: &str| s.chars().rev().collect::<String>());
1151
1152                context
1153            }
1154
1155            let test_name = path.file_stem().unwrap();
1156            let parent = path.parent().unwrap().canonicalize().unwrap();
1157            let vm_error = parent.join(test_name).join("stderr");
1158            let expr_error = parent.join(test_name).join("stderr_expr");
1159            let source = std::fs::read_to_string(path).unwrap();
1160
1161            let expressions = source
1162                .lines()
1163                .filter_map(|line| line.trim().strip_prefix("//@"))
1164                .collect::<Vec<_>>();
1165
1166            let mut context = parse(&source);
1167            let fail_expected = std::fs::exists(&expr_error).unwrap_or(false)
1168                || std::fs::exists(&vm_error).unwrap_or(false);
1169
1170            let blessed = std::env::var("BLESS").as_deref() == Ok("1");
1171
1172            for expression in &expressions {
1173                let expression = if let Some(e) = expression.strip_prefix('+') {
1174                    // `//@+` preserves VM state (like changes to globals)
1175                    e.trim()
1176                } else {
1177                    // `//@` resets VM state (like changes to globals)
1178                    context = parse(&source);
1179                    expression
1180                };
1181                println!("Running `{expression}`");
1182                match context.evaluate::<TypedValue>(expression) {
1183                    Ok(_) if fail_expected => {
1184                        panic!(
1185                            "Expected {} to fail evaluating, but it succeeded",
1186                            path.display()
1187                        )
1188                    }
1189                    Ok(value) => assert_eq!(
1190                        value,
1191                        TypedValue::Bool(true),
1192                        "{}: Expression `{expression}` evaluated to {value:?}",
1193                        path.display()
1194                    ),
1195                    Err(e) if fail_expected => {
1196                        let error = strip_ansi(format!("{e:?}"));
1197                        if blessed {
1198                            std::fs::write(&expr_error, error).unwrap();
1199                        } else {
1200                            let expected_error = std::fs::read_to_string(&expr_error).unwrap();
1201                            pretty_assertions::assert_eq!(strip_ansi(expected_error), error);
1202                        }
1203                    }
1204                    Err(e) => panic!("{}: {e:?}", path.display()),
1205                };
1206            }
1207        }
1208
1209        walk("../tests/eval".as_ref(), &|path| {
1210            run_eval_test(path);
1211        });
1212    }
1213
1214    #[test]
1215    fn test_eval_error() {
1216        let mut ctx = Context::new();
1217
1218        ctx.add_function("func", |v1: u64, v2: u64| v1 + v2);
1219
1220        let err = ctx
1221            .evaluate::<u64>("func(20, true)")
1222            .expect_err("Expected expression to return an error");
1223
1224        pretty_assertions::assert_eq!(
1225            strip_ansi(format!("\n{err:?}")),
1226            r#"
1227Evaluation error
1228 ---> at line 1 column 10
1229  |
12301 | func(20, true)
1231  |          ^^^^ func expects argument 1 to be u64, got bool"#,
1232        );
1233    }
1234}