1use num_traits::Float;
77use std::collections::HashMap;
78use std::sync::Arc;
79use wick_core::{Ast, BinOp, CompareOp, UnaryOp};
80
81mod funcs;
82pub mod ops;
83#[cfg(test)]
84mod parity_tests;
85
86#[cfg(feature = "wgsl")]
87pub mod wgsl;
88
89#[cfg(feature = "glsl")]
90pub mod glsl;
91
92#[cfg(feature = "rust")]
93pub mod rust;
94
95#[cfg(feature = "c")]
96pub mod c;
97
98#[cfg(feature = "opencl")]
99pub mod opencl;
100
101#[cfg(feature = "cuda")]
102pub mod cuda;
103
104#[cfg(feature = "hip")]
105pub mod hip;
106
107#[cfg(feature = "tokenstream")]
108pub mod tokenstream;
109
110#[cfg(feature = "lua-codegen")]
111pub mod lua;
112
113#[cfg(feature = "cranelift")]
114pub mod cranelift;
115
116#[cfg(feature = "optimize")]
117pub mod optimize;
118
119pub use funcs::{
120 AxisAngle, Conj, Dot, Inverse, Length, Lerp, Normalize, QuatConstructor, Rotate, Slerp,
121 Vec3Constructor, quaternion_registry, register_quaternion,
122};
123
124#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
130pub enum Type {
131 Scalar,
133 Vec3,
135 Quaternion,
137}
138
139impl std::fmt::Display for Type {
140 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141 match self {
142 Type::Scalar => write!(f, "scalar"),
143 Type::Vec3 => write!(f, "vec3"),
144 Type::Quaternion => write!(f, "quaternion"),
145 }
146 }
147}
148
149pub trait QuaternionValue<T: Float>: Clone + PartialEq + Sized + std::fmt::Debug {
157 fn typ(&self) -> Type;
159
160 fn from_scalar(v: T) -> Self;
162 fn from_vec3(v: [T; 3]) -> Self;
163 fn from_quaternion(q: [T; 4]) -> Self;
164
165 fn as_scalar(&self) -> Option<T>;
167 fn as_vec3(&self) -> Option<[T; 3]>;
168 fn as_quaternion(&self) -> Option<[T; 4]>;
169}
170
171#[derive(Debug, Clone, PartialEq)]
179pub enum Value<T> {
180 Scalar(T),
182 Vec3([T; 3]),
184 Quaternion([T; 4]),
186}
187
188impl<T> Value<T> {
189 pub fn typ(&self) -> Type {
191 match self {
192 Value::Scalar(_) => Type::Scalar,
193 Value::Vec3(_) => Type::Vec3,
194 Value::Quaternion(_) => Type::Quaternion,
195 }
196 }
197}
198
199impl<T: Copy> Value<T> {
200 pub fn as_scalar(&self) -> Option<T> {
202 match self {
203 Value::Scalar(v) => Some(*v),
204 _ => None,
205 }
206 }
207
208 pub fn as_vec3(&self) -> Option<[T; 3]> {
210 match self {
211 Value::Vec3(v) => Some(*v),
212 _ => None,
213 }
214 }
215
216 pub fn as_quaternion(&self) -> Option<[T; 4]> {
218 match self {
219 Value::Quaternion(q) => Some(*q),
220 _ => None,
221 }
222 }
223}
224
225impl<T: Float + std::fmt::Debug> QuaternionValue<T> for Value<T> {
226 fn typ(&self) -> Type {
227 Value::typ(self)
228 }
229
230 fn from_scalar(v: T) -> Self {
231 Value::Scalar(v)
232 }
233
234 fn from_vec3(v: [T; 3]) -> Self {
235 Value::Vec3(v)
236 }
237
238 fn from_quaternion(q: [T; 4]) -> Self {
239 Value::Quaternion(q)
240 }
241
242 fn as_scalar(&self) -> Option<T> {
243 Value::as_scalar(self)
244 }
245
246 fn as_vec3(&self) -> Option<[T; 3]> {
247 Value::as_vec3(self)
248 }
249
250 fn as_quaternion(&self) -> Option<[T; 4]> {
251 Value::as_quaternion(self)
252 }
253}
254
255#[derive(Debug, Clone, PartialEq)]
261pub enum Error {
262 UnknownVariable(String),
264 UnknownFunction(String),
266 BinaryTypeMismatch { op: BinOp, left: Type, right: Type },
268 UnaryTypeMismatch { op: UnaryOp, operand: Type },
270 WrongArgCount {
272 func: String,
273 expected: usize,
274 got: usize,
275 },
276 FunctionTypeMismatch {
278 func: String,
279 expected: Vec<Type>,
280 got: Vec<Type>,
281 },
282 UnsupportedTypeForConditional(Type),
284}
285
286impl std::fmt::Display for Error {
287 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
288 match self {
289 Error::UnknownVariable(name) => write!(f, "unknown variable: '{name}'"),
290 Error::UnknownFunction(name) => write!(f, "unknown function: '{name}'"),
291 Error::BinaryTypeMismatch { op, left, right } => {
292 write!(f, "cannot apply {op:?} to {left} and {right}")
293 }
294 Error::UnaryTypeMismatch { op, operand } => {
295 write!(f, "cannot apply {op:?} to {operand}")
296 }
297 Error::WrongArgCount {
298 func,
299 expected,
300 got,
301 } => {
302 write!(f, "function '{func}' expects {expected} args, got {got}")
303 }
304 Error::FunctionTypeMismatch {
305 func,
306 expected,
307 got,
308 } => {
309 write!(
310 f,
311 "function '{func}' expects types {expected:?}, got {got:?}"
312 )
313 }
314 Error::UnsupportedTypeForConditional(t) => {
315 write!(f, "conditionals require scalar type, got {t}")
316 }
317 }
318 }
319}
320
321impl std::error::Error for Error {}
322
323#[derive(Debug, Clone, PartialEq)]
329pub struct Signature {
330 pub args: Vec<Type>,
331 pub ret: Type,
332}
333
334pub trait QuaternionFn<T, V>: Send + Sync
338where
339 T: Float,
340 V: QuaternionValue<T>,
341{
342 fn name(&self) -> &str;
344
345 fn signatures(&self) -> Vec<Signature>;
347
348 fn call(&self, args: &[V]) -> V;
350}
351
352#[derive(Clone)]
354pub struct FunctionRegistry<T, V>
355where
356 T: Float,
357 V: QuaternionValue<T>,
358{
359 funcs: HashMap<String, Arc<dyn QuaternionFn<T, V>>>,
360}
361
362impl<T, V> Default for FunctionRegistry<T, V>
363where
364 T: Float,
365 V: QuaternionValue<T>,
366{
367 fn default() -> Self {
368 Self {
369 funcs: HashMap::new(),
370 }
371 }
372}
373
374impl<T, V> FunctionRegistry<T, V>
375where
376 T: Float,
377 V: QuaternionValue<T>,
378{
379 pub fn new() -> Self {
380 Self::default()
381 }
382
383 pub fn register<F: QuaternionFn<T, V> + 'static>(&mut self, func: F) {
384 self.funcs.insert(func.name().to_string(), Arc::new(func));
385 }
386
387 pub fn get(&self, name: &str) -> Option<&Arc<dyn QuaternionFn<T, V>>> {
388 self.funcs.get(name)
389 }
390}
391
392pub fn eval<T, V>(
400 ast: &Ast,
401 vars: &HashMap<String, V>,
402 funcs: &FunctionRegistry<T, V>,
403) -> Result<V, Error>
404where
405 T: Float,
406 V: QuaternionValue<T>,
407{
408 match ast {
409 Ast::Num(n) => Ok(V::from_scalar(T::from(*n).unwrap())),
410
411 Ast::Var(name) => vars
412 .get(name)
413 .cloned()
414 .ok_or_else(|| Error::UnknownVariable(name.clone())),
415
416 Ast::BinOp(op, left, right) => {
417 let left_val = eval(left, vars, funcs)?;
418 let right_val = eval(right, vars, funcs)?;
419 ops::apply_binop(*op, left_val, right_val)
420 }
421
422 Ast::UnaryOp(op, inner) => {
423 let val = eval(inner, vars, funcs)?;
424 ops::apply_unaryop(*op, val)
425 }
426
427 Ast::Call(name, args) => {
428 let func = funcs
429 .get(name)
430 .ok_or_else(|| Error::UnknownFunction(name.clone()))?;
431
432 let arg_vals: Vec<V> = args
433 .iter()
434 .map(|a| eval(a, vars, funcs))
435 .collect::<Result<_, _>>()?;
436
437 let arg_types: Vec<Type> = arg_vals.iter().map(|v| v.typ()).collect();
438
439 let matched = func.signatures().iter().any(|sig| sig.args == arg_types);
441 if !matched {
442 return Err(Error::FunctionTypeMismatch {
443 func: name.clone(),
444 expected: func
445 .signatures()
446 .first()
447 .map(|s| s.args.clone())
448 .unwrap_or_default(),
449 got: arg_types,
450 });
451 }
452
453 Ok(func.call(&arg_vals))
454 }
455
456 Ast::Compare(op, left, right) => {
457 let left_val = eval(left, vars, funcs)?;
458 let right_val = eval(right, vars, funcs)?;
459 match (left_val.as_scalar(), right_val.as_scalar()) {
460 (Some(l), Some(r)) => {
461 let result = match op {
462 CompareOp::Lt => l < r,
463 CompareOp::Le => l <= r,
464 CompareOp::Gt => l > r,
465 CompareOp::Ge => l >= r,
466 CompareOp::Eq => l == r,
467 CompareOp::Ne => l != r,
468 };
469 Ok(V::from_scalar(if result { T::one() } else { T::zero() }))
470 }
471 _ => Err(Error::UnsupportedTypeForConditional(left_val.typ())),
472 }
473 }
474
475 Ast::And(left, right) => {
476 let left_val = eval(left, vars, funcs)?;
477 let right_val = eval(right, vars, funcs)?;
478 match (left_val.as_scalar(), right_val.as_scalar()) {
479 (Some(l), Some(r)) => {
480 let result = !l.is_zero() && !r.is_zero();
481 Ok(V::from_scalar(if result { T::one() } else { T::zero() }))
482 }
483 _ => Err(Error::UnsupportedTypeForConditional(left_val.typ())),
484 }
485 }
486
487 Ast::Or(left, right) => {
488 let left_val = eval(left, vars, funcs)?;
489 let right_val = eval(right, vars, funcs)?;
490 match (left_val.as_scalar(), right_val.as_scalar()) {
491 (Some(l), Some(r)) => {
492 let result = !l.is_zero() || !r.is_zero();
493 Ok(V::from_scalar(if result { T::one() } else { T::zero() }))
494 }
495 _ => Err(Error::UnsupportedTypeForConditional(left_val.typ())),
496 }
497 }
498
499 Ast::If(cond, then_ast, else_ast) => {
500 let cond_val = eval(cond, vars, funcs)?;
501 if let Some(c) = cond_val.as_scalar() {
502 if !c.is_zero() {
503 eval(then_ast, vars, funcs)
504 } else {
505 eval(else_ast, vars, funcs)
506 }
507 } else {
508 Err(Error::UnsupportedTypeForConditional(cond_val.typ()))
509 }
510 }
511
512 Ast::Let { name, value, body } => {
513 let val = eval(value, vars, funcs)?;
514 let mut new_vars = vars.clone();
515 new_vars.insert(name.clone(), val);
516 eval(body, &new_vars, funcs)
517 }
518 }
519}
520
521#[cfg(test)]
526mod tests {
527 use super::*;
528 use wick_core::Expr;
529
530 fn eval_expr(expr: &str, vars: &[(&str, Value<f32>)]) -> Result<Value<f32>, Error> {
531 let expr = Expr::parse(expr).unwrap();
532 let var_map: HashMap<String, Value<f32>> = vars
533 .iter()
534 .map(|(k, v)| (k.to_string(), v.clone()))
535 .collect();
536 let registry = quaternion_registry();
537 eval(expr.ast(), &var_map, ®istry)
538 }
539
540 #[test]
541 fn test_quaternion_add() {
542 let result = eval_expr(
543 "a + b",
544 &[
545 ("a", Value::Quaternion([1.0, 2.0, 3.0, 4.0])),
546 ("b", Value::Quaternion([5.0, 6.0, 7.0, 8.0])),
547 ],
548 );
549 assert_eq!(result.unwrap(), Value::Quaternion([6.0, 8.0, 10.0, 12.0]));
550 }
551
552 #[test]
553 fn test_quaternion_mul() {
554 let result = eval_expr(
557 "a * b",
558 &[
559 ("a", Value::Quaternion([1.0, 2.0, 3.0, 4.0])),
560 ("b", Value::Quaternion([0.0, 0.0, 0.0, 1.0])),
561 ],
562 );
563 assert_eq!(result.unwrap(), Value::Quaternion([1.0, 2.0, 3.0, 4.0]));
564 }
565
566 #[test]
567 fn test_quaternion_neg() {
568 let result = eval_expr("-q", &[("q", Value::Quaternion([1.0, 2.0, 3.0, 4.0]))]);
569 assert_eq!(result.unwrap(), Value::Quaternion([-1.0, -2.0, -3.0, -4.0]));
570 }
571
572 #[test]
573 fn test_quaternion_scalar_mul() {
574 let result = eval_expr(
575 "s * q",
576 &[
577 ("s", Value::Scalar(2.0)),
578 ("q", Value::Quaternion([1.0, 2.0, 3.0, 4.0])),
579 ],
580 );
581 assert_eq!(result.unwrap(), Value::Quaternion([2.0, 4.0, 6.0, 8.0]));
582 }
583}