Skip to main content

wick_scalar/
lib.rs

1//! Standard scalar function library for dew expressions.
2//!
3//! This crate provides the foundation for numeric expressions: standard math functions
4//! (sin, cos, sqrt, etc.), constants (pi, e, tau), and evaluation for scalar values.
5//! All functions are generic over `T: Float`, supporting both `f32` and `f64`.
6//!
7//! # Quick Start
8//!
9//! ```
10//! use wick_core::Expr;
11//! use wick_scalar::{eval, scalar_registry};
12//! use std::collections::HashMap;
13//!
14//! // Parse and evaluate an expression
15//! let expr = Expr::parse("sin(x * pi()) + 1").unwrap();
16//! let vars: HashMap<String, f32> = [("x".into(), 0.5)].into();
17//! let result = eval(expr.ast(), &vars, &scalar_registry()).unwrap();
18//! assert!((result - 2.0).abs() < 0.001); // sin(0.5 * π) + 1 = 2
19//! ```
20//!
21//! # Features
22//!
23//! | Feature       | Description                                |
24//! |---------------|--------------------------------------------|
25//! | `wgsl`        | WGSL shader code generation                |
26//! | `lua-codegen` | Lua code generation (pure Rust, WASM-safe) |
27//! | `lua`         | Lua codegen + mlua execution               |
28//! | `cranelift`   | Cranelift JIT compilation                  |
29//!
30//! # Available Functions
31//!
32//! ## Constants
33//!
34//! | Function | Description              |
35//! |----------|--------------------------|
36//! | `pi()`   | π ≈ 3.14159              |
37//! | `e()`    | Euler's number ≈ 2.71828 |
38//! | `tau()`  | τ = 2π ≈ 6.28318         |
39//!
40//! ## Trigonometric
41//!
42//! | Function       | Description                    |
43//! |----------------|--------------------------------|
44//! | `sin(x)`       | Sine                           |
45//! | `cos(x)`       | Cosine                         |
46//! | `tan(x)`       | Tangent                        |
47//! | `asin(x)`      | Arcsine                        |
48//! | `acos(x)`      | Arccosine                      |
49//! | `atan(x)`      | Arctangent                     |
50//! | `atan2(y, x)`  | Two-argument arctangent        |
51//! | `sinh(x)`      | Hyperbolic sine                |
52//! | `cosh(x)`      | Hyperbolic cosine              |
53//! | `tanh(x)`      | Hyperbolic tangent             |
54//!
55//! ## Exponential & Logarithmic
56//!
57//! | Function        | Description                  |
58//! |-----------------|------------------------------|
59//! | `exp(x)`        | e^x                          |
60//! | `exp2(x)`       | 2^x                          |
61//! | `log(x)`        | Natural logarithm (alias ln) |
62//! | `ln(x)`         | Natural logarithm            |
63//! | `log2(x)`       | Base-2 logarithm             |
64//! | `log10(x)`      | Base-10 logarithm            |
65//! | `pow(x, y)`     | x^y                          |
66//! | `sqrt(x)`       | Square root                  |
67//! | `inversesqrt(x)`| 1 / sqrt(x)                  |
68//!
69//! ## Common Math
70//!
71//! | Function         | Description                    |
72//! |------------------|--------------------------------|
73//! | `abs(x)`         | Absolute value                 |
74//! | `sign(x)`        | Sign (-1, 0, or 1)             |
75//! | `floor(x)`       | Round down                     |
76//! | `ceil(x)`        | Round up                       |
77//! | `round(x)`       | Round to nearest               |
78//! | `trunc(x)`       | Truncate toward zero           |
79//! | `fract(x)`       | Fractional part                |
80//! | `min(a, b)`      | Minimum of two values          |
81//! | `max(a, b)`      | Maximum of two values          |
82//! | `clamp(x, lo, hi)`| Clamp to range                |
83//! | `saturate(x)`    | Clamp to [0, 1]                |
84//!
85//! ## Interpolation
86//!
87//! | Function                      | Description                           |
88//! |-------------------------------|---------------------------------------|
89//! | `lerp(a, b, t)`               | Linear interpolation: a + (b-a)*t     |
90//! | `mix(a, b, t)`                | Alias for lerp (GLSL naming)          |
91//! | `step(edge, x)`               | 0 if x < edge, else 1                 |
92//! | `smoothstep(e0, e1, x)`       | Smooth Hermite interpolation          |
93//! | `inverse_lerp(a, b, v)`       | Inverse of lerp: (v-a) / (b-a)        |
94//! | `remap(x, i0, i1, o0, o1)`    | Remap from [i0,i1] to [o0,o1]         |
95//!
96//! # Custom Functions
97//!
98//! You can register custom functions by implementing the [`ScalarFn`] trait:
99//!
100//! ```
101//! use wick_scalar::{ScalarFn, FunctionRegistry, scalar_registry};
102//!
103//! struct Double;
104//! impl ScalarFn<f32> for Double {
105//!     fn name(&self) -> &str { "double" }
106//!     fn arg_count(&self) -> usize { 1 }
107//!     fn call(&self, args: &[f32]) -> f32 { args[0] * 2.0 }
108//! }
109//!
110//! let mut registry = scalar_registry();
111//! registry.register(Double);
112//! ```
113//!
114//! # Using f64
115//!
116//! All functions work with `f64` by specifying the type parameter:
117//!
118//! ```
119//! use wick_core::Expr;
120//! use wick_scalar::{eval, scalar_registry, FunctionRegistry};
121//! use std::collections::HashMap;
122//!
123//! let registry: FunctionRegistry<f64> = scalar_registry();
124//! let expr = Expr::parse("sqrt(2)").unwrap();
125//! let result: f64 = eval(expr.ast(), &HashMap::new(), &registry).unwrap();
126//! assert!((result - std::f64::consts::SQRT_2).abs() < 1e-10);
127//! ```
128
129use num_traits::{Float, NumCast, PrimInt};
130use wick_core::{Ast, BinOp, CompareOp, UnaryOp};
131
132// Re-export Numeric from dew-core
133use std::collections::HashMap;
134use std::sync::Arc;
135pub use wick_core::Numeric;
136
137#[cfg(feature = "wgsl")]
138pub mod wgsl;
139
140#[cfg(feature = "glsl")]
141pub mod glsl;
142
143#[cfg(feature = "rust")]
144pub mod rust;
145
146#[cfg(feature = "c")]
147pub mod c;
148
149#[cfg(feature = "opencl")]
150pub mod opencl;
151
152#[cfg(feature = "cuda")]
153pub mod cuda;
154
155#[cfg(feature = "hip")]
156pub mod hip;
157
158#[cfg(feature = "tokenstream")]
159pub mod tokenstream;
160
161#[cfg(any(feature = "lua", feature = "lua-codegen"))]
162pub mod lua;
163
164#[cfg(feature = "cranelift")]
165pub mod cranelift;
166
167#[cfg(feature = "optimize")]
168pub mod optimize;
169
170// ============================================================================
171// Errors
172// ============================================================================
173
174/// Scalar evaluation error.
175#[derive(Debug, Clone, PartialEq)]
176pub enum Error {
177    /// Unknown variable.
178    UnknownVariable(String),
179    /// Unknown function.
180    UnknownFunction(String),
181    /// Wrong number of arguments to function.
182    WrongArgCount {
183        func: String,
184        expected: usize,
185        got: usize,
186    },
187    /// Operation not supported for this numeric type.
188    UnsupportedOperation(String),
189    /// Literal cannot be converted to target type (e.g., 3.14 to i32).
190    InvalidLiteral(f64),
191    /// Negative exponent not allowed for integer types.
192    NegativeExponent,
193}
194
195impl std::fmt::Display for Error {
196    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197        match self {
198            Error::UnknownVariable(name) => write!(f, "unknown variable: '{name}'"),
199            Error::UnknownFunction(name) => write!(f, "unknown function: '{name}'"),
200            Error::WrongArgCount {
201                func,
202                expected,
203                got,
204            } => {
205                write!(f, "function '{func}' expects {expected} args, got {got}")
206            }
207            Error::UnsupportedOperation(op) => {
208                write!(f, "operation '{op}' not supported for this numeric type")
209            }
210            Error::InvalidLiteral(n) => {
211                write!(f, "literal {n} cannot be converted to integer type")
212            }
213            Error::NegativeExponent => {
214                write!(f, "negative exponent not allowed for integer types")
215            }
216        }
217    }
218}
219
220impl std::error::Error for Error {}
221
222// ============================================================================
223// Function Registry
224// ============================================================================
225
226/// A scalar function that can be called from expressions.
227pub trait ScalarFn<T>: Send + Sync {
228    /// Function name.
229    fn name(&self) -> &str;
230
231    /// Number of arguments.
232    fn arg_count(&self) -> usize;
233
234    /// Call the function with arguments.
235    fn call(&self, args: &[T]) -> T;
236}
237
238/// Registry of scalar functions.
239#[derive(Clone)]
240pub struct FunctionRegistry<T> {
241    funcs: HashMap<String, Arc<dyn ScalarFn<T>>>,
242}
243
244impl<T> Default for FunctionRegistry<T> {
245    fn default() -> Self {
246        Self {
247            funcs: HashMap::new(),
248        }
249    }
250}
251
252impl<T> FunctionRegistry<T> {
253    pub fn new() -> Self {
254        Self::default()
255    }
256
257    pub fn register<F: ScalarFn<T> + 'static>(&mut self, func: F) {
258        self.funcs.insert(func.name().to_string(), Arc::new(func));
259    }
260
261    pub fn get(&self, name: &str) -> Option<&Arc<dyn ScalarFn<T>>> {
262        self.funcs.get(name)
263    }
264
265    /// Returns an iterator over all function names.
266    pub fn names(&self) -> impl Iterator<Item = &str> {
267        self.funcs.keys().map(|s| s.as_str())
268    }
269}
270
271// ============================================================================
272// Evaluation
273// ============================================================================
274
275/// Evaluate an AST with scalar float values.
276///
277/// For integer types, use [`eval_int`].
278pub fn eval<T: Float>(
279    ast: &Ast,
280    vars: &HashMap<String, T>,
281    funcs: &FunctionRegistry<T>,
282) -> Result<T, Error> {
283    match ast {
284        Ast::Num(n) => Ok(T::from(*n).unwrap()),
285
286        Ast::Var(name) => vars
287            .get(name)
288            .copied()
289            .ok_or_else(|| Error::UnknownVariable(name.clone())),
290
291        Ast::BinOp(op, left, right) => {
292            let l = eval(left, vars, funcs)?;
293            let r = eval(right, vars, funcs)?;
294            match op {
295                BinOp::Add => Ok(l + r),
296                BinOp::Sub => Ok(l - r),
297                BinOp::Mul => Ok(l * r),
298                BinOp::Div => Ok(l / r),
299                BinOp::Pow => Ok(l.powf(r)),
300                BinOp::Rem => Ok(l % r),
301                BinOp::BitAnd => Err(Error::UnsupportedOperation("&".into())),
302                BinOp::BitOr => Err(Error::UnsupportedOperation("|".into())),
303                BinOp::Shl => Err(Error::UnsupportedOperation("<<".into())),
304                BinOp::Shr => Err(Error::UnsupportedOperation(">>".into())),
305            }
306        }
307
308        Ast::UnaryOp(op, inner) => {
309            let v = eval(inner, vars, funcs)?;
310            match op {
311                UnaryOp::Neg => Ok(-v),
312                UnaryOp::BitNot => Err(Error::UnsupportedOperation("~".into())),
313                UnaryOp::Not => {
314                    if v == T::zero() {
315                        Ok(T::one())
316                    } else {
317                        Ok(T::zero())
318                    }
319                }
320            }
321        }
322
323        Ast::Compare(op, left, right) => {
324            let l = eval(left, vars, funcs)?;
325            let r = eval(right, vars, funcs)?;
326            let result = match op {
327                CompareOp::Lt => l < r,
328                CompareOp::Le => l <= r,
329                CompareOp::Gt => l > r,
330                CompareOp::Ge => l >= r,
331                CompareOp::Eq => l == r,
332                CompareOp::Ne => l != r,
333            };
334            Ok(if result { T::one() } else { T::zero() })
335        }
336
337        Ast::And(left, right) => {
338            let l = eval(left, vars, funcs)?;
339            if l == T::zero() {
340                Ok(T::zero()) // Short-circuit
341            } else {
342                let r = eval(right, vars, funcs)?;
343                Ok(if r != T::zero() { T::one() } else { T::zero() })
344            }
345        }
346
347        Ast::Or(left, right) => {
348            let l = eval(left, vars, funcs)?;
349            if l != T::zero() {
350                Ok(T::one()) // Short-circuit
351            } else {
352                let r = eval(right, vars, funcs)?;
353                Ok(if r != T::zero() { T::one() } else { T::zero() })
354            }
355        }
356
357        Ast::If(cond, then_expr, else_expr) => {
358            let c = eval(cond, vars, funcs)?;
359            if c != T::zero() {
360                eval(then_expr, vars, funcs)
361            } else {
362                eval(else_expr, vars, funcs)
363            }
364        }
365
366        Ast::Call(name, args) => {
367            let func = funcs
368                .get(name)
369                .ok_or_else(|| Error::UnknownFunction(name.clone()))?;
370
371            if args.len() != func.arg_count() {
372                return Err(Error::WrongArgCount {
373                    func: name.clone(),
374                    expected: func.arg_count(),
375                    got: args.len(),
376                });
377            }
378
379            let arg_vals: Vec<T> = args
380                .iter()
381                .map(|a| eval(a, vars, funcs))
382                .collect::<Result<_, _>>()?;
383
384            Ok(func.call(&arg_vals))
385        }
386
387        Ast::Let { name, value, body } => {
388            let val = eval(value, vars, funcs)?;
389            let mut new_vars = vars.clone();
390            new_vars.insert(name.clone(), val);
391            eval(body, &new_vars, funcs)
392        }
393    }
394}
395
396/// Evaluate an AST with integer values.
397///
398/// Supports bitwise operations and errors on:
399/// - Fractional literals (e.g., 3.14)
400/// - Negative exponents
401/// - Float-only functions (sin, cos, etc.)
402pub fn eval_int<T: PrimInt + NumCast>(
403    ast: &Ast,
404    vars: &HashMap<String, T>,
405    funcs: &FunctionRegistry<T>,
406) -> Result<T, Error> {
407    match ast {
408        Ast::Num(n) => {
409            // Check if the literal is a whole number
410            if n.fract() != 0.0 {
411                return Err(Error::InvalidLiteral(*n));
412            }
413            T::from(*n).ok_or(Error::InvalidLiteral(*n))
414        }
415
416        Ast::Var(name) => vars
417            .get(name)
418            .copied()
419            .ok_or_else(|| Error::UnknownVariable(name.clone())),
420
421        Ast::BinOp(op, left, right) => {
422            let l = eval_int(left, vars, funcs)?;
423            let r = eval_int(right, vars, funcs)?;
424            match op {
425                BinOp::Add => Ok(l + r),
426                BinOp::Sub => Ok(l - r),
427                BinOp::Mul => Ok(l * r),
428                BinOp::Div => Ok(l / r),
429                BinOp::Rem => Ok(l % r),
430                BinOp::Pow => {
431                    // Check for negative exponent
432                    if r < T::zero() {
433                        return Err(Error::NegativeExponent);
434                    }
435                    // Integer power via repeated multiplication
436                    let mut result = T::one();
437                    let mut exp = r;
438                    let mut base = l;
439                    while exp > T::zero() {
440                        if exp & T::one() == T::one() {
441                            result = result * base;
442                        }
443                        base = base * base;
444                        exp = exp >> 1;
445                    }
446                    Ok(result)
447                }
448                BinOp::BitAnd => Ok(l & r),
449                BinOp::BitOr => Ok(l | r),
450                BinOp::Shl => {
451                    // Convert shift amount to usize
452                    let shift: u32 = r.to_u32().unwrap_or(0);
453                    Ok(l << shift as usize)
454                }
455                BinOp::Shr => {
456                    let shift: u32 = r.to_u32().unwrap_or(0);
457                    Ok(l >> shift as usize)
458                }
459            }
460        }
461
462        Ast::UnaryOp(op, inner) => {
463            let v = eval_int(inner, vars, funcs)?;
464            match op {
465                UnaryOp::Neg => Ok(T::zero() - v),
466                UnaryOp::BitNot => Ok(!v),
467                UnaryOp::Not => {
468                    if v == T::zero() {
469                        Ok(T::one())
470                    } else {
471                        Ok(T::zero())
472                    }
473                }
474            }
475        }
476
477        Ast::Compare(op, left, right) => {
478            let l = eval_int(left, vars, funcs)?;
479            let r = eval_int(right, vars, funcs)?;
480            let result = match op {
481                CompareOp::Lt => l < r,
482                CompareOp::Le => l <= r,
483                CompareOp::Gt => l > r,
484                CompareOp::Ge => l >= r,
485                CompareOp::Eq => l == r,
486                CompareOp::Ne => l != r,
487            };
488            Ok(if result { T::one() } else { T::zero() })
489        }
490
491        Ast::And(left, right) => {
492            let l = eval_int(left, vars, funcs)?;
493            if l == T::zero() {
494                Ok(T::zero())
495            } else {
496                let r = eval_int(right, vars, funcs)?;
497                Ok(if r != T::zero() { T::one() } else { T::zero() })
498            }
499        }
500
501        Ast::Or(left, right) => {
502            let l = eval_int(left, vars, funcs)?;
503            if l != T::zero() {
504                Ok(T::one())
505            } else {
506                let r = eval_int(right, vars, funcs)?;
507                Ok(if r != T::zero() { T::one() } else { T::zero() })
508            }
509        }
510
511        Ast::If(cond, then_expr, else_expr) => {
512            let c = eval_int(cond, vars, funcs)?;
513            if c != T::zero() {
514                eval_int(then_expr, vars, funcs)
515            } else {
516                eval_int(else_expr, vars, funcs)
517            }
518        }
519
520        Ast::Call(name, args) => {
521            let func = funcs
522                .get(name)
523                .ok_or_else(|| Error::UnknownFunction(name.clone()))?;
524
525            if args.len() != func.arg_count() {
526                return Err(Error::WrongArgCount {
527                    func: name.clone(),
528                    expected: func.arg_count(),
529                    got: args.len(),
530                });
531            }
532
533            let arg_vals: Vec<T> = args
534                .iter()
535                .map(|a| eval_int(a, vars, funcs))
536                .collect::<Result<_, _>>()?;
537
538            Ok(func.call(&arg_vals))
539        }
540
541        Ast::Let { name, value, body } => {
542            let val = eval_int(value, vars, funcs)?;
543            let mut new_vars = vars.clone();
544            new_vars.insert(name.clone(), val);
545            eval_int(body, &new_vars, funcs)
546        }
547    }
548}
549
550// ============================================================================
551// Standard Functions - Constants
552// ============================================================================
553
554/// Pi constant: pi() = 3.14159...
555pub struct Pi;
556impl<T: Float> ScalarFn<T> for Pi {
557    fn name(&self) -> &str {
558        "pi"
559    }
560    fn arg_count(&self) -> usize {
561        0
562    }
563    fn call(&self, _args: &[T]) -> T {
564        T::from(std::f64::consts::PI).unwrap()
565    }
566}
567
568/// Euler's number: e() = 2.71828...
569pub struct E;
570impl<T: Float> ScalarFn<T> for E {
571    fn name(&self) -> &str {
572        "e"
573    }
574    fn arg_count(&self) -> usize {
575        0
576    }
577    fn call(&self, _args: &[T]) -> T {
578        T::from(std::f64::consts::E).unwrap()
579    }
580}
581
582/// Tau constant: tau() = 2*pi = 6.28318...
583pub struct Tau;
584impl<T: Float> ScalarFn<T> for Tau {
585    fn name(&self) -> &str {
586        "tau"
587    }
588    fn arg_count(&self) -> usize {
589        0
590    }
591    fn call(&self, _args: &[T]) -> T {
592        T::from(std::f64::consts::TAU).unwrap()
593    }
594}
595
596// ============================================================================
597// Standard Functions - Trigonometric
598// ============================================================================
599
600pub struct Sin;
601impl<T: Float> ScalarFn<T> for Sin {
602    fn name(&self) -> &str {
603        "sin"
604    }
605    fn arg_count(&self) -> usize {
606        1
607    }
608    fn call(&self, args: &[T]) -> T {
609        args[0].sin()
610    }
611}
612
613pub struct Cos;
614impl<T: Float> ScalarFn<T> for Cos {
615    fn name(&self) -> &str {
616        "cos"
617    }
618    fn arg_count(&self) -> usize {
619        1
620    }
621    fn call(&self, args: &[T]) -> T {
622        args[0].cos()
623    }
624}
625
626pub struct Tan;
627impl<T: Float> ScalarFn<T> for Tan {
628    fn name(&self) -> &str {
629        "tan"
630    }
631    fn arg_count(&self) -> usize {
632        1
633    }
634    fn call(&self, args: &[T]) -> T {
635        args[0].tan()
636    }
637}
638
639pub struct Asin;
640impl<T: Float> ScalarFn<T> for Asin {
641    fn name(&self) -> &str {
642        "asin"
643    }
644    fn arg_count(&self) -> usize {
645        1
646    }
647    fn call(&self, args: &[T]) -> T {
648        args[0].asin()
649    }
650}
651
652pub struct Acos;
653impl<T: Float> ScalarFn<T> for Acos {
654    fn name(&self) -> &str {
655        "acos"
656    }
657    fn arg_count(&self) -> usize {
658        1
659    }
660    fn call(&self, args: &[T]) -> T {
661        args[0].acos()
662    }
663}
664
665pub struct Atan;
666impl<T: Float> ScalarFn<T> for Atan {
667    fn name(&self) -> &str {
668        "atan"
669    }
670    fn arg_count(&self) -> usize {
671        1
672    }
673    fn call(&self, args: &[T]) -> T {
674        args[0].atan()
675    }
676}
677
678pub struct Atan2;
679impl<T: Float> ScalarFn<T> for Atan2 {
680    fn name(&self) -> &str {
681        "atan2"
682    }
683    fn arg_count(&self) -> usize {
684        2
685    }
686    fn call(&self, args: &[T]) -> T {
687        args[0].atan2(args[1])
688    }
689}
690
691pub struct Sinh;
692impl<T: Float> ScalarFn<T> for Sinh {
693    fn name(&self) -> &str {
694        "sinh"
695    }
696    fn arg_count(&self) -> usize {
697        1
698    }
699    fn call(&self, args: &[T]) -> T {
700        args[0].sinh()
701    }
702}
703
704pub struct Cosh;
705impl<T: Float> ScalarFn<T> for Cosh {
706    fn name(&self) -> &str {
707        "cosh"
708    }
709    fn arg_count(&self) -> usize {
710        1
711    }
712    fn call(&self, args: &[T]) -> T {
713        args[0].cosh()
714    }
715}
716
717pub struct Tanh;
718impl<T: Float> ScalarFn<T> for Tanh {
719    fn name(&self) -> &str {
720        "tanh"
721    }
722    fn arg_count(&self) -> usize {
723        1
724    }
725    fn call(&self, args: &[T]) -> T {
726        args[0].tanh()
727    }
728}
729
730// ============================================================================
731// Standard Functions - Exponential / Logarithmic
732// ============================================================================
733
734pub struct Exp;
735impl<T: Float> ScalarFn<T> for Exp {
736    fn name(&self) -> &str {
737        "exp"
738    }
739    fn arg_count(&self) -> usize {
740        1
741    }
742    fn call(&self, args: &[T]) -> T {
743        args[0].exp()
744    }
745}
746
747pub struct Exp2;
748impl<T: Float> ScalarFn<T> for Exp2 {
749    fn name(&self) -> &str {
750        "exp2"
751    }
752    fn arg_count(&self) -> usize {
753        1
754    }
755    fn call(&self, args: &[T]) -> T {
756        args[0].exp2()
757    }
758}
759
760pub struct Log;
761impl<T: Float> ScalarFn<T> for Log {
762    fn name(&self) -> &str {
763        "log"
764    }
765    fn arg_count(&self) -> usize {
766        1
767    }
768    fn call(&self, args: &[T]) -> T {
769        args[0].ln()
770    }
771}
772
773pub struct Ln;
774impl<T: Float> ScalarFn<T> for Ln {
775    fn name(&self) -> &str {
776        "ln"
777    }
778    fn arg_count(&self) -> usize {
779        1
780    }
781    fn call(&self, args: &[T]) -> T {
782        args[0].ln()
783    }
784}
785
786pub struct Log2;
787impl<T: Float> ScalarFn<T> for Log2 {
788    fn name(&self) -> &str {
789        "log2"
790    }
791    fn arg_count(&self) -> usize {
792        1
793    }
794    fn call(&self, args: &[T]) -> T {
795        args[0].log2()
796    }
797}
798
799pub struct Log10;
800impl<T: Float> ScalarFn<T> for Log10 {
801    fn name(&self) -> &str {
802        "log10"
803    }
804    fn arg_count(&self) -> usize {
805        1
806    }
807    fn call(&self, args: &[T]) -> T {
808        args[0].log10()
809    }
810}
811
812pub struct Pow;
813impl<T: Float> ScalarFn<T> for Pow {
814    fn name(&self) -> &str {
815        "pow"
816    }
817    fn arg_count(&self) -> usize {
818        2
819    }
820    fn call(&self, args: &[T]) -> T {
821        args[0].powf(args[1])
822    }
823}
824
825pub struct Sqrt;
826impl<T: Float> ScalarFn<T> for Sqrt {
827    fn name(&self) -> &str {
828        "sqrt"
829    }
830    fn arg_count(&self) -> usize {
831        1
832    }
833    fn call(&self, args: &[T]) -> T {
834        args[0].sqrt()
835    }
836}
837
838pub struct InverseSqrt;
839impl<T: Float> ScalarFn<T> for InverseSqrt {
840    fn name(&self) -> &str {
841        "inversesqrt"
842    }
843    fn arg_count(&self) -> usize {
844        1
845    }
846    fn call(&self, args: &[T]) -> T {
847        T::one() / args[0].sqrt()
848    }
849}
850
851// ============================================================================
852// Standard Functions - Common Math
853// ============================================================================
854
855pub struct Abs;
856impl<T: Float> ScalarFn<T> for Abs {
857    fn name(&self) -> &str {
858        "abs"
859    }
860    fn arg_count(&self) -> usize {
861        1
862    }
863    fn call(&self, args: &[T]) -> T {
864        args[0].abs()
865    }
866}
867
868pub struct Sign;
869impl<T: Float> ScalarFn<T> for Sign {
870    fn name(&self) -> &str {
871        "sign"
872    }
873    fn arg_count(&self) -> usize {
874        1
875    }
876    fn call(&self, args: &[T]) -> T {
877        let x = args[0];
878        if x > T::zero() {
879            T::one()
880        } else if x < T::zero() {
881            -T::one()
882        } else {
883            T::zero()
884        }
885    }
886}
887
888pub struct Floor;
889impl<T: Float> ScalarFn<T> for Floor {
890    fn name(&self) -> &str {
891        "floor"
892    }
893    fn arg_count(&self) -> usize {
894        1
895    }
896    fn call(&self, args: &[T]) -> T {
897        args[0].floor()
898    }
899}
900
901pub struct Ceil;
902impl<T: Float> ScalarFn<T> for Ceil {
903    fn name(&self) -> &str {
904        "ceil"
905    }
906    fn arg_count(&self) -> usize {
907        1
908    }
909    fn call(&self, args: &[T]) -> T {
910        args[0].ceil()
911    }
912}
913
914pub struct Round;
915impl<T: Float> ScalarFn<T> for Round {
916    fn name(&self) -> &str {
917        "round"
918    }
919    fn arg_count(&self) -> usize {
920        1
921    }
922    fn call(&self, args: &[T]) -> T {
923        args[0].round()
924    }
925}
926
927pub struct Trunc;
928impl<T: Float> ScalarFn<T> for Trunc {
929    fn name(&self) -> &str {
930        "trunc"
931    }
932    fn arg_count(&self) -> usize {
933        1
934    }
935    fn call(&self, args: &[T]) -> T {
936        args[0].trunc()
937    }
938}
939
940pub struct Fract;
941impl<T: Float> ScalarFn<T> for Fract {
942    fn name(&self) -> &str {
943        "fract"
944    }
945    fn arg_count(&self) -> usize {
946        1
947    }
948    fn call(&self, args: &[T]) -> T {
949        args[0].fract()
950    }
951}
952
953pub struct Min;
954impl<T: Float> ScalarFn<T> for Min {
955    fn name(&self) -> &str {
956        "min"
957    }
958    fn arg_count(&self) -> usize {
959        2
960    }
961    fn call(&self, args: &[T]) -> T {
962        args[0].min(args[1])
963    }
964}
965
966pub struct Max;
967impl<T: Float> ScalarFn<T> for Max {
968    fn name(&self) -> &str {
969        "max"
970    }
971    fn arg_count(&self) -> usize {
972        2
973    }
974    fn call(&self, args: &[T]) -> T {
975        args[0].max(args[1])
976    }
977}
978
979pub struct Clamp;
980impl<T: Float> ScalarFn<T> for Clamp {
981    fn name(&self) -> &str {
982        "clamp"
983    }
984    fn arg_count(&self) -> usize {
985        3
986    }
987    fn call(&self, args: &[T]) -> T {
988        args[0].max(args[1]).min(args[2])
989    }
990}
991
992pub struct Saturate;
993impl<T: Float> ScalarFn<T> for Saturate {
994    fn name(&self) -> &str {
995        "saturate"
996    }
997    fn arg_count(&self) -> usize {
998        1
999    }
1000    fn call(&self, args: &[T]) -> T {
1001        args[0].max(T::zero()).min(T::one())
1002    }
1003}
1004
1005// ============================================================================
1006// Standard Functions - Interpolation
1007// ============================================================================
1008
1009/// Linear interpolation: lerp(a, b, t) = a + (b - a) * t
1010pub struct Lerp;
1011impl<T: Float> ScalarFn<T> for Lerp {
1012    fn name(&self) -> &str {
1013        "lerp"
1014    }
1015    fn arg_count(&self) -> usize {
1016        3
1017    }
1018    fn call(&self, args: &[T]) -> T {
1019        let (a, b, t) = (args[0], args[1], args[2]);
1020        a + (b - a) * t
1021    }
1022}
1023
1024/// Alias for lerp (GLSL naming)
1025pub struct Mix;
1026impl<T: Float> ScalarFn<T> for Mix {
1027    fn name(&self) -> &str {
1028        "mix"
1029    }
1030    fn arg_count(&self) -> usize {
1031        3
1032    }
1033    fn call(&self, args: &[T]) -> T {
1034        let (a, b, t) = (args[0], args[1], args[2]);
1035        a + (b - a) * t
1036    }
1037}
1038
1039/// Step function: step(edge, x) = x < edge ? 0.0 : 1.0
1040pub struct Step;
1041impl<T: Float> ScalarFn<T> for Step {
1042    fn name(&self) -> &str {
1043        "step"
1044    }
1045    fn arg_count(&self) -> usize {
1046        2
1047    }
1048    fn call(&self, args: &[T]) -> T {
1049        if args[1] < args[0] {
1050            T::zero()
1051        } else {
1052            T::one()
1053        }
1054    }
1055}
1056
1057/// Smooth Hermite interpolation
1058pub struct Smoothstep;
1059impl<T: Float> ScalarFn<T> for Smoothstep {
1060    fn name(&self) -> &str {
1061        "smoothstep"
1062    }
1063    fn arg_count(&self) -> usize {
1064        3
1065    }
1066    fn call(&self, args: &[T]) -> T {
1067        let (edge0, edge1, x) = (args[0], args[1], args[2]);
1068        let t = ((x - edge0) / (edge1 - edge0)).max(T::zero()).min(T::one());
1069        let three = T::from(3.0).unwrap();
1070        let two = T::from(2.0).unwrap();
1071        t * t * (three - two * t)
1072    }
1073}
1074
1075/// Inverse lerp: inverse_lerp(a, b, v) = (v - a) / (b - a)
1076pub struct InverseLerp;
1077impl<T: Float> ScalarFn<T> for InverseLerp {
1078    fn name(&self) -> &str {
1079        "inverse_lerp"
1080    }
1081    fn arg_count(&self) -> usize {
1082        3
1083    }
1084    fn call(&self, args: &[T]) -> T {
1085        let (a, b, v) = (args[0], args[1], args[2]);
1086        (v - a) / (b - a)
1087    }
1088}
1089
1090/// Remap: remap(x, in_lo, in_hi, out_lo, out_hi)
1091pub struct Remap;
1092impl<T: Float> ScalarFn<T> for Remap {
1093    fn name(&self) -> &str {
1094        "remap"
1095    }
1096    fn arg_count(&self) -> usize {
1097        5
1098    }
1099    fn call(&self, args: &[T]) -> T {
1100        let (x, in_lo, in_hi, out_lo, out_hi) = (args[0], args[1], args[2], args[3], args[4]);
1101        let t = (x - in_lo) / (in_hi - in_lo);
1102        out_lo + (out_hi - out_lo) * t
1103    }
1104}
1105
1106// ============================================================================
1107// Standard Functions - Integer-specific
1108// ============================================================================
1109
1110/// Bitwise XOR: xor(a, b)
1111pub struct Xor;
1112impl<T: PrimInt> ScalarFn<T> for Xor {
1113    fn name(&self) -> &str {
1114        "xor"
1115    }
1116    fn arg_count(&self) -> usize {
1117        2
1118    }
1119    fn call(&self, args: &[T]) -> T {
1120        args[0] ^ args[1]
1121    }
1122}
1123
1124/// Integer abs: abs(x) for integer types
1125pub struct AbsInt;
1126impl<T: PrimInt> ScalarFn<T> for AbsInt {
1127    fn name(&self) -> &str {
1128        "abs"
1129    }
1130    fn arg_count(&self) -> usize {
1131        1
1132    }
1133    fn call(&self, args: &[T]) -> T {
1134        let x = args[0];
1135        if x < T::zero() { T::zero() - x } else { x }
1136    }
1137}
1138
1139/// Integer min: min(a, b) for integer types
1140pub struct MinInt;
1141impl<T: PrimInt> ScalarFn<T> for MinInt {
1142    fn name(&self) -> &str {
1143        "min"
1144    }
1145    fn arg_count(&self) -> usize {
1146        2
1147    }
1148    fn call(&self, args: &[T]) -> T {
1149        if args[0] < args[1] { args[0] } else { args[1] }
1150    }
1151}
1152
1153/// Integer max: max(a, b) for integer types
1154pub struct MaxInt;
1155impl<T: PrimInt> ScalarFn<T> for MaxInt {
1156    fn name(&self) -> &str {
1157        "max"
1158    }
1159    fn arg_count(&self) -> usize {
1160        2
1161    }
1162    fn call(&self, args: &[T]) -> T {
1163        if args[0] > args[1] { args[0] } else { args[1] }
1164    }
1165}
1166
1167/// Integer clamp: clamp(x, lo, hi) for integer types
1168pub struct ClampInt;
1169impl<T: PrimInt> ScalarFn<T> for ClampInt {
1170    fn name(&self) -> &str {
1171        "clamp"
1172    }
1173    fn arg_count(&self) -> usize {
1174        3
1175    }
1176    fn call(&self, args: &[T]) -> T {
1177        let (x, lo, hi) = (args[0], args[1], args[2]);
1178        if x < lo {
1179            lo
1180        } else if x > hi {
1181            hi
1182        } else {
1183            x
1184        }
1185    }
1186}
1187
1188/// Integer sign: sign(x) for integer types
1189pub struct SignInt;
1190impl<T: PrimInt> ScalarFn<T> for SignInt {
1191    fn name(&self) -> &str {
1192        "sign"
1193    }
1194    fn arg_count(&self) -> usize {
1195        1
1196    }
1197    fn call(&self, args: &[T]) -> T {
1198        let x = args[0];
1199        if x > T::zero() {
1200            T::one()
1201        } else if x < T::zero() {
1202            T::zero() - T::one()
1203        } else {
1204            T::zero()
1205        }
1206    }
1207}
1208
1209// ============================================================================
1210// Registry
1211// ============================================================================
1212
1213/// Registers all standard scalar functions into the given registry.
1214pub fn register_scalar<T: Float + 'static>(registry: &mut FunctionRegistry<T>) {
1215    // Constants
1216    registry.register(Pi);
1217    registry.register(E);
1218    registry.register(Tau);
1219
1220    // Trigonometric
1221    registry.register(Sin);
1222    registry.register(Cos);
1223    registry.register(Tan);
1224    registry.register(Asin);
1225    registry.register(Acos);
1226    registry.register(Atan);
1227    registry.register(Atan2);
1228    registry.register(Sinh);
1229    registry.register(Cosh);
1230    registry.register(Tanh);
1231
1232    // Exponential / logarithmic
1233    registry.register(Exp);
1234    registry.register(Exp2);
1235    registry.register(Log);
1236    registry.register(Ln);
1237    registry.register(Log2);
1238    registry.register(Log10);
1239    registry.register(Pow);
1240    registry.register(Sqrt);
1241    registry.register(InverseSqrt);
1242
1243    // Common math
1244    registry.register(Abs);
1245    registry.register(Sign);
1246    registry.register(Floor);
1247    registry.register(Ceil);
1248    registry.register(Round);
1249    registry.register(Trunc);
1250    registry.register(Fract);
1251    registry.register(Min);
1252    registry.register(Max);
1253    registry.register(Clamp);
1254    registry.register(Saturate);
1255
1256    // Interpolation
1257    registry.register(Lerp);
1258    registry.register(Mix);
1259    registry.register(Step);
1260    registry.register(Smoothstep);
1261    registry.register(InverseLerp);
1262    registry.register(Remap);
1263}
1264
1265/// Creates a new registry with all standard scalar functions.
1266pub fn scalar_registry<T: Float + 'static>() -> FunctionRegistry<T> {
1267    let mut registry = FunctionRegistry::new();
1268    register_scalar(&mut registry);
1269    registry
1270}
1271
1272/// Registers standard integer functions into the given registry.
1273///
1274/// Includes: abs, min, max, clamp, sign, xor
1275pub fn register_scalar_int<T: PrimInt + 'static>(registry: &mut FunctionRegistry<T>) {
1276    registry.register(AbsInt);
1277    registry.register(MinInt);
1278    registry.register(MaxInt);
1279    registry.register(ClampInt);
1280    registry.register(SignInt);
1281    registry.register(Xor);
1282}
1283
1284/// Creates a new registry with standard integer functions.
1285pub fn scalar_registry_int<T: PrimInt + 'static>() -> FunctionRegistry<T> {
1286    let mut registry = FunctionRegistry::new();
1287    register_scalar_int(&mut registry);
1288    registry
1289}
1290
1291// ============================================================================
1292// Tests
1293// ============================================================================
1294
1295#[cfg(test)]
1296mod tests {
1297    use super::*;
1298    use wick_core::Expr;
1299
1300    fn eval_expr(expr: &str, vars: &[(&str, f32)]) -> f32 {
1301        let registry = scalar_registry();
1302        let expr = Expr::parse(expr).unwrap();
1303        let var_map: HashMap<String, f32> = vars.iter().map(|(k, v)| (k.to_string(), *v)).collect();
1304        eval(expr.ast(), &var_map, &registry).unwrap()
1305    }
1306
1307    #[test]
1308    fn test_constants() {
1309        assert!((eval_expr("pi()", &[]) - std::f32::consts::PI).abs() < 0.001);
1310        assert!((eval_expr("e()", &[]) - std::f32::consts::E).abs() < 0.001);
1311        assert!((eval_expr("tau()", &[]) - std::f32::consts::TAU).abs() < 0.001);
1312    }
1313
1314    #[test]
1315    fn test_trig() {
1316        assert!(eval_expr("sin(0)", &[]).abs() < 0.001);
1317        assert!((eval_expr("cos(0)", &[]) - 1.0).abs() < 0.001);
1318    }
1319
1320    #[test]
1321    fn test_exp_log() {
1322        assert!((eval_expr("exp(0)", &[]) - 1.0).abs() < 0.001);
1323        assert!((eval_expr("ln(1)", &[]) - 0.0).abs() < 0.001);
1324        assert!((eval_expr("sqrt(16)", &[]) - 4.0).abs() < 0.001);
1325    }
1326
1327    #[test]
1328    fn test_common() {
1329        assert_eq!(eval_expr("abs(-5)", &[]), 5.0);
1330        assert_eq!(eval_expr("floor(3.7)", &[]), 3.0);
1331        assert_eq!(eval_expr("ceil(3.2)", &[]), 4.0);
1332        assert_eq!(eval_expr("min(3, 7)", &[]), 3.0);
1333        assert_eq!(eval_expr("max(3, 7)", &[]), 7.0);
1334        assert_eq!(eval_expr("clamp(5, 0, 3)", &[]), 3.0);
1335        assert_eq!(eval_expr("saturate(1.5)", &[]), 1.0);
1336    }
1337
1338    #[test]
1339    fn test_interpolation() {
1340        assert_eq!(eval_expr("lerp(0, 10, 0.5)", &[]), 5.0);
1341        assert_eq!(eval_expr("mix(0, 10, 0.5)", &[]), 5.0);
1342        assert_eq!(eval_expr("step(0.5, 0.3)", &[]), 0.0);
1343        assert_eq!(eval_expr("step(0.5, 0.7)", &[]), 1.0);
1344        assert!((eval_expr("smoothstep(0, 1, 0.5)", &[]) - 0.5).abs() < 0.1);
1345        assert_eq!(eval_expr("inverse_lerp(0, 10, 5)", &[]), 0.5);
1346    }
1347
1348    #[test]
1349    fn test_remap() {
1350        assert_eq!(eval_expr("remap(5, 0, 10, 0, 100)", &[]), 50.0);
1351    }
1352
1353    #[test]
1354    fn test_with_variables() {
1355        let v = eval_expr("sin(x * pi())", &[("x", 0.5)]);
1356        assert!((v - 1.0).abs() < 0.001);
1357    }
1358
1359    #[test]
1360    fn test_f64() {
1361        let registry: FunctionRegistry<f64> = scalar_registry();
1362        let expr = Expr::parse("sin(x) + 1").unwrap();
1363        let vars: HashMap<String, f64> = [("x".to_string(), 0.0)].into();
1364        let result = eval(expr.ast(), &vars, &registry).unwrap();
1365        assert!((result - 1.0).abs() < 0.001);
1366    }
1367
1368    // Integer expression tests
1369    mod int_tests {
1370        use super::*;
1371
1372        fn eval_int_expr(expr_str: &str, vars: &[(&str, i32)]) -> i32 {
1373            let registry = scalar_registry_int();
1374            let expr = Expr::parse(expr_str).unwrap();
1375            let var_map: HashMap<String, i32> =
1376                vars.iter().map(|(k, v)| (k.to_string(), *v)).collect();
1377            eval_int(expr.ast(), &var_map, &registry).unwrap()
1378        }
1379
1380        #[test]
1381        fn test_int_arithmetic() {
1382            assert_eq!(eval_int_expr("5 + 3", &[]), 8);
1383            assert_eq!(eval_int_expr("10 - 4", &[]), 6);
1384            assert_eq!(eval_int_expr("6 * 7", &[]), 42);
1385            assert_eq!(eval_int_expr("15 / 4", &[]), 3); // Integer division
1386        }
1387
1388        #[test]
1389        fn test_int_modulo() {
1390            assert_eq!(eval_int_expr("8 % 3", &[]), 2);
1391            assert_eq!(eval_int_expr("10 % 5", &[]), 0);
1392            assert_eq!(eval_int_expr("17 % 7", &[]), 3);
1393        }
1394
1395        #[test]
1396        fn test_int_power() {
1397            assert_eq!(eval_int_expr("2 ^ 3", &[]), 8);
1398            assert_eq!(eval_int_expr("3 ^ 4", &[]), 81);
1399            assert_eq!(eval_int_expr("5 ^ 0", &[]), 1);
1400        }
1401
1402        #[test]
1403        fn test_int_bitwise() {
1404            assert_eq!(eval_int_expr("5 & 3", &[]), 1); // 0101 & 0011 = 0001
1405            assert_eq!(eval_int_expr("5 | 3", &[]), 7); // 0101 | 0011 = 0111
1406            assert_eq!(eval_int_expr("xor(5, 3)", &[]), 6); // 0101 ^ 0011 = 0110
1407            assert_eq!(eval_int_expr("1 << 4", &[]), 16);
1408            assert_eq!(eval_int_expr("16 >> 2", &[]), 4);
1409        }
1410
1411        #[test]
1412        fn test_int_bitnot() {
1413            // ~0 for i32 is -1
1414            assert_eq!(eval_int_expr("~0", &[]), -1);
1415        }
1416
1417        #[test]
1418        fn test_int_functions() {
1419            assert_eq!(eval_int_expr("abs(-5)", &[]), 5);
1420            assert_eq!(eval_int_expr("min(3, 7)", &[]), 3);
1421            assert_eq!(eval_int_expr("max(3, 7)", &[]), 7);
1422            assert_eq!(eval_int_expr("clamp(5, 0, 3)", &[]), 3);
1423            assert_eq!(eval_int_expr("sign(-10)", &[]), -1);
1424            assert_eq!(eval_int_expr("sign(10)", &[]), 1);
1425            assert_eq!(eval_int_expr("sign(0)", &[]), 0);
1426        }
1427
1428        #[test]
1429        fn test_int_with_variables() {
1430            assert_eq!(eval_int_expr("x + y", &[("x", 5), ("y", 3)]), 8);
1431            assert_eq!(
1432                eval_int_expr("steps % beats", &[("steps", 8), ("beats", 3)]),
1433                2
1434            );
1435        }
1436
1437        #[test]
1438        fn test_int_fractional_literal_error() {
1439            let registry: FunctionRegistry<i32> = scalar_registry_int();
1440            let expr = Expr::parse("3.14 + 1").unwrap();
1441            let result = eval_int(expr.ast(), &HashMap::new(), &registry);
1442            assert!(matches!(result, Err(Error::InvalidLiteral(_))));
1443        }
1444
1445        #[test]
1446        fn test_int_negative_exponent_error() {
1447            let registry: FunctionRegistry<i32> = scalar_registry_int();
1448            let expr = Expr::parse("2 ^ -1").unwrap();
1449            let vars: HashMap<String, i32> = HashMap::new();
1450            let result = eval_int(expr.ast(), &vars, &registry);
1451            assert!(matches!(result, Err(Error::NegativeExponent)));
1452        }
1453    }
1454}