Skip to main content

wick_core/
lib.rs

1//! # dew-core
2//!
3//! Minimal expression language, multiple backends.
4//!
5//! This crate provides a simple expression parser that compiles string expressions
6//! into evaluable ASTs. Variables and functions are provided by the caller—nothing
7//! is hardcoded—making it suitable for user-facing expression inputs, shader
8//! parameter systems, and dynamic computation pipelines.
9//!
10//! ## Design Philosophy
11//!
12//! - **Minimal by default**: Core supports only arithmetic and variables
13//! - **Opt-in complexity**: Enable `cond` for conditionals, `func` for function calls
14//! - **No runtime dependencies**: Pure Rust, no allocations during evaluation
15//! - **Backend-agnostic**: AST can be compiled to WGSL, Lua, Cranelift, or evaluated directly
16//!
17//! ## Features
18//!
19//! | Feature      | Description |
20//! |--------------|-------------|
21//! | `introspect` | AST introspection (`free_vars`, etc.) - **enabled by default** |
22//! | `cond`       | Conditionals (`if`/`then`/`else`), comparisons (`<`, `<=`, etc.), boolean logic (`and`, `or`, `not`) |
23//! | `func`       | Function calls via [`ExprFn`] trait and [`FunctionRegistry`] |
24//!
25//! ## Syntax Reference
26//!
27//! ### Operators (by precedence, low to high)
28//!
29//! | Precedence | Operators | Description |
30//! |------------|-----------|-------------|
31//! | 1 | `if c then a else b` | Conditional (requires `cond`) |
32//! | 2 | `a or b` | Logical OR, short-circuit (requires `cond`) |
33//! | 3 | `a and b` | Logical AND, short-circuit (requires `cond`) |
34//! | 4 | `<` `<=` `>` `>=` `==` `!=` | Comparison (requires `cond`) |
35//! | 5 | `a + b`, `a - b` | Addition, subtraction |
36//! | 6 | `a * b`, `a / b` | Multiplication, division |
37//! | 7 | `a ^ b` | Exponentiation (right-associative) |
38//! | 8 | `-a`, `not a` | Negation, logical NOT (`not` requires `cond`) |
39//! | 9 | `(a)`, `f(a, b)` | Grouping, function calls (calls require `func`) |
40//!
41//! ### Literals and Identifiers
42//!
43//! - **Numbers**: `42`, `3.14`, `.5`, `1.0`
44//! - **Variables**: Any identifier (`x`, `time`, `my_var`)
45//! - **Functions**: Identifier followed by parentheses (`sin(x)`, `clamp(x, 0, 1)`)
46//!
47//! ### Boolean Semantics (with `cond` feature)
48//!
49//! - `0.0` is false, any non-zero value is true
50//! - Comparisons and boolean operators return `1.0` (true) or `0.0` (false)
51//! - `and`/`or` use short-circuit evaluation
52//!
53//! ## Examples
54//!
55//! ### Basic Arithmetic
56//!
57//! ```
58//! use wick_core::Expr;
59//! use std::collections::HashMap;
60//!
61//! let expr = Expr::parse("x * 2 + y").unwrap();
62//!
63//! let mut vars = HashMap::new();
64//! vars.insert("x".to_string(), 3.0);
65//! vars.insert("y".to_string(), 1.0);
66//!
67//! # #[cfg(not(feature = "func"))]
68//! let value = expr.eval(&vars).unwrap();
69//! # #[cfg(feature = "func")]
70//! # let value = expr.eval(&vars, &wick_core::FunctionRegistry::new()).unwrap();
71//! assert_eq!(value, 7.0);  // 3 * 2 + 1 = 7
72//! ```
73//!
74//! ### Working with the AST
75//!
76//! ```
77//! use wick_core::{Expr, Ast, BinOp};
78//!
79//! let expr = Expr::parse("a + b * c").unwrap();
80//!
81//! // Inspect the AST structure
82//! match expr.ast() {
83//!     Ast::BinOp(BinOp::Add, left, right) => {
84//!         assert!(matches!(left.as_ref(), Ast::Var(name) if name == "a"));
85//!         assert!(matches!(right.as_ref(), Ast::BinOp(BinOp::Mul, _, _)));
86//!     }
87//!     _ => panic!("unexpected AST structure"),
88//! }
89//! ```
90//!
91//! ### Custom Functions (with `func` feature)
92//!
93#![cfg_attr(feature = "func", doc = "```")]
94#![cfg_attr(not(feature = "func"), doc = "```ignore")]
95//! use wick_core::{Expr, ExprFn, FunctionRegistry, Ast};
96//! use std::collections::HashMap;
97//!
98//! struct Clamp;
99//! impl ExprFn for Clamp {
100//!     fn name(&self) -> &str { "clamp" }
101//!     fn arg_count(&self) -> usize { 3 }
102//!     fn call(&self, args: &[f32]) -> f32 {
103//!         args[0].clamp(args[1], args[2])
104//!     }
105//! }
106//!
107//! let mut registry = FunctionRegistry::new();
108//! registry.register(Clamp);
109//!
110//! let expr = Expr::parse("clamp(x, 0, 1)").unwrap();
111//! let mut vars = HashMap::new();
112//! vars.insert("x".to_string(), 1.5);
113//!
114//! let value = expr.eval(&vars, &registry).unwrap();
115//! assert_eq!(value, 1.0);  // clamped to [0, 1]
116//! ```
117//!
118//! ### Conditionals (with `cond` feature)
119//!
120#![cfg_attr(feature = "cond", doc = "```")]
121#![cfg_attr(not(feature = "cond"), doc = "```ignore")]
122//! use wick_core::Expr;
123//! use std::collections::HashMap;
124//!
125//! let expr = Expr::parse("if x > 0 then x else -x").unwrap();  // absolute value
126//!
127//! let mut vars = HashMap::new();
128//! vars.insert("x".to_string(), -5.0);
129//!
130//! # #[cfg(not(feature = "func"))]
131//! let value = expr.eval(&vars).unwrap();
132//! # #[cfg(feature = "func")]
133//! # let value = expr.eval(&vars, &wick_core::FunctionRegistry::new()).unwrap();
134//! assert_eq!(value, 5.0);
135//! ```
136
137use std::collections::HashMap;
138#[cfg(feature = "introspect")]
139use std::collections::HashSet;
140#[cfg(feature = "func")]
141use std::sync::Arc;
142
143use num_traits::{Num, NumCast, One, Zero};
144use std::ops::Neg;
145
146#[cfg(feature = "optimize")]
147pub mod optimize;
148
149// ============================================================================
150// Numeric Trait
151// ============================================================================
152
153/// Trait for types that can be used as numeric values in expressions.
154///
155/// This is a marker trait that combines the necessary bounds for basic
156/// arithmetic operations. Both float and integer types implement this.
157pub trait Numeric:
158    Num
159    + NumCast
160    + Copy
161    + PartialOrd
162    + Zero
163    + One
164    + Neg<Output = Self>
165    + std::fmt::Debug
166    + Send
167    + Sync
168    + 'static
169{
170    /// Whether this type supports bitwise operations.
171    fn supports_bitwise() -> bool;
172
173    /// Whether this type is a floating-point type.
174    fn is_float() -> bool;
175
176    /// Compute self raised to the power of exp.
177    /// For floats, uses powf. For integers, uses repeated multiplication
178    /// (returns None for negative exponents).
179    fn numeric_pow(self, exp: Self) -> Option<Self>;
180}
181
182impl Numeric for f32 {
183    fn supports_bitwise() -> bool {
184        false
185    }
186    fn is_float() -> bool {
187        true
188    }
189    fn numeric_pow(self, exp: Self) -> Option<Self> {
190        Some(self.powf(exp))
191    }
192}
193
194impl Numeric for f64 {
195    fn supports_bitwise() -> bool {
196        false
197    }
198    fn is_float() -> bool {
199        true
200    }
201    fn numeric_pow(self, exp: Self) -> Option<Self> {
202        Some(self.powf(exp))
203    }
204}
205
206impl Numeric for i32 {
207    fn supports_bitwise() -> bool {
208        true
209    }
210    fn is_float() -> bool {
211        false
212    }
213    fn numeric_pow(self, exp: Self) -> Option<Self> {
214        if exp < 0 {
215            return None;
216        }
217        Some(self.pow(exp as u32))
218    }
219}
220
221impl Numeric for i64 {
222    fn supports_bitwise() -> bool {
223        true
224    }
225    fn is_float() -> bool {
226        false
227    }
228    fn numeric_pow(self, exp: Self) -> Option<Self> {
229        if exp < 0 {
230            return None;
231        }
232        Some(self.pow(exp as u32))
233    }
234}
235
236// Note: u32/u64 don't implement Neg, so they're not included in Numeric.
237// Use i32/i64 for integer vectors/matrices.
238
239// ============================================================================
240// ExprFn trait and registry (func feature)
241// ============================================================================
242
243/// A function that can be called from expressions.
244///
245/// Implement this trait to add custom functions.
246/// Constants (like `pi`) can be 0-arg functions.
247#[cfg(feature = "func")]
248pub trait ExprFn: Send + Sync {
249    /// Function name (e.g., "sin", "pi").
250    fn name(&self) -> &str;
251
252    /// Number of arguments this function expects.
253    fn arg_count(&self) -> usize;
254
255    /// Evaluate the function with the given arguments.
256    fn call(&self, args: &[f32]) -> f32;
257
258    /// Express as simpler expressions (enables automatic backend support).
259    /// If this returns Some, backends can compile without knowing about this function.
260    fn decompose(&self, _args: &[Ast]) -> Option<Ast> {
261        None
262    }
263}
264
265/// Registry of expression functions.
266#[cfg(feature = "func")]
267#[derive(Clone, Default)]
268pub struct FunctionRegistry {
269    funcs: HashMap<String, Arc<dyn ExprFn>>,
270}
271
272#[cfg(feature = "func")]
273impl FunctionRegistry {
274    /// Creates an empty registry.
275    pub fn new() -> Self {
276        Self::default()
277    }
278
279    /// Registers a function.
280    pub fn register<F: ExprFn + 'static>(&mut self, func: F) {
281        self.funcs.insert(func.name().to_string(), Arc::new(func));
282    }
283
284    /// Gets a function by name.
285    pub fn get(&self, name: &str) -> Option<&Arc<dyn ExprFn>> {
286        self.funcs.get(name)
287    }
288
289    /// Returns an iterator over all registered function names.
290    pub fn names(&self) -> impl Iterator<Item = &str> {
291        self.funcs.keys().map(|s| s.as_str())
292    }
293}
294
295// ============================================================================
296// Errors
297// ============================================================================
298
299/// Expression parse error.
300#[derive(Debug, Clone, PartialEq)]
301pub enum ParseError {
302    UnexpectedChar(char),
303    UnexpectedEnd,
304    UnexpectedToken(String),
305    InvalidNumber(String),
306}
307
308impl std::fmt::Display for ParseError {
309    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
310        match self {
311            ParseError::UnexpectedChar(c) => write!(f, "unexpected character: '{}'", c),
312            ParseError::UnexpectedEnd => write!(f, "unexpected end of expression"),
313            ParseError::UnexpectedToken(t) => write!(f, "unexpected token: '{}'", t),
314            ParseError::InvalidNumber(s) => write!(f, "invalid number: '{}'", s),
315        }
316    }
317}
318
319impl std::error::Error for ParseError {}
320
321/// Expression evaluation error.
322#[derive(Debug, Clone, PartialEq)]
323pub enum EvalError {
324    UnknownVariable(String),
325    #[cfg(feature = "func")]
326    UnknownFunction(String),
327    #[cfg(feature = "func")]
328    WrongArgCount {
329        func: String,
330        expected: usize,
331        got: usize,
332    },
333    /// Operation not supported for this numeric type (e.g., bitwise ops on floats).
334    UnsupportedOperation(String),
335}
336
337impl std::fmt::Display for EvalError {
338    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
339        match self {
340            EvalError::UnknownVariable(name) => write!(f, "unknown variable: '{}'", name),
341            #[cfg(feature = "func")]
342            EvalError::UnknownFunction(name) => write!(f, "unknown function: '{}'", name),
343            #[cfg(feature = "func")]
344            EvalError::WrongArgCount {
345                func,
346                expected,
347                got,
348            } => {
349                write!(
350                    f,
351                    "function '{}' expects {} args, got {}",
352                    func, expected, got
353                )
354            }
355            EvalError::UnsupportedOperation(op) => {
356                write!(f, "unsupported operation for this numeric type: '{}'", op)
357            }
358        }
359    }
360}
361
362impl std::error::Error for EvalError {}
363
364// ============================================================================
365// Lexer
366// ============================================================================
367
368#[derive(Debug, Clone, PartialEq)]
369enum Token {
370    Number(f64),
371    Ident(String),
372    Plus,
373    Minus,
374    Star,
375    Slash,
376    Caret,
377    Percent,
378    Ampersand,
379    Pipe,
380    Tilde,
381    Shl,
382    Shr,
383    LParen,
384    RParen,
385    #[cfg(feature = "func")]
386    Comma,
387    // Comparison operators
388    #[cfg(feature = "cond")]
389    Lt,
390    #[cfg(feature = "cond")]
391    Le,
392    #[cfg(feature = "cond")]
393    Gt,
394    #[cfg(feature = "cond")]
395    Ge,
396    #[cfg(feature = "cond")]
397    Eq,
398    #[cfg(feature = "cond")]
399    Ne,
400    // Boolean operators (keywords)
401    #[cfg(feature = "cond")]
402    And,
403    #[cfg(feature = "cond")]
404    Or,
405    #[cfg(feature = "cond")]
406    Not,
407    // Conditional
408    #[cfg(feature = "cond")]
409    If,
410    #[cfg(feature = "cond")]
411    Then,
412    #[cfg(feature = "cond")]
413    Else,
414    // Let bindings
415    Let,
416    Assign,
417    Semicolon,
418    Eof,
419}
420
421struct Lexer<'a> {
422    input: &'a str,
423    pos: usize,
424}
425
426impl<'a> Lexer<'a> {
427    fn new(input: &'a str) -> Self {
428        Self { input, pos: 0 }
429    }
430
431    fn peek_char(&self) -> Option<char> {
432        self.input[self.pos..].chars().next()
433    }
434
435    fn next_char(&mut self) -> Option<char> {
436        let c = self.peek_char()?;
437        self.pos += c.len_utf8();
438        Some(c)
439    }
440
441    fn skip_whitespace(&mut self) {
442        while let Some(c) = self.peek_char() {
443            if c.is_whitespace() {
444                self.next_char();
445            } else {
446                break;
447            }
448        }
449    }
450
451    fn read_number(&mut self) -> Result<f64, ParseError> {
452        let start = self.pos;
453        while let Some(c) = self.peek_char() {
454            if c.is_ascii_digit() || c == '.' {
455                self.next_char();
456            } else {
457                break;
458            }
459        }
460        let s = &self.input[start..self.pos];
461        s.parse()
462            .map_err(|_| ParseError::InvalidNumber(s.to_string()))
463    }
464
465    fn read_ident(&mut self) -> String {
466        let start = self.pos;
467        while let Some(c) = self.peek_char() {
468            if c.is_alphanumeric() || c == '_' {
469                self.next_char();
470            } else {
471                break;
472            }
473        }
474        self.input[start..self.pos].to_string()
475    }
476
477    fn next_token(&mut self) -> Result<Token, ParseError> {
478        self.skip_whitespace();
479
480        let Some(c) = self.peek_char() else {
481            return Ok(Token::Eof);
482        };
483
484        match c {
485            '+' => {
486                self.next_char();
487                Ok(Token::Plus)
488            }
489            '-' => {
490                self.next_char();
491                Ok(Token::Minus)
492            }
493            '*' => {
494                self.next_char();
495                Ok(Token::Star)
496            }
497            '/' => {
498                self.next_char();
499                Ok(Token::Slash)
500            }
501            '^' => {
502                self.next_char();
503                Ok(Token::Caret)
504            }
505            '%' => {
506                self.next_char();
507                Ok(Token::Percent)
508            }
509            '&' => {
510                self.next_char();
511                Ok(Token::Ampersand)
512            }
513            '|' => {
514                self.next_char();
515                Ok(Token::Pipe)
516            }
517            '~' => {
518                self.next_char();
519                Ok(Token::Tilde)
520            }
521            '(' => {
522                self.next_char();
523                Ok(Token::LParen)
524            }
525            ')' => {
526                self.next_char();
527                Ok(Token::RParen)
528            }
529            #[cfg(feature = "func")]
530            ',' => {
531                self.next_char();
532                Ok(Token::Comma)
533            }
534            ';' => {
535                self.next_char();
536                Ok(Token::Semicolon)
537            }
538            '<' => {
539                self.next_char();
540                if self.peek_char() == Some('<') {
541                    self.next_char();
542                    Ok(Token::Shl)
543                } else {
544                    #[cfg(feature = "cond")]
545                    {
546                        if self.peek_char() == Some('=') {
547                            self.next_char();
548                            Ok(Token::Le)
549                        } else {
550                            Ok(Token::Lt)
551                        }
552                    }
553                    #[cfg(not(feature = "cond"))]
554                    Err(ParseError::UnexpectedChar('<'))
555                }
556            }
557            '>' => {
558                self.next_char();
559                if self.peek_char() == Some('>') {
560                    self.next_char();
561                    Ok(Token::Shr)
562                } else {
563                    #[cfg(feature = "cond")]
564                    {
565                        if self.peek_char() == Some('=') {
566                            self.next_char();
567                            Ok(Token::Ge)
568                        } else {
569                            Ok(Token::Gt)
570                        }
571                    }
572                    #[cfg(not(feature = "cond"))]
573                    Err(ParseError::UnexpectedChar('>'))
574                }
575            }
576            '=' => {
577                self.next_char();
578                if self.peek_char() == Some('=') {
579                    self.next_char();
580                    #[cfg(feature = "cond")]
581                    {
582                        Ok(Token::Eq)
583                    }
584                    #[cfg(not(feature = "cond"))]
585                    {
586                        Err(ParseError::UnexpectedChar('='))
587                    }
588                } else {
589                    Ok(Token::Assign)
590                }
591            }
592            #[cfg(feature = "cond")]
593            '!' => {
594                self.next_char();
595                if self.peek_char() == Some('=') {
596                    self.next_char();
597                    Ok(Token::Ne)
598                } else {
599                    Err(ParseError::UnexpectedChar('!'))
600                }
601            }
602            '0'..='9' | '.' => Ok(Token::Number(self.read_number()?)),
603            'a'..='z' | 'A'..='Z' | '_' => {
604                let ident = self.read_ident();
605                // Check for keywords
606                if ident == "let" {
607                    return Ok(Token::Let);
608                }
609                #[cfg(feature = "cond")]
610                match ident.as_str() {
611                    "and" => return Ok(Token::And),
612                    "or" => return Ok(Token::Or),
613                    "not" => return Ok(Token::Not),
614                    "if" => return Ok(Token::If),
615                    "then" => return Ok(Token::Then),
616                    "else" => return Ok(Token::Else),
617                    _ => {}
618                }
619                Ok(Token::Ident(ident))
620            }
621            _ => Err(ParseError::UnexpectedChar(c)),
622        }
623    }
624}
625
626// ============================================================================
627// AST
628// ============================================================================
629
630/// Abstract syntax tree node for expressions.
631///
632/// The AST represents the structure of a parsed expression. Use [`Expr::ast()`]
633/// to access the AST after parsing.
634///
635/// # Variants
636///
637/// The available variants depend on enabled features:
638/// - **Always**: `Num`, `Var`, `BinOp`, `UnaryOp`
639/// - **With `func`**: `Call`
640/// - **With `cond`**: `Compare`, `And`, `Or`, `If`
641///
642/// # Example
643///
644/// ```
645/// use wick_core::{Expr, Ast, BinOp};
646///
647/// let expr = Expr::parse("2 + 3").unwrap();
648/// match expr.ast() {
649///     Ast::BinOp(BinOp::Add, left, right) => {
650///         assert!(matches!(left.as_ref(), Ast::Num(2.0)));
651///         assert!(matches!(right.as_ref(), Ast::Num(3.0)));
652///     }
653///     _ => panic!("expected addition"),
654/// }
655/// ```
656#[derive(Debug, Clone, PartialEq)]
657pub enum Ast {
658    /// Numeric literal (e.g., `42`, `3.14`).
659    Num(f64),
660    /// Variable reference, resolved at evaluation time.
661    Var(String),
662    /// Binary operation: `left op right`.
663    BinOp(BinOp, Box<Ast>, Box<Ast>),
664    /// Unary operation: `op operand`.
665    UnaryOp(UnaryOp, Box<Ast>),
666    /// Function call: `name(arg1, arg2, ...)`.
667    #[cfg(feature = "func")]
668    Call(String, Vec<Ast>),
669    /// Comparison: `left op right`, evaluates to `0.0` or `1.0`.
670    #[cfg(feature = "cond")]
671    Compare(CompareOp, Box<Ast>, Box<Ast>),
672    /// Logical AND with short-circuit evaluation.
673    #[cfg(feature = "cond")]
674    And(Box<Ast>, Box<Ast>),
675    /// Logical OR with short-circuit evaluation.
676    #[cfg(feature = "cond")]
677    Or(Box<Ast>, Box<Ast>),
678    /// Conditional: `if condition then then_expr else else_expr`.
679    #[cfg(feature = "cond")]
680    If(Box<Ast>, Box<Ast>, Box<Ast>),
681    /// Local binding: `let name = value; body`.
682    Let {
683        name: String,
684        value: Box<Ast>,
685        body: Box<Ast>,
686    },
687}
688
689/// Binary operators for arithmetic and bitwise operations.
690///
691/// Used in [`Ast::BinOp`] to specify the operation.
692#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
693pub enum BinOp {
694    /// Addition (`+`).
695    Add,
696    /// Subtraction (`-`).
697    Sub,
698    /// Multiplication (`*`).
699    Mul,
700    /// Division (`/`).
701    Div,
702    /// Exponentiation (`^`), right-associative.
703    Pow,
704    /// Remainder/modulo (`%`).
705    Rem,
706    /// Bitwise AND (`&`).
707    BitAnd,
708    /// Bitwise OR (`|`).
709    BitOr,
710    /// Left shift (`<<`).
711    Shl,
712    /// Right shift (`>>`).
713    Shr,
714}
715
716/// Comparison operators (requires `cond` feature).
717///
718/// Used in [`Ast::Compare`] to specify the comparison.
719/// All comparisons evaluate to `1.0` (true) or `0.0` (false).
720#[cfg(feature = "cond")]
721#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
722pub enum CompareOp {
723    /// Less than (`<`).
724    Lt,
725    /// Less than or equal (`<=`).
726    Le,
727    /// Greater than (`>`).
728    Gt,
729    /// Greater than or equal (`>=`).
730    Ge,
731    /// Equal (`==`).
732    Eq,
733    /// Not equal (`!=`).
734    Ne,
735}
736
737/// Unary operators.
738///
739/// Used in [`Ast::UnaryOp`] to specify the operation.
740#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
741pub enum UnaryOp {
742    /// Numeric negation (`-x`).
743    Neg,
744    /// Logical NOT (`not x`), requires `cond` feature.
745    /// Returns `1.0` if operand is `0.0`, otherwise `0.0`.
746    #[cfg(feature = "cond")]
747    Not,
748    /// Bitwise NOT (`~x`).
749    BitNot,
750}
751
752// ============================================================================
753// AST Display (produces parseable expressions)
754// ============================================================================
755
756impl std::fmt::Display for Ast {
757    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
758        match self {
759            Ast::Num(n) => {
760                if n.is_nan() {
761                    write!(f, "(0.0 / 0.0)") // NaN
762                } else if n.is_infinite() {
763                    if *n > 0.0 {
764                        write!(f, "(1.0 / 0.0)") // +Inf
765                    } else {
766                        write!(f, "(-1.0 / 0.0)") // -Inf
767                    }
768                } else {
769                    write!(f, "{}", n)
770                }
771            }
772            Ast::Var(name) => write!(f, "{}", name),
773            Ast::BinOp(op, left, right) => {
774                write!(f, "({} {} {})", left, op, right)
775            }
776            Ast::UnaryOp(op, inner) => {
777                write!(f, "({}{})", op, inner)
778            }
779            #[cfg(feature = "func")]
780            Ast::Call(name, args) => {
781                write!(f, "{}(", name)?;
782                for (i, arg) in args.iter().enumerate() {
783                    if i > 0 {
784                        write!(f, ", ")?;
785                    }
786                    write!(f, "{}", arg)?;
787                }
788                write!(f, ")")
789            }
790            #[cfg(feature = "cond")]
791            Ast::Compare(op, left, right) => {
792                write!(f, "({} {} {})", left, op, right)
793            }
794            #[cfg(feature = "cond")]
795            Ast::And(left, right) => {
796                write!(f, "({} and {})", left, right)
797            }
798            #[cfg(feature = "cond")]
799            Ast::Or(left, right) => {
800                write!(f, "({} or {})", left, right)
801            }
802            #[cfg(feature = "cond")]
803            Ast::If(cond, then_expr, else_expr) => {
804                write!(f, "(if {} then {} else {})", cond, then_expr, else_expr)
805            }
806            Ast::Let { name, value, body } => {
807                write!(f, "(let {} = {}; {})", name, value, body)
808            }
809        }
810    }
811}
812
813impl std::fmt::Display for BinOp {
814    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
815        match self {
816            BinOp::Add => write!(f, "+"),
817            BinOp::Sub => write!(f, "-"),
818            BinOp::Mul => write!(f, "*"),
819            BinOp::Div => write!(f, "/"),
820            BinOp::Pow => write!(f, "^"),
821            BinOp::Rem => write!(f, "%"),
822            BinOp::BitAnd => write!(f, "&"),
823            BinOp::BitOr => write!(f, "|"),
824            BinOp::Shl => write!(f, "<<"),
825            BinOp::Shr => write!(f, ">>"),
826        }
827    }
828}
829
830impl std::fmt::Display for UnaryOp {
831    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
832        match self {
833            UnaryOp::Neg => write!(f, "-"),
834            #[cfg(feature = "cond")]
835            UnaryOp::Not => write!(f, "not "),
836            UnaryOp::BitNot => write!(f, "~"),
837        }
838    }
839}
840
841#[cfg(feature = "cond")]
842impl std::fmt::Display for CompareOp {
843    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
844        match self {
845            CompareOp::Lt => write!(f, "<"),
846            CompareOp::Le => write!(f, "<="),
847            CompareOp::Gt => write!(f, ">"),
848            CompareOp::Ge => write!(f, ">="),
849            CompareOp::Eq => write!(f, "=="),
850            CompareOp::Ne => write!(f, "!="),
851        }
852    }
853}
854
855// ============================================================================
856// AST Introspection (introspect feature)
857// ============================================================================
858
859#[cfg(feature = "introspect")]
860impl Ast {
861    /// Returns the set of free variables referenced in this AST node.
862    ///
863    /// Traverses the entire AST and collects all variable names.
864    ///
865    /// # Example
866    ///
867    /// ```
868    /// use wick_core::{Expr, Ast};
869    ///
870    /// let expr = Expr::parse("sin(x) + y * z").unwrap();
871    /// let vars = expr.ast().free_vars();
872    /// assert!(vars.contains("x"));
873    /// assert!(vars.contains("y"));
874    /// assert!(vars.contains("z"));
875    /// ```
876    pub fn free_vars(&self) -> HashSet<&str> {
877        let mut vars = HashSet::new();
878        self.collect_vars(&mut vars);
879        vars
880    }
881
882    fn collect_vars<'a>(&'a self, vars: &mut HashSet<&'a str>) {
883        match self {
884            Ast::Num(_) => {}
885            Ast::Var(name) => {
886                vars.insert(name.as_str());
887            }
888            Ast::BinOp(_, left, right) => {
889                left.collect_vars(vars);
890                right.collect_vars(vars);
891            }
892            Ast::UnaryOp(_, inner) => {
893                inner.collect_vars(vars);
894            }
895            #[cfg(feature = "func")]
896            Ast::Call(_, args) => {
897                for arg in args {
898                    arg.collect_vars(vars);
899                }
900            }
901            #[cfg(feature = "cond")]
902            Ast::Compare(_, left, right) => {
903                left.collect_vars(vars);
904                right.collect_vars(vars);
905            }
906            #[cfg(feature = "cond")]
907            Ast::And(left, right) => {
908                left.collect_vars(vars);
909                right.collect_vars(vars);
910            }
911            #[cfg(feature = "cond")]
912            Ast::Or(left, right) => {
913                left.collect_vars(vars);
914                right.collect_vars(vars);
915            }
916            #[cfg(feature = "cond")]
917            Ast::If(cond, then_expr, else_expr) => {
918                cond.collect_vars(vars);
919                then_expr.collect_vars(vars);
920                else_expr.collect_vars(vars);
921            }
922            Ast::Let { name, value, body } => {
923                // Variables in value are free (name is not bound there yet)
924                value.collect_vars(vars);
925                // Variables in body are free, except for the bound name
926                let mut body_vars = HashSet::new();
927                body.collect_vars(&mut body_vars);
928                body_vars.remove(name.as_str());
929                for v in body_vars {
930                    vars.insert(v);
931                }
932            }
933        }
934    }
935}
936
937// ============================================================================
938// Parser
939// ============================================================================
940
941struct Parser<'a> {
942    lexer: Lexer<'a>,
943    current: Token,
944}
945
946impl<'a> Parser<'a> {
947    fn new(input: &'a str) -> Result<Self, ParseError> {
948        let mut lexer = Lexer::new(input);
949        let current = lexer.next_token()?;
950        Ok(Self { lexer, current })
951    }
952
953    fn advance(&mut self) -> Result<(), ParseError> {
954        self.current = self.lexer.next_token()?;
955        Ok(())
956    }
957
958    fn expect(&mut self, expected: Token) -> Result<(), ParseError> {
959        if self.current == expected {
960            self.advance()
961        } else {
962            Err(ParseError::UnexpectedToken(format!("{:?}", self.current)))
963        }
964    }
965
966    // Precedence (low to high):
967    // 0. let bindings (lowest)
968    // 1. if/then/else (cond feature)
969    // 2. or (cond feature, keyword)
970    // 3. and (cond feature, keyword)
971    // 4. bit_or (|)
972    // 5. bit_and (&)
973    // 6. comparison (<, <=, >, >=, ==, !=) (cond feature)
974    // 7. shift (<<, >>)
975    // 8. add/sub (+, -)
976    // 9. mul/div/rem (*, /, %)
977    // 10. power (^)
978    // 11. unary (-, ~, not)
979    // 12. primary
980
981    fn parse_expr(&mut self) -> Result<Ast, ParseError> {
982        self.parse_let()
983    }
984
985    fn parse_let(&mut self) -> Result<Ast, ParseError> {
986        if self.current == Token::Let {
987            self.advance()?;
988            // Expect identifier
989            let name = match &self.current {
990                Token::Ident(s) => s.clone(),
991                _ => return Err(ParseError::UnexpectedToken(format!("{:?}", self.current))),
992            };
993            self.advance()?;
994            // Expect =
995            self.expect(Token::Assign)?;
996            // Parse value expression (at next precedence level, not including let)
997            let value = self.parse_non_let()?;
998            // Expect ;
999            self.expect(Token::Semicolon)?;
1000            // Parse body (can be another let)
1001            let body = self.parse_let()?;
1002            Ok(Ast::Let {
1003                name,
1004                value: Box::new(value),
1005                body: Box::new(body),
1006            })
1007        } else {
1008            self.parse_non_let()
1009        }
1010    }
1011
1012    fn parse_non_let(&mut self) -> Result<Ast, ParseError> {
1013        #[cfg(feature = "cond")]
1014        {
1015            self.parse_if()
1016        }
1017        #[cfg(not(feature = "cond"))]
1018        {
1019            self.parse_bit_or()
1020        }
1021    }
1022
1023    #[cfg(feature = "cond")]
1024    fn parse_if(&mut self) -> Result<Ast, ParseError> {
1025        if self.current == Token::If {
1026            self.advance()?;
1027            let cond = self.parse_or()?;
1028            self.expect(Token::Then)?;
1029            let then_expr = self.parse_if()?; // Allow nested if in then branch
1030            self.expect(Token::Else)?;
1031            let else_expr = self.parse_if()?; // Right associative for chained if/else
1032            Ok(Ast::If(
1033                Box::new(cond),
1034                Box::new(then_expr),
1035                Box::new(else_expr),
1036            ))
1037        } else {
1038            self.parse_or()
1039        }
1040    }
1041
1042    #[cfg(feature = "cond")]
1043    fn parse_or(&mut self) -> Result<Ast, ParseError> {
1044        let mut left = self.parse_and()?;
1045
1046        while self.current == Token::Or {
1047            self.advance()?;
1048            let right = self.parse_and()?;
1049            left = Ast::Or(Box::new(left), Box::new(right));
1050        }
1051
1052        Ok(left)
1053    }
1054
1055    #[cfg(feature = "cond")]
1056    fn parse_and(&mut self) -> Result<Ast, ParseError> {
1057        let mut left = self.parse_bit_or()?;
1058
1059        while self.current == Token::And {
1060            self.advance()?;
1061            let right = self.parse_bit_or()?;
1062            left = Ast::And(Box::new(left), Box::new(right));
1063        }
1064
1065        Ok(left)
1066    }
1067
1068    fn parse_bit_or(&mut self) -> Result<Ast, ParseError> {
1069        let mut left = self.parse_bit_and()?;
1070
1071        while self.current == Token::Pipe {
1072            self.advance()?;
1073            let right = self.parse_bit_and()?;
1074            left = Ast::BinOp(BinOp::BitOr, Box::new(left), Box::new(right));
1075        }
1076
1077        Ok(left)
1078    }
1079
1080    fn parse_bit_and(&mut self) -> Result<Ast, ParseError> {
1081        #[cfg(feature = "cond")]
1082        let mut left = self.parse_compare()?;
1083        #[cfg(not(feature = "cond"))]
1084        let mut left = self.parse_shift()?;
1085
1086        while self.current == Token::Ampersand {
1087            self.advance()?;
1088            #[cfg(feature = "cond")]
1089            let right = self.parse_compare()?;
1090            #[cfg(not(feature = "cond"))]
1091            let right = self.parse_shift()?;
1092            left = Ast::BinOp(BinOp::BitAnd, Box::new(left), Box::new(right));
1093        }
1094
1095        Ok(left)
1096    }
1097
1098    #[cfg(feature = "cond")]
1099    fn parse_compare(&mut self) -> Result<Ast, ParseError> {
1100        let left = self.parse_shift()?;
1101
1102        let op = match &self.current {
1103            Token::Lt => Some(CompareOp::Lt),
1104            Token::Le => Some(CompareOp::Le),
1105            Token::Gt => Some(CompareOp::Gt),
1106            Token::Ge => Some(CompareOp::Ge),
1107            Token::Eq => Some(CompareOp::Eq),
1108            Token::Ne => Some(CompareOp::Ne),
1109            _ => None,
1110        };
1111
1112        if let Some(op) = op {
1113            self.advance()?;
1114            let right = self.parse_shift()?;
1115            Ok(Ast::Compare(op, Box::new(left), Box::new(right)))
1116        } else {
1117            Ok(left)
1118        }
1119    }
1120
1121    fn parse_shift(&mut self) -> Result<Ast, ParseError> {
1122        let mut left = self.parse_add_sub()?;
1123
1124        loop {
1125            match &self.current {
1126                Token::Shl => {
1127                    self.advance()?;
1128                    let right = self.parse_add_sub()?;
1129                    left = Ast::BinOp(BinOp::Shl, Box::new(left), Box::new(right));
1130                }
1131                Token::Shr => {
1132                    self.advance()?;
1133                    let right = self.parse_add_sub()?;
1134                    left = Ast::BinOp(BinOp::Shr, Box::new(left), Box::new(right));
1135                }
1136                _ => break,
1137            }
1138        }
1139
1140        Ok(left)
1141    }
1142
1143    fn parse_add_sub(&mut self) -> Result<Ast, ParseError> {
1144        let mut left = self.parse_mul_div()?;
1145
1146        loop {
1147            match &self.current {
1148                Token::Plus => {
1149                    self.advance()?;
1150                    let right = self.parse_mul_div()?;
1151                    left = Ast::BinOp(BinOp::Add, Box::new(left), Box::new(right));
1152                }
1153                Token::Minus => {
1154                    self.advance()?;
1155                    let right = self.parse_mul_div()?;
1156                    left = Ast::BinOp(BinOp::Sub, Box::new(left), Box::new(right));
1157                }
1158                _ => break,
1159            }
1160        }
1161
1162        Ok(left)
1163    }
1164
1165    fn parse_mul_div(&mut self) -> Result<Ast, ParseError> {
1166        let mut left = self.parse_power()?;
1167
1168        loop {
1169            match &self.current {
1170                Token::Star => {
1171                    self.advance()?;
1172                    let right = self.parse_power()?;
1173                    left = Ast::BinOp(BinOp::Mul, Box::new(left), Box::new(right));
1174                }
1175                Token::Slash => {
1176                    self.advance()?;
1177                    let right = self.parse_power()?;
1178                    left = Ast::BinOp(BinOp::Div, Box::new(left), Box::new(right));
1179                }
1180                Token::Percent => {
1181                    self.advance()?;
1182                    let right = self.parse_power()?;
1183                    left = Ast::BinOp(BinOp::Rem, Box::new(left), Box::new(right));
1184                }
1185                _ => break,
1186            }
1187        }
1188
1189        Ok(left)
1190    }
1191
1192    fn parse_power(&mut self) -> Result<Ast, ParseError> {
1193        let base = self.parse_unary()?;
1194
1195        if self.current == Token::Caret {
1196            self.advance()?;
1197            let exp = self.parse_power()?; // Right associative
1198            Ok(Ast::BinOp(BinOp::Pow, Box::new(base), Box::new(exp)))
1199        } else {
1200            Ok(base)
1201        }
1202    }
1203
1204    fn parse_unary(&mut self) -> Result<Ast, ParseError> {
1205        match &self.current {
1206            Token::Minus => {
1207                self.advance()?;
1208                let inner = self.parse_unary()?;
1209                Ok(Ast::UnaryOp(UnaryOp::Neg, Box::new(inner)))
1210            }
1211            Token::Tilde => {
1212                self.advance()?;
1213                let inner = self.parse_unary()?;
1214                Ok(Ast::UnaryOp(UnaryOp::BitNot, Box::new(inner)))
1215            }
1216            #[cfg(feature = "cond")]
1217            Token::Not => {
1218                self.advance()?;
1219                let inner = self.parse_unary()?;
1220                Ok(Ast::UnaryOp(UnaryOp::Not, Box::new(inner)))
1221            }
1222            _ => self.parse_primary(),
1223        }
1224    }
1225
1226    fn parse_primary(&mut self) -> Result<Ast, ParseError> {
1227        match &self.current {
1228            Token::Number(n) => {
1229                let n = *n;
1230                self.advance()?;
1231                Ok(Ast::Num(n))
1232            }
1233            Token::Ident(name) => {
1234                let name = name.clone();
1235                self.advance()?;
1236
1237                // Check if it's a function call (func feature)
1238                #[cfg(feature = "func")]
1239                if self.current == Token::LParen {
1240                    self.advance()?;
1241                    let mut args = Vec::new();
1242                    if self.current != Token::RParen {
1243                        args.push(self.parse_expr()?);
1244                        while self.current == Token::Comma {
1245                            self.advance()?;
1246                            args.push(self.parse_expr()?);
1247                        }
1248                    }
1249                    self.expect(Token::RParen)?;
1250                    return Ok(Ast::Call(name, args));
1251                }
1252
1253                // It's a variable
1254                Ok(Ast::Var(name))
1255            }
1256            Token::LParen => {
1257                self.advance()?;
1258                let inner = self.parse_expr()?;
1259                self.expect(Token::RParen)?;
1260                Ok(inner)
1261            }
1262            Token::Eof => Err(ParseError::UnexpectedEnd),
1263            _ => Err(ParseError::UnexpectedToken(format!("{:?}", self.current))),
1264        }
1265    }
1266}
1267
1268// ============================================================================
1269// Expression
1270// ============================================================================
1271
1272/// A parsed expression that can be evaluated or inspected.
1273///
1274/// `Expr` is the main entry point for the expression language. Parse a string
1275/// with [`Expr::parse()`], then either evaluate it with [`Expr::eval()`] or
1276/// inspect the AST with [`Expr::ast()`].
1277///
1278/// # Example
1279///
1280/// ```
1281/// use wick_core::Expr;
1282/// use std::collections::HashMap;
1283///
1284/// // Parse an expression
1285/// let expr = Expr::parse("x^2 + 2*x + 1").unwrap();
1286///
1287/// // Evaluate with different variable values
1288/// let mut vars = HashMap::new();
1289/// vars.insert("x".to_string(), 3.0);
1290/// # #[cfg(not(feature = "func"))]
1291/// assert_eq!(expr.eval(&vars).unwrap(), 16.0);  // 9 + 6 + 1
1292/// # #[cfg(feature = "func")]
1293/// # assert_eq!(expr.eval(&vars, &wick_core::FunctionRegistry::new()).unwrap(), 16.0);
1294///
1295/// vars.insert("x".to_string(), 0.0);
1296/// # #[cfg(not(feature = "func"))]
1297/// assert_eq!(expr.eval(&vars).unwrap(), 1.0);   // 0 + 0 + 1
1298/// # #[cfg(feature = "func")]
1299/// # assert_eq!(expr.eval(&vars, &wick_core::FunctionRegistry::new()).unwrap(), 1.0);
1300/// ```
1301#[derive(Debug, Clone)]
1302pub struct Expr {
1303    ast: Ast,
1304}
1305
1306impl Expr {
1307    /// Parses an expression from a string.
1308    ///
1309    /// # Errors
1310    ///
1311    /// Returns [`ParseError`] if the input is not a valid expression:
1312    /// - [`ParseError::UnexpectedChar`] for invalid characters
1313    /// - [`ParseError::UnexpectedEnd`] for incomplete expressions
1314    /// - [`ParseError::UnexpectedToken`] for syntax errors
1315    /// - [`ParseError::InvalidNumber`] for malformed numeric literals
1316    ///
1317    /// # Example
1318    ///
1319    /// ```
1320    /// use wick_core::{Expr, ParseError};
1321    ///
1322    /// // Valid expression
1323    /// assert!(Expr::parse("1 + 2").is_ok());
1324    ///
1325    /// // Invalid: unexpected character
1326    /// assert!(matches!(Expr::parse("1 @ 2"), Err(ParseError::UnexpectedChar('@'))));
1327    ///
1328    /// // Invalid: incomplete expression
1329    /// assert!(matches!(Expr::parse("1 +"), Err(ParseError::UnexpectedEnd)));
1330    /// ```
1331    pub fn parse(input: &str) -> Result<Self, ParseError> {
1332        let mut parser = Parser::new(input)?;
1333        let ast = parser.parse_expr()?;
1334        if parser.current != Token::Eof {
1335            return Err(ParseError::UnexpectedToken(format!("{:?}", parser.current)));
1336        }
1337        Ok(Self { ast })
1338    }
1339
1340    /// Returns a reference to the parsed AST.
1341    ///
1342    /// Use this to inspect the expression structure or to compile it to
1343    /// a different target (WGSL, Lua, etc.).
1344    pub fn ast(&self) -> &Ast {
1345        &self.ast
1346    }
1347
1348    /// Returns the set of free variables referenced in the expression.
1349    ///
1350    /// This is useful for determining which variables need to be provided
1351    /// at evaluation time, or for building dependency graphs.
1352    ///
1353    /// # Example
1354    ///
1355    /// ```
1356    /// use wick_core::Expr;
1357    ///
1358    /// let expr = Expr::parse("x * 2 + y").unwrap();
1359    /// let vars = expr.free_vars();
1360    /// assert!(vars.contains("x"));
1361    /// assert!(vars.contains("y"));
1362    /// assert_eq!(vars.len(), 2);
1363    /// ```
1364    #[cfg(feature = "introspect")]
1365    pub fn free_vars(&self) -> HashSet<&str> {
1366        self.ast.free_vars()
1367    }
1368
1369    /// Evaluates the expression with the given variables and function registry.
1370    ///
1371    /// # Errors
1372    ///
1373    /// Returns [`EvalError`] if evaluation fails:
1374    /// - [`EvalError::UnknownVariable`] if a variable is not in `vars`
1375    /// - [`EvalError::UnknownFunction`] if a function is not in `funcs`
1376    /// - [`EvalError::WrongArgCount`] if a function is called with wrong arity
1377    #[cfg(feature = "func")]
1378    pub fn eval(
1379        &self,
1380        vars: &HashMap<String, f32>,
1381        funcs: &FunctionRegistry,
1382    ) -> Result<f32, EvalError> {
1383        eval_ast(&self.ast, vars, funcs)
1384    }
1385
1386    /// Evaluates the expression with the given variables.
1387    ///
1388    /// This version is available when the `func` feature is disabled.
1389    ///
1390    /// # Errors
1391    ///
1392    /// Returns [`EvalError::UnknownVariable`] if a variable is not in `vars`.
1393    #[cfg(not(feature = "func"))]
1394    pub fn eval(&self, vars: &HashMap<String, f32>) -> Result<f32, EvalError> {
1395        eval_ast(&self.ast, vars)
1396    }
1397}
1398
1399#[cfg(feature = "func")]
1400fn eval_ast(
1401    ast: &Ast,
1402    vars: &HashMap<String, f32>,
1403    funcs: &FunctionRegistry,
1404) -> Result<f32, EvalError> {
1405    match ast {
1406        Ast::Num(n) => Ok(*n as f32),
1407        Ast::Var(name) => vars
1408            .get(name)
1409            .copied()
1410            .ok_or_else(|| EvalError::UnknownVariable(name.clone())),
1411        Ast::BinOp(op, l, r) => {
1412            let l = eval_ast(l, vars, funcs)?;
1413            let r = eval_ast(r, vars, funcs)?;
1414            match op {
1415                BinOp::Add => Ok(l + r),
1416                BinOp::Sub => Ok(l - r),
1417                BinOp::Mul => Ok(l * r),
1418                BinOp::Div => Ok(l / r),
1419                BinOp::Pow => Ok(l.powf(r)),
1420                BinOp::Rem => Ok(l % r),
1421                BinOp::BitAnd => Err(EvalError::UnsupportedOperation("&".to_string())),
1422                BinOp::BitOr => Err(EvalError::UnsupportedOperation("|".to_string())),
1423                BinOp::Shl => Err(EvalError::UnsupportedOperation("<<".to_string())),
1424                BinOp::Shr => Err(EvalError::UnsupportedOperation(">>".to_string())),
1425            }
1426        }
1427        Ast::UnaryOp(op, inner) => {
1428            let v = eval_ast(inner, vars, funcs)?;
1429            match op {
1430                UnaryOp::Neg => Ok(-v),
1431                UnaryOp::BitNot => Err(EvalError::UnsupportedOperation("~".to_string())),
1432                #[cfg(feature = "cond")]
1433                UnaryOp::Not => {
1434                    if v == 0.0 {
1435                        Ok(1.0)
1436                    } else {
1437                        Ok(0.0)
1438                    }
1439                }
1440            }
1441        }
1442        #[cfg(feature = "cond")]
1443        Ast::Compare(op, l, r) => {
1444            let l = eval_ast(l, vars, funcs)?;
1445            let r = eval_ast(r, vars, funcs)?;
1446            let result = match op {
1447                CompareOp::Lt => l < r,
1448                CompareOp::Le => l <= r,
1449                CompareOp::Gt => l > r,
1450                CompareOp::Ge => l >= r,
1451                CompareOp::Eq => l == r,
1452                CompareOp::Ne => l != r,
1453            };
1454            Ok(if result { 1.0 } else { 0.0 })
1455        }
1456        #[cfg(feature = "cond")]
1457        Ast::And(l, r) => {
1458            let l = eval_ast(l, vars, funcs)?;
1459            if l == 0.0 {
1460                Ok(0.0) // Short-circuit
1461            } else {
1462                let r = eval_ast(r, vars, funcs)?;
1463                Ok(if r != 0.0 { 1.0 } else { 0.0 })
1464            }
1465        }
1466        #[cfg(feature = "cond")]
1467        Ast::Or(l, r) => {
1468            let l = eval_ast(l, vars, funcs)?;
1469            if l != 0.0 {
1470                Ok(1.0) // Short-circuit
1471            } else {
1472                let r = eval_ast(r, vars, funcs)?;
1473                Ok(if r != 0.0 { 1.0 } else { 0.0 })
1474            }
1475        }
1476        #[cfg(feature = "cond")]
1477        Ast::If(cond, then_expr, else_expr) => {
1478            let cond = eval_ast(cond, vars, funcs)?;
1479            if cond != 0.0 {
1480                eval_ast(then_expr, vars, funcs)
1481            } else {
1482                eval_ast(else_expr, vars, funcs)
1483            }
1484        }
1485        Ast::Call(name, args) => {
1486            let func = funcs
1487                .get(name)
1488                .ok_or_else(|| EvalError::UnknownFunction(name.clone()))?;
1489
1490            if args.len() != func.arg_count() {
1491                return Err(EvalError::WrongArgCount {
1492                    func: name.clone(),
1493                    expected: func.arg_count(),
1494                    got: args.len(),
1495                });
1496            }
1497
1498            let arg_values: Vec<f32> = args
1499                .iter()
1500                .map(|a| eval_ast(a, vars, funcs))
1501                .collect::<Result<_, _>>()?;
1502
1503            Ok(func.call(&arg_values))
1504        }
1505        Ast::Let { name, value, body } => {
1506            let val = eval_ast(value, vars, funcs)?;
1507            let mut new_vars = vars.clone();
1508            new_vars.insert(name.clone(), val);
1509            eval_ast(body, &new_vars, funcs)
1510        }
1511    }
1512}
1513
1514#[cfg(not(feature = "func"))]
1515fn eval_ast(ast: &Ast, vars: &HashMap<String, f32>) -> Result<f32, EvalError> {
1516    match ast {
1517        Ast::Num(n) => Ok(*n as f32),
1518        Ast::Var(name) => vars
1519            .get(name)
1520            .copied()
1521            .ok_or_else(|| EvalError::UnknownVariable(name.clone())),
1522        Ast::BinOp(op, l, r) => {
1523            let l = eval_ast(l, vars)?;
1524            let r = eval_ast(r, vars)?;
1525            match op {
1526                BinOp::Add => Ok(l + r),
1527                BinOp::Sub => Ok(l - r),
1528                BinOp::Mul => Ok(l * r),
1529                BinOp::Div => Ok(l / r),
1530                BinOp::Pow => Ok(l.powf(r)),
1531                BinOp::Rem => Ok(l % r),
1532                BinOp::BitAnd => Err(EvalError::UnsupportedOperation("&".to_string())),
1533                BinOp::BitOr => Err(EvalError::UnsupportedOperation("|".to_string())),
1534                BinOp::Shl => Err(EvalError::UnsupportedOperation("<<".to_string())),
1535                BinOp::Shr => Err(EvalError::UnsupportedOperation(">>".to_string())),
1536            }
1537        }
1538        Ast::UnaryOp(op, inner) => {
1539            let v = eval_ast(inner, vars)?;
1540            match op {
1541                UnaryOp::Neg => Ok(-v),
1542                UnaryOp::BitNot => Err(EvalError::UnsupportedOperation("~".to_string())),
1543                #[cfg(feature = "cond")]
1544                UnaryOp::Not => {
1545                    if v == 0.0 {
1546                        Ok(1.0)
1547                    } else {
1548                        Ok(0.0)
1549                    }
1550                }
1551            }
1552        }
1553        #[cfg(feature = "cond")]
1554        Ast::Compare(op, l, r) => {
1555            let l = eval_ast(l, vars)?;
1556            let r = eval_ast(r, vars)?;
1557            let result = match op {
1558                CompareOp::Lt => l < r,
1559                CompareOp::Le => l <= r,
1560                CompareOp::Gt => l > r,
1561                CompareOp::Ge => l >= r,
1562                CompareOp::Eq => l == r,
1563                CompareOp::Ne => l != r,
1564            };
1565            Ok(if result { 1.0 } else { 0.0 })
1566        }
1567        #[cfg(feature = "cond")]
1568        Ast::And(l, r) => {
1569            let l = eval_ast(l, vars)?;
1570            if l == 0.0 {
1571                Ok(0.0) // Short-circuit
1572            } else {
1573                let r = eval_ast(r, vars)?;
1574                Ok(if r != 0.0 { 1.0 } else { 0.0 })
1575            }
1576        }
1577        #[cfg(feature = "cond")]
1578        Ast::Or(l, r) => {
1579            let l = eval_ast(l, vars)?;
1580            if l != 0.0 {
1581                Ok(1.0) // Short-circuit
1582            } else {
1583                let r = eval_ast(r, vars)?;
1584                Ok(if r != 0.0 { 1.0 } else { 0.0 })
1585            }
1586        }
1587        #[cfg(feature = "cond")]
1588        Ast::If(cond, then_expr, else_expr) => {
1589            let cond = eval_ast(cond, vars)?;
1590            if cond != 0.0 {
1591                eval_ast(then_expr, vars)
1592            } else {
1593                eval_ast(else_expr, vars)
1594            }
1595        }
1596        Ast::Let { name, value, body } => {
1597            let val = eval_ast(value, vars)?;
1598            let mut new_vars = vars.clone();
1599            new_vars.insert(name.clone(), val);
1600            eval_ast(body, &new_vars)
1601        }
1602    }
1603}
1604
1605// ============================================================================
1606// Tests
1607// ============================================================================
1608
1609#[cfg(test)]
1610mod tests {
1611    use super::*;
1612
1613    #[cfg(feature = "func")]
1614    fn eval(expr_str: &str, vars: &[(&str, f32)]) -> f32 {
1615        let registry = FunctionRegistry::new();
1616        let expr = Expr::parse(expr_str).unwrap();
1617        let var_map: HashMap<String, f32> = vars.iter().map(|(k, v)| (k.to_string(), *v)).collect();
1618        expr.eval(&var_map, &registry).unwrap()
1619    }
1620
1621    #[cfg(not(feature = "func"))]
1622    fn eval(expr_str: &str, vars: &[(&str, f32)]) -> f32 {
1623        let expr = Expr::parse(expr_str).unwrap();
1624        let var_map: HashMap<String, f32> = vars.iter().map(|(k, v)| (k.to_string(), *v)).collect();
1625        expr.eval(&var_map).unwrap()
1626    }
1627
1628    #[test]
1629    fn test_parse_number() {
1630        assert_eq!(eval("42", &[]), 42.0);
1631    }
1632
1633    #[test]
1634    fn test_parse_float() {
1635        assert!((eval("1.234", &[]) - 1.234).abs() < 0.001);
1636    }
1637
1638    #[test]
1639    fn test_parse_variable() {
1640        assert_eq!(eval("x", &[("x", 5.0)]), 5.0);
1641        assert_eq!(eval("foo", &[("foo", 3.0)]), 3.0);
1642    }
1643
1644    #[test]
1645    fn test_parse_add() {
1646        assert_eq!(eval("1 + 2", &[]), 3.0);
1647    }
1648
1649    #[test]
1650    fn test_parse_mul() {
1651        assert_eq!(eval("3 * 4", &[]), 12.0);
1652    }
1653
1654    #[test]
1655    fn test_precedence() {
1656        assert_eq!(eval("2 + 3 * 4", &[]), 14.0);
1657    }
1658
1659    #[test]
1660    fn test_parentheses() {
1661        assert_eq!(eval("(2 + 3) * 4", &[]), 20.0);
1662    }
1663
1664    #[test]
1665    fn test_negation() {
1666        assert_eq!(eval("-5", &[]), -5.0);
1667    }
1668
1669    #[test]
1670    fn test_power() {
1671        assert_eq!(eval("2 ^ 3", &[]), 8.0);
1672    }
1673
1674    #[cfg(feature = "func")]
1675    #[test]
1676    fn test_unknown_variable() {
1677        let registry = FunctionRegistry::new();
1678        let expr = Expr::parse("unknown").unwrap();
1679        let vars = HashMap::new();
1680        let result = expr.eval(&vars, &registry);
1681        assert!(matches!(result, Err(EvalError::UnknownVariable(_))));
1682    }
1683
1684    #[cfg(not(feature = "func"))]
1685    #[test]
1686    fn test_unknown_variable() {
1687        let expr = Expr::parse("unknown").unwrap();
1688        let vars = HashMap::new();
1689        let result = expr.eval(&vars);
1690        assert!(matches!(result, Err(EvalError::UnknownVariable(_))));
1691    }
1692
1693    #[cfg(feature = "func")]
1694    #[test]
1695    fn test_unknown_function() {
1696        let registry = FunctionRegistry::new();
1697        let expr = Expr::parse("unknown(1)").unwrap();
1698        let vars = HashMap::new();
1699        let result = expr.eval(&vars, &registry);
1700        assert!(matches!(result, Err(EvalError::UnknownFunction(_))));
1701    }
1702
1703    #[cfg(feature = "func")]
1704    #[test]
1705    fn test_custom_function() {
1706        struct Double;
1707        impl ExprFn for Double {
1708            fn name(&self) -> &str {
1709                "double"
1710            }
1711            fn arg_count(&self) -> usize {
1712                1
1713            }
1714            fn call(&self, args: &[f32]) -> f32 {
1715                args[0] * 2.0
1716            }
1717        }
1718
1719        let mut registry = FunctionRegistry::new();
1720        registry.register(Double);
1721
1722        let expr = Expr::parse("double(5)").unwrap();
1723        let vars = HashMap::new();
1724        assert_eq!(expr.eval(&vars, &registry).unwrap(), 10.0);
1725    }
1726
1727    #[cfg(feature = "func")]
1728    #[test]
1729    fn test_zero_arg_function() {
1730        struct Pi;
1731        impl ExprFn for Pi {
1732            fn name(&self) -> &str {
1733                "pi"
1734            }
1735            fn arg_count(&self) -> usize {
1736                0
1737            }
1738            fn call(&self, _args: &[f32]) -> f32 {
1739                std::f32::consts::PI
1740            }
1741        }
1742
1743        let mut registry = FunctionRegistry::new();
1744        registry.register(Pi);
1745
1746        let expr = Expr::parse("pi()").unwrap();
1747        let vars = HashMap::new();
1748        assert!((expr.eval(&vars, &registry).unwrap() - std::f32::consts::PI).abs() < 0.001);
1749    }
1750
1751    #[cfg(feature = "func")]
1752    #[test]
1753    fn test_wrong_arg_count() {
1754        struct OneArg;
1755        impl ExprFn for OneArg {
1756            fn name(&self) -> &str {
1757                "one_arg"
1758            }
1759            fn arg_count(&self) -> usize {
1760                1
1761            }
1762            fn call(&self, args: &[f32]) -> f32 {
1763                args[0]
1764            }
1765        }
1766
1767        let mut registry = FunctionRegistry::new();
1768        registry.register(OneArg);
1769
1770        let expr = Expr::parse("one_arg(1, 2)").unwrap();
1771        let vars = HashMap::new();
1772        let result = expr.eval(&vars, &registry);
1773        assert!(matches!(result, Err(EvalError::WrongArgCount { .. })));
1774    }
1775
1776    #[cfg(feature = "func")]
1777    #[test]
1778    fn test_complex_expression() {
1779        struct Add;
1780        impl ExprFn for Add {
1781            fn name(&self) -> &str {
1782                "add"
1783            }
1784            fn arg_count(&self) -> usize {
1785                2
1786            }
1787            fn call(&self, args: &[f32]) -> f32 {
1788                args[0] + args[1]
1789            }
1790        }
1791
1792        let mut registry = FunctionRegistry::new();
1793        registry.register(Add);
1794
1795        let expr = Expr::parse("add(x * 2, y + 1)").unwrap();
1796        let vars: HashMap<String, f32> = [("x".to_string(), 3.0), ("y".to_string(), 4.0)].into();
1797        assert_eq!(expr.eval(&vars, &registry).unwrap(), 11.0); // (3*2) + (4+1) = 11
1798    }
1799
1800    // Comparison tests (cond feature)
1801    #[cfg(feature = "cond")]
1802    #[test]
1803    fn test_compare_lt() {
1804        assert_eq!(eval("1 < 2", &[]), 1.0);
1805        assert_eq!(eval("2 < 1", &[]), 0.0);
1806        assert_eq!(eval("1 < 1", &[]), 0.0);
1807    }
1808
1809    #[cfg(feature = "cond")]
1810    #[test]
1811    fn test_compare_le() {
1812        assert_eq!(eval("1 <= 2", &[]), 1.0);
1813        assert_eq!(eval("2 <= 1", &[]), 0.0);
1814        assert_eq!(eval("1 <= 1", &[]), 1.0);
1815    }
1816
1817    #[cfg(feature = "cond")]
1818    #[test]
1819    fn test_compare_gt() {
1820        assert_eq!(eval("2 > 1", &[]), 1.0);
1821        assert_eq!(eval("1 > 2", &[]), 0.0);
1822    }
1823
1824    #[cfg(feature = "cond")]
1825    #[test]
1826    fn test_compare_ge() {
1827        assert_eq!(eval("2 >= 1", &[]), 1.0);
1828        assert_eq!(eval("1 >= 1", &[]), 1.0);
1829    }
1830
1831    #[cfg(feature = "cond")]
1832    #[test]
1833    fn test_compare_eq() {
1834        assert_eq!(eval("1 == 1", &[]), 1.0);
1835        assert_eq!(eval("1 == 2", &[]), 0.0);
1836    }
1837
1838    #[cfg(feature = "cond")]
1839    #[test]
1840    fn test_compare_ne() {
1841        assert_eq!(eval("1 != 2", &[]), 1.0);
1842        assert_eq!(eval("1 != 1", &[]), 0.0);
1843    }
1844
1845    // Boolean logic tests (cond feature)
1846    #[cfg(feature = "cond")]
1847    #[test]
1848    fn test_and() {
1849        assert_eq!(eval("1 and 1", &[]), 1.0);
1850        assert_eq!(eval("1 and 0", &[]), 0.0);
1851        assert_eq!(eval("0 and 1", &[]), 0.0);
1852        assert_eq!(eval("0 and 0", &[]), 0.0);
1853    }
1854
1855    #[cfg(feature = "cond")]
1856    #[test]
1857    fn test_or() {
1858        assert_eq!(eval("1 or 1", &[]), 1.0);
1859        assert_eq!(eval("1 or 0", &[]), 1.0);
1860        assert_eq!(eval("0 or 1", &[]), 1.0);
1861        assert_eq!(eval("0 or 0", &[]), 0.0);
1862    }
1863
1864    #[cfg(feature = "cond")]
1865    #[test]
1866    fn test_not() {
1867        assert_eq!(eval("not 0", &[]), 1.0);
1868        assert_eq!(eval("not 1", &[]), 0.0);
1869        assert_eq!(eval("not 5", &[]), 0.0); // any non-zero is truthy
1870    }
1871
1872    // Conditional tests (cond feature)
1873    #[cfg(feature = "cond")]
1874    #[test]
1875    fn test_if_then_else() {
1876        assert_eq!(eval("if 1 then 10 else 20", &[]), 10.0);
1877        assert_eq!(eval("if 0 then 10 else 20", &[]), 20.0);
1878    }
1879
1880    #[cfg(feature = "cond")]
1881    #[test]
1882    fn test_if_with_comparison() {
1883        assert_eq!(eval("if x > 5 then 1 else 0", &[("x", 10.0)]), 1.0);
1884        assert_eq!(eval("if x > 5 then 1 else 0", &[("x", 3.0)]), 0.0);
1885    }
1886
1887    #[cfg(feature = "cond")]
1888    #[test]
1889    fn test_nested_if() {
1890        // if x > 0 then (if x > 10 then 2 else 1) else 0
1891        assert_eq!(
1892            eval(
1893                "if x > 0 then if x > 10 then 2 else 1 else 0",
1894                &[("x", 15.0)]
1895            ),
1896            2.0
1897        );
1898        assert_eq!(
1899            eval(
1900                "if x > 0 then if x > 10 then 2 else 1 else 0",
1901                &[("x", 5.0)]
1902            ),
1903            1.0
1904        );
1905        assert_eq!(
1906            eval(
1907                "if x > 0 then if x > 10 then 2 else 1 else 0",
1908                &[("x", -1.0)]
1909            ),
1910            0.0
1911        );
1912    }
1913
1914    #[cfg(feature = "cond")]
1915    #[test]
1916    fn test_compound_boolean() {
1917        assert_eq!(eval("x > 0 and x < 10", &[("x", 5.0)]), 1.0);
1918        assert_eq!(eval("x > 0 and x < 10", &[("x", 15.0)]), 0.0);
1919        assert_eq!(eval("x < 0 or x > 10", &[("x", 5.0)]), 0.0);
1920        assert_eq!(eval("x < 0 or x > 10", &[("x", 15.0)]), 1.0);
1921    }
1922
1923    #[cfg(feature = "cond")]
1924    #[test]
1925    fn test_precedence_compare_vs_arithmetic() {
1926        // 1 + 2 < 4 should be (1 + 2) < 4, not 1 + (2 < 4)
1927        assert_eq!(eval("1 + 2 < 4", &[]), 1.0);
1928        assert_eq!(eval("1 + 2 < 3", &[]), 0.0);
1929    }
1930
1931    #[cfg(feature = "cond")]
1932    #[test]
1933    fn test_precedence_and_vs_or() {
1934        // a or b and c should be a or (b and c)
1935        assert_eq!(eval("1 or 0 and 0", &[]), 1.0); // 1 or (0 and 0) = 1 or 0 = 1
1936        assert_eq!(eval("0 or 1 and 1", &[]), 1.0); // 0 or (1 and 1) = 0 or 1 = 1
1937        assert_eq!(eval("0 or 0 and 1", &[]), 0.0); // 0 or (0 and 1) = 0 or 0 = 0
1938    }
1939
1940    // Free variables tests (introspect feature)
1941    #[cfg(feature = "introspect")]
1942    #[test]
1943    fn test_free_vars_simple() {
1944        let expr = Expr::parse("x + y").unwrap();
1945        let vars = expr.free_vars();
1946        assert_eq!(vars.len(), 2);
1947        assert!(vars.contains("x"));
1948        assert!(vars.contains("y"));
1949    }
1950
1951    #[cfg(feature = "introspect")]
1952    #[test]
1953    fn test_free_vars_no_vars() {
1954        let expr = Expr::parse("1 + 2 * 3").unwrap();
1955        let vars = expr.free_vars();
1956        assert!(vars.is_empty());
1957    }
1958
1959    #[cfg(feature = "introspect")]
1960    #[test]
1961    fn test_free_vars_duplicates() {
1962        let expr = Expr::parse("x + x * x").unwrap();
1963        let vars = expr.free_vars();
1964        assert_eq!(vars.len(), 1);
1965        assert!(vars.contains("x"));
1966    }
1967
1968    #[cfg(all(feature = "introspect", feature = "func"))]
1969    #[test]
1970    fn test_free_vars_in_call() {
1971        let expr = Expr::parse("sin(x) + cos(y)").unwrap();
1972        let vars = expr.free_vars();
1973        assert_eq!(vars.len(), 2);
1974        assert!(vars.contains("x"));
1975        assert!(vars.contains("y"));
1976    }
1977
1978    #[cfg(all(feature = "introspect", feature = "cond"))]
1979    #[test]
1980    fn test_free_vars_in_conditional() {
1981        let expr = Expr::parse("if a > b then x else y").unwrap();
1982        let vars = expr.free_vars();
1983        assert_eq!(vars.len(), 4);
1984        assert!(vars.contains("a"));
1985        assert!(vars.contains("b"));
1986        assert!(vars.contains("x"));
1987        assert!(vars.contains("y"));
1988    }
1989
1990    // Let binding tests
1991    #[test]
1992    fn test_let_simple() {
1993        assert_eq!(eval("let a = 1; a", &[]), 1.0);
1994    }
1995
1996    #[test]
1997    fn test_let_with_computation() {
1998        assert_eq!(eval("let a = 1 + 2; a * 2", &[]), 6.0);
1999    }
2000
2001    #[test]
2002    fn test_let_chained() {
2003        assert_eq!(eval("let a = 1; let b = 2; a + b", &[]), 3.0);
2004    }
2005
2006    #[test]
2007    fn test_let_with_var() {
2008        assert_eq!(eval("let a = x * 2; a + 1", &[("x", 5.0)]), 11.0);
2009    }
2010
2011    #[test]
2012    fn test_let_shadow() {
2013        // Inner binding shadows outer variable
2014        assert_eq!(eval("let x = 10; x + 1", &[("x", 5.0)]), 11.0);
2015    }
2016
2017    #[test]
2018    fn test_let_uses_outer_in_value() {
2019        // The value expression uses the outer x, then binds to x
2020        assert_eq!(eval("let x = x + 1; x", &[("x", 5.0)]), 6.0);
2021    }
2022
2023    #[cfg(feature = "introspect")]
2024    #[test]
2025    fn test_free_vars_in_let() {
2026        // let a = x; a + y -> free vars are x and y (not a)
2027        let expr = Expr::parse("let a = x; a + y").unwrap();
2028        let vars = expr.free_vars();
2029        assert_eq!(vars.len(), 2);
2030        assert!(vars.contains("x"));
2031        assert!(vars.contains("y"));
2032        assert!(!vars.contains("a"));
2033    }
2034
2035    #[cfg(feature = "introspect")]
2036    #[test]
2037    fn test_free_vars_in_let_shadow() {
2038        // let x = x + 1; x -> free var is x (from the value expression)
2039        let expr = Expr::parse("let x = x + 1; x").unwrap();
2040        let vars = expr.free_vars();
2041        assert_eq!(vars.len(), 1);
2042        assert!(vars.contains("x"));
2043    }
2044
2045    // AST Display / roundtrip tests
2046    #[test]
2047    fn test_ast_display_simple() {
2048        let expr = Expr::parse("1 + 2").unwrap();
2049        let s = expr.ast().to_string();
2050        assert_eq!(s, "(1 + 2)");
2051    }
2052
2053    #[test]
2054    fn test_ast_display_nested() {
2055        let expr = Expr::parse("1 + 2 * 3").unwrap();
2056        let s = expr.ast().to_string();
2057        // Should be fully parenthesized
2058        assert_eq!(s, "(1 + (2 * 3))");
2059    }
2060
2061    #[test]
2062    fn test_ast_roundtrip() {
2063        let cases = [
2064            "1 + 2",
2065            "x * y",
2066            "1 + 2 * 3",
2067            "(1 + 2) * 3",
2068            "-x",
2069            "x ^ 2",
2070            "2 ^ 3 ^ 4", // right-associative
2071        ];
2072        for case in cases {
2073            let expr1 = Expr::parse(case).unwrap();
2074            let stringified = expr1.ast().to_string();
2075            let expr2 = Expr::parse(&stringified).unwrap();
2076            let stringified2 = expr2.ast().to_string();
2077            assert_eq!(stringified, stringified2, "Roundtrip failed for: {}", case);
2078        }
2079    }
2080
2081    #[cfg(feature = "func")]
2082    #[test]
2083    fn test_ast_roundtrip_func() {
2084        let cases = ["sin(x)", "foo(a, b, c)", "f()"];
2085        for case in cases {
2086            let expr1 = Expr::parse(case).unwrap();
2087            let stringified = expr1.ast().to_string();
2088            let expr2 = Expr::parse(&stringified).unwrap();
2089            let stringified2 = expr2.ast().to_string();
2090            assert_eq!(stringified, stringified2, "Roundtrip failed for: {}", case);
2091        }
2092    }
2093
2094    #[cfg(feature = "cond")]
2095    #[test]
2096    fn test_ast_roundtrip_cond() {
2097        let cases = [
2098            "x < y",
2099            "x and y",
2100            "x or y",
2101            "not x",
2102            "if x then y else z",
2103            "if a > b then x else y",
2104        ];
2105        for case in cases {
2106            let expr1 = Expr::parse(case).unwrap();
2107            let stringified = expr1.ast().to_string();
2108            let expr2 = Expr::parse(&stringified).unwrap();
2109            let stringified2 = expr2.ast().to_string();
2110            assert_eq!(stringified, stringified2, "Roundtrip failed for: {}", case);
2111        }
2112    }
2113
2114    #[test]
2115    fn test_ast_roundtrip_let() {
2116        let cases = [
2117            "let a = 1; a",
2118            "let a = 1; let b = 2; a + b",
2119            "let x = y * 2; x + 1",
2120        ];
2121        for case in cases {
2122            let expr1 = Expr::parse(case).unwrap();
2123            let stringified = expr1.ast().to_string();
2124            let expr2 = Expr::parse(&stringified).unwrap();
2125            let stringified2 = expr2.ast().to_string();
2126            assert_eq!(stringified, stringified2, "Roundtrip failed for: {}", case);
2127        }
2128    }
2129}
2130
2131// ============================================================================
2132// Property-based tests (proptest)
2133// ============================================================================
2134
2135#[cfg(test)]
2136mod proptest_tests {
2137    use super::*;
2138    use proptest::prelude::*;
2139
2140    /// Strategy for generating valid expression strings
2141    fn expr_strategy() -> impl Strategy<Value = String> {
2142        // Generate simple arithmetic expressions
2143        let num = prop::num::f32::NORMAL.prop_map(|n| format!("{:.6}", n));
2144        let var = prop::sample::select(vec!["x", "y", "z", "a", "b"]).prop_map(String::from);
2145
2146        // Operators
2147        let binop = prop::sample::select(vec!["+", "-", "*", "/"]);
2148
2149        // Combine into expressions
2150        prop::strategy::Union::new(vec![
2151            num.clone().boxed(),
2152            var.clone().boxed(),
2153            (num.clone(), binop.clone(), num.clone())
2154                .prop_map(|(l, op, r)| format!("({} {} {})", l, op, r))
2155                .boxed(),
2156            (var.clone(), binop.clone(), num.clone())
2157                .prop_map(|(l, op, r)| format!("({} {} {})", l, op, r))
2158                .boxed(),
2159            (num.clone(), binop, var.clone())
2160                .prop_map(|(l, op, r)| format!("({} {} {})", l, op, r))
2161                .boxed(),
2162        ])
2163    }
2164
2165    proptest! {
2166        /// Parser should not panic on arbitrary input
2167        #[test]
2168        fn parse_never_panics(s in ".*") {
2169            // Just check that parsing doesn't panic, result can be Ok or Err
2170            let _ = Expr::parse(&s);
2171        }
2172
2173        /// Valid expressions should parse successfully
2174        #[test]
2175        fn valid_expr_parses(expr in expr_strategy()) {
2176            let result = Expr::parse(&expr);
2177            prop_assert!(result.is_ok(), "Failed to parse: {}", expr);
2178        }
2179
2180        /// Numbers round-trip correctly
2181        #[test]
2182        fn number_roundtrip(n in prop::num::f64::NORMAL) {
2183            let expr_str = format!("{:.6}", n);
2184            if let Ok(expr) = Expr::parse(&expr_str) {
2185                // The parsed number should be close to the original
2186                if let Ast::Num(parsed) = expr.ast() {
2187                    let diff = (parsed - n).abs();
2188                    prop_assert!(diff < 0.001 || diff / n.abs() < 0.001,
2189                        "Number mismatch: {} vs {}", n, parsed);
2190                }
2191            }
2192        }
2193
2194        /// Evaluation with valid variables shouldn't panic
2195        #[test]
2196        #[cfg(not(feature = "func"))]
2197        fn eval_with_vars_no_panic(
2198            x in prop::num::f32::NORMAL,
2199            y in prop::num::f32::NORMAL,
2200            expr in expr_strategy()
2201        ) {
2202            if let Ok(parsed) = Expr::parse(&expr) {
2203                let vars: HashMap<String, f32> = [
2204                    ("x".into(), x),
2205                    ("y".into(), y),
2206                    ("z".into(), 1.0),
2207                    ("a".into(), 2.0),
2208                    ("b".into(), 3.0),
2209                ].into();
2210                // Just check it doesn't panic, result can be anything (including NaN/Inf)
2211                let _ = parsed.eval(&vars);
2212            }
2213        }
2214
2215        /// Evaluation with valid variables shouldn't panic (func feature)
2216        #[test]
2217        #[cfg(feature = "func")]
2218        fn eval_with_vars_no_panic_func(
2219            x in prop::num::f32::NORMAL,
2220            y in prop::num::f32::NORMAL,
2221            expr in expr_strategy()
2222        ) {
2223            if let Ok(parsed) = Expr::parse(&expr) {
2224                let vars: HashMap<String, f32> = [
2225                    ("x".into(), x),
2226                    ("y".into(), y),
2227                    ("z".into(), 1.0),
2228                    ("a".into(), 2.0),
2229                    ("b".into(), 3.0),
2230                ].into();
2231                let registry = FunctionRegistry::new();
2232                let _ = parsed.eval(&vars, &registry);
2233            }
2234        }
2235
2236        /// Negation is its own inverse
2237        #[test]
2238        #[cfg(not(feature = "func"))]
2239        fn negation_inverse(n in prop::num::f32::NORMAL) {
2240            let expr = Expr::parse(&format!("--{:.6}", n)).unwrap();
2241            let vars = HashMap::new();
2242            let result = expr.eval(&vars).unwrap();
2243            prop_assert!((result - n).abs() < 0.01,
2244                "Double negation failed: --{} = {}", n, result);
2245        }
2246
2247        /// Negation is its own inverse (func feature)
2248        #[test]
2249        #[cfg(feature = "func")]
2250        fn negation_inverse_func(n in prop::num::f32::NORMAL) {
2251            let expr = Expr::parse(&format!("--{:.6}", n)).unwrap();
2252            let vars = HashMap::new();
2253            let registry = FunctionRegistry::new();
2254            let result = expr.eval(&vars, &registry).unwrap();
2255            prop_assert!((result - n).abs() < 0.01,
2256                "Double negation failed: --{} = {}", n, result);
2257        }
2258    }
2259}