somni_expr/
lib.rs

1//! # Somni expression evaluation Library
2//!
3//! This library provides tools for evaluating expressions.
4//!
5//! ## Overview
6//!
7//! The expression language includes:
8//!
9//! - Literals: integers, floats, booleans, strings.
10//! - Variables
11//! - A basic set of operators
12//! - Function calls
13//!
14//! The expression language does not include:
15//!
16//! - Control flow (if, loops, etc.)
17//! - Complex data structures (arrays, objects, etc.)
18//! - Defining functions and variables (these are provided by the context)
19//!
20//! ## Operators
21//!
22//! The following binary operators are supported, in order of precedence:
23//!
24//! - `||`: logical OR, short-circuiting
25//! - `&&`: logical AND, short-circuiting
26//! - `<`, `<=`, `>`, `>=`, `==`, `!=`: comparison operators
27//! - `|`: bitwise OR
28//! - `^`: bitwise XOR
29//! - `&`: bitwise AND
30//! - `<<`, `>>`: bitwise shift
31//! - `+`, `-`: addition and subtraction
32//! - `*`, `/`: multiplication and division
33//!
34//! Unary operators include:
35//! - `!`: logical NOT
36//! - `-`: negation
37//!
38//! For the full specification of the grammar, see the [`parser`] module's documentation.
39//!
40//! ## Numeric types
41//!
42//! The Somni language supports three numeric types:
43//!
44//! - Integers
45//! - Signed integers
46//! - Floats
47//!
48//! By default, the library uses the [`DefaultTypeSet`], which uses `u64`, `i64`, and `f64` for
49//! these types. You can use other type sets like [`TypeSet32`] or [`TypeSet128`] to use
50//! 32-bit or 128-bit integers and floats. You need to specify the type set when creating
51//! the context.
52//!
53//! ## Usage
54//!
55//! To evaluate an expression, you need to create a [`Context`] first. You can assign
56//! variables and define functions in this context, and then you can use this context
57//! to evaluate expressions.
58//!
59//! ```rust
60//! use somni_expr::Context;
61//!
62//! let mut context = Context::new();
63//!
64//! // Define a variable
65//! context.add_variable::<u64>("x", 42);
66//! context.add_function("add_one", |x: u64| { x + 1 });
67//! context.add_function("floor", |x: f64| { x.floor() as u64 });
68//!
69//! // Evaluate an expression - we expect it to evaluate
70//! // to a number, which is u64 in the default type set.
71//! let result = context.evaluate::<u64>("add_one(x + floor(1.2))");
72//!
73//! assert_eq!(result, Ok(44));
74//! ```
75#![warn(missing_docs)]
76
77pub mod error;
78#[doc(hidden)]
79pub mod function;
80#[doc(hidden)]
81pub mod string_interner;
82#[doc(hidden)]
83pub mod value;
84
85pub use function::{DynFunction, FunctionCallError};
86pub use value::TypedValue;
87
88use std::{
89    collections::HashMap,
90    fmt::{Debug, Display},
91};
92
93use indexmap::IndexMap;
94use somni_parser::{
95    ast::{self, Expression},
96    lexer::{self, Location},
97    parser,
98};
99
100use crate::{
101    error::MarkInSource,
102    function::ExprFn,
103    string_interner::{StringIndex, StringInterner},
104    value::{Load, MemoryRepr, Store, ValueType},
105};
106
107pub use somni_parser::parser::{DefaultTypeSet, TypeSet128, TypeSet32};
108
109mod private {
110    pub trait Sealed {}
111    impl Sealed for u32 {}
112    impl Sealed for u64 {}
113    impl Sealed for u128 {}
114}
115
116use private::Sealed;
117
118/// Defines numeric types in expressions.
119pub trait TypeSet: somni_parser::parser::TypeSet + PartialEq
120where
121    Self::Integer: ValueType,
122    Self::Float: ValueType,
123{
124    /// The type of signed integers in this type set.
125    type SignedInteger: ValueType;
126
127    /// Converts an unsigned integer into a signed integer.
128    fn to_signed(v: Self::Integer) -> Result<Self::SignedInteger, OperatorError>;
129}
130
131impl TypeSet for DefaultTypeSet {
132    type SignedInteger = i64;
133
134    fn to_signed(v: Self::Integer) -> Result<Self::SignedInteger, OperatorError> {
135        i64::try_from(v).map_err(|_| OperatorError::RuntimeError)
136    }
137}
138
139impl TypeSet for TypeSet32 {
140    type SignedInteger = i32;
141
142    fn to_signed(v: Self::Integer) -> Result<Self::SignedInteger, OperatorError> {
143        i32::try_from(v).map_err(|_| OperatorError::RuntimeError)
144    }
145}
146
147impl TypeSet for TypeSet128 {
148    type SignedInteger = i128;
149
150    fn to_signed(v: Self::Integer) -> Result<Self::SignedInteger, OperatorError> {
151        i128::try_from(v).map_err(|_| OperatorError::RuntimeError)
152    }
153}
154
155#[doc(hidden)]
156pub trait Integer: ValueType + Sealed {
157    fn from_usize(value: usize) -> Self;
158}
159
160impl Integer for u32 {
161    fn from_usize(value: usize) -> Self {
162        u32::try_from(value).unwrap()
163    }
164}
165impl Integer for u64 {
166    fn from_usize(value: usize) -> Self {
167        u64::try_from(value).unwrap()
168    }
169}
170impl Integer for u128 {
171    fn from_usize(value: usize) -> Self {
172        u128::try_from(value).unwrap()
173    }
174}
175
176/// Represents an error that can occur during operator evaluation.
177#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
178pub enum OperatorError {
179    /// A type error occurred.
180    TypeError,
181    /// A runtime error occurred.
182    RuntimeError,
183}
184
185impl Display for OperatorError {
186    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
187        let message = match self {
188            OperatorError::TypeError => "Type error",
189            OperatorError::RuntimeError => "Runtime error",
190        };
191
192        f.write_str(message)
193    }
194}
195
196macro_rules! dispatch_binary {
197    ($method:ident) => {
198        pub(crate) fn $method(
199            ctx: &mut dyn ExprContext<T>,
200            lhs: Self,
201            rhs: Self,
202        ) -> Result<Self, OperatorError> {
203            let result = match (lhs, rhs) {
204                (Self::Bool(value), Self::Bool(other)) => {
205                    ValueType::$method(value, other)?.store(ctx)
206                }
207                (Self::Int(value), Self::Int(other)) => {
208                    ValueType::$method(value, other)?.store(ctx)
209                }
210                (Self::SignedInt(value), Self::SignedInt(other)) => {
211                    ValueType::$method(value, other)?.store(ctx)
212                }
213                (Self::MaybeSignedInt(value), Self::MaybeSignedInt(other)) => {
214                    match ValueType::$method(value, other)?.store(ctx) {
215                        Self::Int(v) => Self::MaybeSignedInt(v),
216                        other => other,
217                    }
218                }
219                (Self::Float(value), Self::Float(other)) => {
220                    ValueType::$method(value, other)?.store(ctx)
221                }
222                (Self::String(value), Self::String(other)) => {
223                    ValueType::$method(value, other)?.store(ctx)
224                }
225                (Self::Int(value), Self::MaybeSignedInt(other)) => {
226                    ValueType::$method(value, other)?.store(ctx)
227                }
228                (Self::MaybeSignedInt(value), Self::Int(other)) => {
229                    ValueType::$method(value, other)?.store(ctx)
230                }
231                (Self::SignedInt(value), Self::MaybeSignedInt(other)) => {
232                    ValueType::$method(value, T::to_signed(other)?)?.store(ctx)
233                }
234                (Self::MaybeSignedInt(value), Self::SignedInt(other)) => {
235                    ValueType::$method(T::to_signed(value)?, other)?.store(ctx)
236                }
237                _ => return Err(OperatorError::TypeError),
238            };
239
240            Ok(result)
241        }
242    };
243}
244
245macro_rules! dispatch_unary {
246    ($method:ident) => {
247        pub(crate) fn $method(
248            ctx: &mut dyn ExprContext<T>,
249            operand: Self,
250        ) -> Result<Self, OperatorError> {
251            match operand {
252                Self::Bool(value) => Ok(ValueType::$method(value)?.store(ctx)),
253                Self::Int(value) | Self::MaybeSignedInt(value) => {
254                    Ok(ValueType::$method(value)?.store(ctx))
255                }
256                Self::SignedInt(value) => Ok(ValueType::$method(value)?.store(ctx)),
257                Self::Float(value) => Ok(ValueType::$method(value)?.store(ctx)),
258                Self::String(value) => Ok(ValueType::$method(value)?.store(ctx)),
259                _ => return Err(OperatorError::TypeError),
260            }
261        }
262    };
263}
264
265impl<T> TypedValue<T>
266where
267    T: TypeSet,
268    T::Integer: Load<T> + Store<T>,
269    T::Float: Load<T> + Store<T>,
270    T::SignedInteger: Load<T> + Store<T>,
271
272    <T::Integer as ValueType>::NegateOutput: Load<T> + Store<T>,
273    <T::SignedInteger as ValueType>::NegateOutput: Load<T> + Store<T>,
274    <T::Float as ValueType>::NegateOutput: Load<T> + Store<T>,
275{
276    dispatch_binary!(equals);
277    dispatch_binary!(less_than);
278    dispatch_binary!(less_than_or_equal);
279    dispatch_binary!(not_equals);
280    dispatch_binary!(bitwise_or);
281    dispatch_binary!(bitwise_xor);
282    dispatch_binary!(bitwise_and);
283    dispatch_binary!(shift_left);
284    dispatch_binary!(shift_right);
285    dispatch_binary!(add);
286    dispatch_binary!(subtract);
287    dispatch_binary!(multiply);
288    dispatch_binary!(divide);
289    dispatch_unary!(not);
290    dispatch_unary!(negate);
291}
292
293/// An expression context that provides the necessary environment for evaluating expressions.
294pub trait ExprContext<T = DefaultTypeSet>
295where
296    T: TypeSet,
297    T::Integer: ValueType,
298    T::Float: ValueType,
299{
300    /// Implements string interning.
301    fn intern_string(&mut self, s: &str) -> StringIndex;
302
303    /// Loads an interned string.
304    fn load_interned_string(&self, idx: StringIndex) -> &str;
305
306    /// Attempts to load a variable from the context.
307    fn try_load_variable(&self, variable: &str) -> Option<TypedValue<T>>;
308
309    /// Returns the address of a variable in the context.
310    fn address_of(&self, variable: &str) -> TypedValue<T>;
311
312    /// Calls a function in the context.
313    fn call_function(
314        &mut self,
315        function_name: &str,
316        args: &[TypedValue<T>],
317    ) -> Result<TypedValue<T>, FunctionCallError>;
318}
319
320/// A visitor that can process an abstract syntax tree.
321pub struct ExpressionVisitor<'a, C, T = DefaultTypeSet> {
322    /// The context in which the expression is evaluated.
323    pub context: &'a mut C,
324    /// The source code from which the expression was parsed.
325    pub source: &'a str,
326    /// The types of the variables in the context.
327    pub _marker: std::marker::PhantomData<T>,
328}
329
330/// An error that occurs during evaluation of an expression.
331#[derive(Clone, Debug, PartialEq)]
332pub struct EvalError {
333    /// The error message.
334    pub message: Box<str>,
335    /// The location in the source code where the error occurred.
336    pub location: Location,
337}
338
339impl Display for EvalError {
340    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
341        write!(f, "Evaluation error: {}", self.message)
342    }
343}
344
345/// An error that occurs during evaluation.
346///
347/// Printing this error will show the error message and the location in the source code.
348///
349/// ```rust
350/// use somni_expr::{Context, TypeSet32};
351/// let mut ctx = Context::<TypeSet32>::new_with_types();
352///
353/// let error = ctx.evaluate::<u32>("true + 1").unwrap_err();
354///
355/// println!("{error:?}");
356///
357/// // Output:
358/// //
359/// // Evaluation error
360/// // ---> at line 1 column 1
361/// //   |
362/// // 1 | true + 1
363/// //   | ^^^^^^^^ Failed to evaluate expression: Type error
364/// ```
365#[derive(Clone, PartialEq)]
366pub struct ExpressionError<'s> {
367    error: EvalError,
368    source: &'s str,
369}
370
371impl ExpressionError<'_> {
372    /// Returns the inner [`EvalError`].
373    pub fn into_inner(self) -> EvalError {
374        self.error
375    }
376}
377
378impl Debug for ExpressionError<'_> {
379    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
380        let marked = MarkInSource(
381            self.source,
382            self.error.location,
383            "Evaluation error",
384            &self.error.message,
385        );
386        marked.fmt(f)
387    }
388}
389
390impl<'a, C, T> ExpressionVisitor<'a, C, T>
391where
392    C: ExprContext<T>,
393    T: TypeSet,
394    T::Integer: Load<T> + Store<T>,
395    T::Float: Load<T> + Store<T>,
396    T::SignedInteger: Load<T> + Store<T>,
397
398    <T::Integer as ValueType>::NegateOutput: Load<T> + Store<T>,
399    <T::SignedInteger as ValueType>::NegateOutput: Load<T> + Store<T>,
400    <T::Float as ValueType>::NegateOutput: Load<T> + Store<T>,
401{
402    fn visit_variable(&mut self, variable: &lexer::Token) -> Result<TypedValue<T>, EvalError> {
403        let name = variable.source(self.source);
404        self.context.try_load_variable(name).ok_or(EvalError {
405            message: format!("Variable {name} was not found").into_boxed_str(),
406            location: variable.location,
407        })
408    }
409
410    /// Visits an expression and evaluates it, returning the result as a `TypedValue`.
411    pub fn visit_expression(
412        &mut self,
413        expression: &Expression<T>,
414    ) -> Result<TypedValue<T>, EvalError> {
415        let result = match expression {
416            Expression::Variable { variable } => self.visit_variable(variable)?,
417            Expression::Literal { value } => match &value.value {
418                ast::LiteralValue::Integer(value) => TypedValue::<T>::MaybeSignedInt(*value),
419                ast::LiteralValue::Float(value) => TypedValue::<T>::Float(*value),
420                ast::LiteralValue::String(value) => {
421                    TypedValue::<T>::String(self.context.intern_string(value))
422                }
423                ast::LiteralValue::Boolean(value) => TypedValue::<T>::Bool(*value),
424            },
425            Expression::UnaryOperator { name, operand } => match name.source(self.source) {
426                "!" => {
427                    let operand = self.visit_expression(operand)?;
428
429                    match TypedValue::<T>::not(self.context, operand) {
430                        Ok(r) => r,
431                        Err(error) => {
432                            return Err(EvalError {
433                                message: format!("Failed to evaluate expression: {error}")
434                                    .into_boxed_str(),
435                                location: expression.location(),
436                            });
437                        }
438                    }
439                }
440
441                "-" => {
442                    let value = self.visit_expression(operand)?;
443                    let ty = value.type_of();
444                    TypedValue::<T>::negate(self.context, value).map_err(|e| EvalError {
445                        message: format!("Cannot negate {ty}: {e}").into_boxed_str(),
446                        location: operand.location(),
447                    })?
448                }
449                "&" => match operand.as_variable() {
450                    Some(variable) => {
451                        let name = variable.source(self.source);
452                        self.context.address_of(name)
453                    }
454                    None => {
455                        return Err(EvalError {
456                            message: String::from("Cannot take address of non-variable expression")
457                                .into_boxed_str(),
458                            location: operand.location(),
459                        });
460                    }
461                },
462                "*" => {
463                    return Err(EvalError {
464                        message: String::from("Dereference not supported").into_boxed_str(),
465                        location: operand.location(),
466                    });
467                }
468                _ => {
469                    return Err(EvalError {
470                        message: format!("Unknown unary operator: {}", name.source(self.source))
471                            .into_boxed_str(),
472                        location: expression.location(),
473                    });
474                }
475            },
476            Expression::BinaryOperator { name, operands } => {
477                let short_circuiting = ["&&", "||"];
478                let operator = name.source(self.source);
479
480                if short_circuiting.contains(&operator) {
481                    let lhs = self.visit_expression(&operands[0])?;
482                    return match operator {
483                        "&&" if lhs == TypedValue::<T>::Bool(false) => Ok(TypedValue::Bool(false)),
484                        "||" if lhs == TypedValue::<T>::Bool(true) => Ok(TypedValue::Bool(true)),
485                        _ => self.visit_expression(&operands[1]),
486                    };
487                }
488
489                let lhs = self.visit_expression(&operands[0])?;
490                let rhs = self.visit_expression(&operands[1])?;
491                let result = match name.source(self.source) {
492                    "+" => TypedValue::<T>::add(self.context, lhs, rhs),
493                    "-" => TypedValue::<T>::subtract(self.context, lhs, rhs),
494                    "*" => TypedValue::<T>::multiply(self.context, lhs, rhs),
495                    "/" => TypedValue::<T>::divide(self.context, lhs, rhs),
496                    "<" => TypedValue::<T>::less_than(self.context, lhs, rhs),
497                    ">" => TypedValue::<T>::less_than(self.context, rhs, lhs),
498                    "<=" => TypedValue::<T>::less_than_or_equal(self.context, lhs, rhs),
499                    ">=" => TypedValue::<T>::less_than_or_equal(self.context, rhs, lhs),
500                    "==" => TypedValue::<T>::equals(self.context, lhs, rhs),
501                    "!=" => TypedValue::<T>::not_equals(self.context, lhs, rhs),
502                    "|" => TypedValue::<T>::bitwise_or(self.context, lhs, rhs),
503                    "^" => TypedValue::<T>::bitwise_xor(self.context, lhs, rhs),
504                    "&" => TypedValue::<T>::bitwise_and(self.context, lhs, rhs),
505                    "<<" => TypedValue::<T>::shift_left(self.context, lhs, rhs),
506                    ">>" => TypedValue::<T>::shift_right(self.context, lhs, rhs),
507
508                    other => {
509                        return Err(EvalError {
510                            message: format!("Unknown binary operator: {other}").into_boxed_str(),
511                            location: expression.location(),
512                        });
513                    }
514                };
515
516                match result {
517                    Ok(r) => r,
518                    Err(error) => {
519                        return Err(EvalError {
520                            message: format!("Failed to evaluate expression: {error}")
521                                .into_boxed_str(),
522                            location: expression.location(),
523                        });
524                    }
525                }
526            }
527            Expression::FunctionCall { name, arguments } => {
528                let function_name = name.source(self.source);
529                let mut args = Vec::with_capacity(arguments.len());
530                for arg in arguments {
531                    args.push(self.visit_expression(arg)?);
532                }
533
534                match self.context.call_function(function_name, &args) {
535                    Ok(result) => result,
536                    Err(FunctionCallError::IncorrectArgumentCount { expected }) => {
537                        return Err(EvalError {
538                            message: format!(
539                                "{function_name} takes {expected} arguments, {} given",
540                                args.len()
541                            )
542                            .into_boxed_str(),
543                            location: expression.location(),
544                        });
545                    }
546                    Err(FunctionCallError::IncorrectArgumentType { idx, expected }) => {
547                        return Err(EvalError {
548                            message: format!(
549                                "{function_name} expects argument {idx} to be {expected}, got {}",
550                                args[idx].type_of()
551                            )
552                            .into_boxed_str(),
553                            location: arguments[idx].location(),
554                        });
555                    }
556                    Err(FunctionCallError::FunctionNotFound) => {
557                        return Err(EvalError {
558                            message: format!("Function {function_name} is not found")
559                                .into_boxed_str(),
560                            location: expression.location(),
561                        });
562                    }
563                    Err(FunctionCallError::Other(error)) => {
564                        return Err(EvalError {
565                            message: format!("Failed to call {function_name}: {error}")
566                                .into_boxed_str(),
567                            location: expression.location(),
568                        });
569                    }
570                }
571            }
572        };
573
574        Ok(result)
575    }
576}
577
578/// A type in the Somni language.
579#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
580pub enum Type {
581    /// Represents no value, used for e.g. functions that do not return a value.
582    Void,
583    /// Represents integer that may be signed or unsigned.
584    MaybeSignedInt,
585    /// Represents an unsigned integer.
586    Int,
587    /// Represents a signed integer.
588    SignedInt,
589    /// Represents a floating point number.
590    Float,
591    /// Represents a boolean value.
592    Bool,
593    /// Represents a string value.
594    String,
595}
596impl Type {
597    /// Returns the size of the type in bytes.
598    pub fn size_of<T>(&self) -> usize
599    where
600        T: TypeSet,
601        T::Integer: ValueType + MemoryRepr,
602        T::SignedInteger: ValueType + MemoryRepr,
603        T::Float: ValueType + MemoryRepr,
604    {
605        match self {
606            Type::Void => <() as MemoryRepr>::BYTES,
607            Type::Int | Type::MaybeSignedInt => <T::Integer as MemoryRepr>::BYTES,
608            Type::SignedInt => <T::SignedInteger as MemoryRepr>::BYTES,
609            Type::Float => <T::Float as MemoryRepr>::BYTES,
610            Type::Bool => <bool as MemoryRepr>::BYTES,
611            Type::String => <StringIndex as MemoryRepr>::BYTES,
612        }
613    }
614}
615
616impl Display for Type {
617    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
618        match self {
619            Type::Void => write!(f, "void"),
620            Type::MaybeSignedInt => write!(f, "{{int/signed}}"),
621            Type::Int => write!(f, "int"),
622            Type::SignedInt => write!(f, "signed"),
623            Type::Bool => write!(f, "bool"),
624            Type::String => write!(f, "string"),
625            Type::Float => write!(f, "float"),
626        }
627    }
628}
629
630/// The expression context, which holds variables, functions, and other state needed for evaluation.
631#[derive(Default)]
632pub struct Context<'ctx, T = DefaultTypeSet>
633where
634    T: TypeSet,
635    T::Integer: Load<T> + Store<T>,
636    T::Float: Load<T> + Store<T>,
637    T::SignedInteger: Load<T> + Store<T>,
638{
639    variables: IndexMap<String, TypedValue<T>>,
640    functions: HashMap<String, ExprFn<'ctx, T>>,
641    strings: StringInterner,
642    marker: std::marker::PhantomData<T>,
643}
644
645impl<T> ExprContext<T> for Context<'_, T>
646where
647    T: TypeSet,
648    T::Integer: Load<T> + Store<T> + Integer,
649    T::Float: Load<T> + Store<T>,
650    T::SignedInteger: Load<T> + Store<T>,
651{
652    fn intern_string(&mut self, s: &str) -> StringIndex {
653        self.strings.intern(s)
654    }
655
656    fn load_interned_string(&self, idx: StringIndex) -> &str {
657        self.strings.lookup(idx)
658    }
659
660    fn try_load_variable(&self, variable: &str) -> Option<TypedValue<T>> {
661        self.variables.get(variable).cloned()
662    }
663
664    fn address_of(&self, variable: &str) -> TypedValue<T> {
665        let index = self.variables.get_index_of(variable).unwrap();
666        TypedValue::Int(<T::Integer as Integer>::from_usize(index))
667    }
668
669    fn call_function(
670        &mut self,
671        function_name: &str,
672        args: &[TypedValue<T>],
673    ) -> Result<TypedValue<T>, FunctionCallError> {
674        match self.functions.remove_entry(function_name) {
675            Some((name, func)) => {
676                let retval = func.call(self, args);
677                self.functions.insert(name, func);
678
679                retval
680            }
681            None => Err(FunctionCallError::FunctionNotFound),
682        }
683    }
684}
685
686impl<'ctx> Context<'ctx, DefaultTypeSet> {
687    /// Creates a new context with [default types][DefaultTypeSet].
688    pub fn new() -> Self {
689        Self::new_with_types()
690    }
691}
692
693impl<'ctx, T> Context<'ctx, T>
694where
695    T: TypeSet,
696    T::Integer: Load<T> + Store<T> + Integer,
697    T::Float: Load<T> + Store<T>,
698    T::SignedInteger: Load<T> + Store<T>,
699
700    <T::Integer as ValueType>::NegateOutput: Load<T> + Store<T>,
701    <T::SignedInteger as ValueType>::NegateOutput: Load<T> + Store<T>,
702    <T::Float as ValueType>::NegateOutput: Load<T> + Store<T>,
703{
704    /// Creates a new context. The type set must be specified when using this function.
705    ///
706    /// ```rust
707    /// use somni_expr::{Context, TypeSet32};
708    /// let mut ctx = Context::<TypeSet32>::new_with_types();
709    /// ```
710    pub fn new_with_types() -> Self {
711        Self {
712            variables: IndexMap::new(),
713            functions: HashMap::new(),
714            strings: StringInterner::new(),
715            marker: std::marker::PhantomData,
716        }
717    }
718
719    fn evaluate_any_impl(&mut self, expression: &str) -> Result<TypedValue<T>, EvalError> {
720        // TODO: we can allow new globals to be defined in the expression, but that would require
721        // storing a copy of the original globals, so that they can be reset?
722        let tokens = match lexer::tokenize(expression).collect::<Result<Vec<_>, _>>() {
723            Ok(tokens) => tokens,
724            Err(e) => {
725                return Err(EvalError {
726                    message: format!("Syntax error: {e}").into_boxed_str(),
727                    location: e.location,
728                });
729            }
730        };
731        let ast = match parser::parse_expression(expression, &tokens) {
732            Ok(ast) => ast,
733            Err(e) => {
734                return Err(EvalError {
735                    message: format!("Parser error: {e}").into_boxed_str(),
736                    location: e.location,
737                });
738            }
739        };
740
741        let mut visitor = ExpressionVisitor::<Self, T> {
742            context: self,
743            source: expression,
744            _marker: std::marker::PhantomData,
745        };
746
747        visitor.visit_expression(&ast)
748    }
749
750    /// Evaluates an expression and returns the result as a [`TypedValue<T>`].
751    pub fn evaluate_any<'s>(
752        &mut self,
753        expression: &'s str,
754    ) -> Result<TypedValue<T>, ExpressionError<'s>> {
755        self.evaluate_any_impl(expression)
756            .map_err(|error| ExpressionError {
757                error,
758                source: expression,
759            })
760    }
761
762    /// Evaluates an expression and returns the result as a specific value type.
763    ///
764    /// This function will attempt to convert the result of the expression to the specified type `V`.
765    /// If the conversion fails, it will return an `ExpressionError`.
766    pub fn evaluate<'s, V>(
767        &'s mut self,
768        expression: &'s str,
769    ) -> Result<V::Output<'s>, ExpressionError<'s>>
770    where
771        V: Load<T>,
772    {
773        let result = self.evaluate_any(expression)?;
774        let result_ty = result.type_of();
775        V::load(self, result).ok_or_else(|| ExpressionError {
776            error: EvalError {
777                message: format!(
778                    "Expression evaluates to {result_ty}, which cannot be converted to {}",
779                    V::TYPE
780                )
781                .into_boxed_str(),
782                location: Location::dummy(),
783            },
784            source: expression,
785        })
786    }
787
788    /// Defines a new variable in the context.
789    pub fn add_variable<V>(&mut self, name: &str, value: V)
790    where
791        V: Store<T>,
792    {
793        let stored = value.store(self);
794        self.variables.insert(name.to_string(), stored);
795    }
796
797    /// Adds a new function to the context.
798    pub fn add_function<F, A>(&mut self, name: &str, func: F)
799    where
800        F: DynFunction<A, T> + 'ctx,
801    {
802        let func = ExprFn::new(func);
803        self.functions.insert(name.to_string(), func);
804    }
805}
806
807#[macro_export]
808#[doc(hidden)]
809macro_rules! for_all_tuples {
810    ($pat:tt => $code:tt;) => {
811        macro_rules! inner { $pat => $code; }
812
813        inner!();
814        inner!(V1);
815        inner!(V1, V2);
816        inner!(V1, V2, V3);
817        inner!(V1, V2, V3, V4);
818        inner!(V1, V2, V3, V4, V5);
819        inner!(V1, V2, V3, V4, V5, V6);
820        inner!(V1, V2, V3, V4, V5, V6, V7);
821        inner!(V1, V2, V3, V4, V5, V6, V7, V8);
822        inner!(V1, V2, V3, V4, V5, V6, V7, V8, V9);
823        inner!(V1, V2, V3, V4, V5, V6, V7, V8, V9, V10);
824    };
825}
826
827#[cfg(test)]
828mod test {
829    use super::*;
830
831    fn strip_ansi(s: impl AsRef<str>) -> String {
832        use ansi_parser::AnsiParser;
833        fn text_block(output: ansi_parser::Output<'_>) -> Option<&str> {
834            match output {
835                ansi_parser::Output::TextBlock(text) => Some(text),
836                _ => None,
837            }
838        }
839
840        s.as_ref()
841            .ansi_parse()
842            .filter_map(text_block)
843            .collect::<String>()
844    }
845
846    #[test]
847    fn test_evaluating_exprs() {
848        let mut ctx = Context::new();
849
850        ctx.add_variable::<i64>("signed", 30);
851        ctx.add_variable::<u64>("value", 30);
852        ctx.add_function("func", |v: u64| 2 * v);
853        ctx.add_function("func2", |v1: u64, v2: u64| v1 + v2);
854        ctx.add_function("five", || "five");
855        ctx.add_function("is_five", |num: &str| num == "five");
856        ctx.add_function("concatenate", |a: &str, b: &str| format!("{a}{b}"));
857
858        assert_eq!(ctx.evaluate::<bool>("value / 5 == 6"), Ok(true));
859        assert_eq!(ctx.evaluate::<bool>("five() == \"five\""), Ok(true));
860        assert_eq!(
861            ctx.evaluate::<bool>("is_five(five()) != is_five(\"six\")"),
862            Ok(true)
863        );
864        assert_eq!(ctx.evaluate::<u64>("func(20) / 5"), Ok(8));
865        assert_eq!(ctx.evaluate::<u64>("func2(20, 20) / 5"), Ok(8));
866        assert_eq!(ctx.evaluate::<bool>("true & false"), Ok(false));
867        assert_eq!(ctx.evaluate::<bool>("!true"), Ok(false));
868        assert_eq!(ctx.evaluate::<bool>("false | false"), Ok(false));
869        assert_eq!(ctx.evaluate::<bool>("true ^ true"), Ok(false));
870        assert_eq!(ctx.evaluate::<u64>("!0x1111"), Ok(0xFFFF_FFFF_FFFF_EEEE));
871        assert_eq!(
872            ctx.evaluate::<&str>("concatenate(five(), \"six\")"),
873            Ok("fivesix")
874        );
875        assert_eq!(ctx.evaluate::<bool>("signed * 2 == 60"), Ok(true));
876    }
877
878    #[test]
879    fn test_evaluating_exprs_with_u32() {
880        let mut ctx = Context::<TypeSet32>::new_with_types();
881
882        ctx.add_variable::<u32>("value", 30);
883        ctx.add_function("func", |v: u32| 2 * v);
884        ctx.add_function("func2", |v1: u32, v2: u32| v1 + v2);
885
886        assert_eq!(ctx.evaluate::<bool>("value / 5 == 6"), Ok(true));
887        assert_eq!(ctx.evaluate::<u32>("func(20) / 5"), Ok(8));
888        assert_eq!(ctx.evaluate::<u32>("func2(20, 20) / 5"), Ok(8));
889    }
890
891    #[test]
892    fn test_evaluating_exprs_with_u128() {
893        let mut ctx = Context::<TypeSet128>::new_with_types();
894
895        ctx.add_variable::<u128>("value", 30);
896        ctx.add_function("func", |v: u128| 2 * v);
897        ctx.add_function("func2", |v1: u128, v2: u128| v1 + v2);
898
899        assert_eq!(ctx.evaluate::<bool>("value / 5 == 6"), Ok(true));
900        assert_eq!(ctx.evaluate::<u128>("func(20) / 5"), Ok(8));
901        assert_eq!(ctx.evaluate::<u128>("func2(20, 20) / 5"), Ok(8));
902    }
903
904    #[test]
905    fn test_eval_error() {
906        let mut ctx = Context::new();
907
908        ctx.add_function("func", |v1: u64, v2: u64| v1 + v2);
909
910        let err = ctx
911            .evaluate::<u64>("func(20, true)")
912            .expect_err("Expected expression to return an error");
913
914        pretty_assertions::assert_eq!(
915            strip_ansi(format!("\n{err:?}")),
916            r#"
917Evaluation error
918 ---> at line 1 column 10
919  |
9201 | func(20, true)
921  |          ^^^^ func expects argument 1 to be int, got bool"#,
922        );
923    }
924}