Skip to main content

wick_linalg/
funcs.rs

1//! Standard linalg functions: dot, cross, normalize, length, etc.
2
3// `mut` is needed when 4d feature is enabled for sigs.push()
4#![allow(unused_mut)]
5
6use crate::{LinalgFn, LinalgValue, Signature, Type};
7use num_traits::Float;
8use wick_core::Numeric;
9
10// ============================================================================
11// Dot product
12// ============================================================================
13
14/// Dot product: dot(a, b) -> scalar
15pub struct Dot;
16
17impl<T, V> LinalgFn<T, V> for Dot
18where
19    T: Numeric,
20    V: LinalgValue<T>,
21{
22    fn name(&self) -> &str {
23        "dot"
24    }
25
26    fn signatures(&self) -> Vec<Signature> {
27        let mut sigs = vec![Signature {
28            args: vec![Type::Vec2, Type::Vec2],
29            ret: Type::Scalar,
30        }];
31        #[cfg(feature = "3d")]
32        sigs.push(Signature {
33            args: vec![Type::Vec3, Type::Vec3],
34            ret: Type::Scalar,
35        });
36        #[cfg(feature = "4d")]
37        sigs.push(Signature {
38            args: vec![Type::Vec4, Type::Vec4],
39            ret: Type::Scalar,
40        });
41        sigs
42    }
43
44    fn call(&self, args: &[V]) -> V {
45        match (args[0].typ(), args[1].typ()) {
46            (Type::Vec2, Type::Vec2) => {
47                let a = args[0].as_vec2().unwrap();
48                let b = args[1].as_vec2().unwrap();
49                V::from_scalar(a[0] * b[0] + a[1] * b[1])
50            }
51            #[cfg(feature = "3d")]
52            (Type::Vec3, Type::Vec3) => {
53                let a = args[0].as_vec3().unwrap();
54                let b = args[1].as_vec3().unwrap();
55                V::from_scalar(a[0] * b[0] + a[1] * b[1] + a[2] * b[2])
56            }
57            #[cfg(feature = "4d")]
58            (Type::Vec4, Type::Vec4) => {
59                let a = args[0].as_vec4().unwrap();
60                let b = args[1].as_vec4().unwrap();
61                V::from_scalar(a[0] * b[0] + a[1] * b[1] + a[2] * b[2] + a[3] * b[3])
62            }
63            _ => unreachable!("signature mismatch"),
64        }
65    }
66}
67
68// ============================================================================
69// Cross product (3D only)
70// ============================================================================
71
72/// Cross product: cross(a, b) -> vec3
73#[cfg(feature = "3d")]
74pub struct Cross;
75
76#[cfg(feature = "3d")]
77impl<T, V> LinalgFn<T, V> for Cross
78where
79    T: Numeric,
80    V: LinalgValue<T>,
81{
82    fn name(&self) -> &str {
83        "cross"
84    }
85
86    fn signatures(&self) -> Vec<Signature> {
87        vec![Signature {
88            args: vec![Type::Vec3, Type::Vec3],
89            ret: Type::Vec3,
90        }]
91    }
92
93    fn call(&self, args: &[V]) -> V {
94        let a = args[0].as_vec3().unwrap();
95        let b = args[1].as_vec3().unwrap();
96        V::from_vec3([
97            a[1] * b[2] - a[2] * b[1],
98            a[2] * b[0] - a[0] * b[2],
99            a[0] * b[1] - a[1] * b[0],
100        ])
101    }
102}
103
104// ============================================================================
105// Length
106// ============================================================================
107
108/// Vector length: length(v) -> scalar
109pub struct Length;
110
111impl<T, V> LinalgFn<T, V> for Length
112where
113    T: Float + Numeric,
114    V: LinalgValue<T>,
115{
116    fn name(&self) -> &str {
117        "length"
118    }
119
120    fn signatures(&self) -> Vec<Signature> {
121        let mut sigs = vec![Signature {
122            args: vec![Type::Vec2],
123            ret: Type::Scalar,
124        }];
125        #[cfg(feature = "3d")]
126        sigs.push(Signature {
127            args: vec![Type::Vec3],
128            ret: Type::Scalar,
129        });
130        #[cfg(feature = "4d")]
131        sigs.push(Signature {
132            args: vec![Type::Vec4],
133            ret: Type::Scalar,
134        });
135        sigs
136    }
137
138    fn call(&self, args: &[V]) -> V {
139        match args[0].typ() {
140            Type::Vec2 => {
141                let v = args[0].as_vec2().unwrap();
142                V::from_scalar((v[0] * v[0] + v[1] * v[1]).sqrt())
143            }
144            #[cfg(feature = "3d")]
145            Type::Vec3 => {
146                let v = args[0].as_vec3().unwrap();
147                V::from_scalar((v[0] * v[0] + v[1] * v[1] + v[2] * v[2]).sqrt())
148            }
149            #[cfg(feature = "4d")]
150            Type::Vec4 => {
151                let v = args[0].as_vec4().unwrap();
152                V::from_scalar((v[0] * v[0] + v[1] * v[1] + v[2] * v[2] + v[3] * v[3]).sqrt())
153            }
154            _ => unreachable!("signature mismatch"),
155        }
156    }
157}
158
159// ============================================================================
160// Normalize
161// ============================================================================
162
163/// Normalize vector: normalize(v) -> vec (same type, unit length)
164pub struct Normalize;
165
166impl<T, V> LinalgFn<T, V> for Normalize
167where
168    T: Float + Numeric,
169    V: LinalgValue<T>,
170{
171    fn name(&self) -> &str {
172        "normalize"
173    }
174
175    fn signatures(&self) -> Vec<Signature> {
176        let mut sigs = vec![Signature {
177            args: vec![Type::Vec2],
178            ret: Type::Vec2,
179        }];
180        #[cfg(feature = "3d")]
181        sigs.push(Signature {
182            args: vec![Type::Vec3],
183            ret: Type::Vec3,
184        });
185        #[cfg(feature = "4d")]
186        sigs.push(Signature {
187            args: vec![Type::Vec4],
188            ret: Type::Vec4,
189        });
190        sigs
191    }
192
193    fn call(&self, args: &[V]) -> V {
194        match args[0].typ() {
195            Type::Vec2 => {
196                let v = args[0].as_vec2().unwrap();
197                let len = (v[0] * v[0] + v[1] * v[1]).sqrt();
198                V::from_vec2([v[0] / len, v[1] / len])
199            }
200            #[cfg(feature = "3d")]
201            Type::Vec3 => {
202                let v = args[0].as_vec3().unwrap();
203                let len = (v[0] * v[0] + v[1] * v[1] + v[2] * v[2]).sqrt();
204                V::from_vec3([v[0] / len, v[1] / len, v[2] / len])
205            }
206            #[cfg(feature = "4d")]
207            Type::Vec4 => {
208                let v = args[0].as_vec4().unwrap();
209                let len = (v[0] * v[0] + v[1] * v[1] + v[2] * v[2] + v[3] * v[3]).sqrt();
210                V::from_vec4([v[0] / len, v[1] / len, v[2] / len, v[3] / len])
211            }
212            _ => unreachable!("signature mismatch"),
213        }
214    }
215}
216
217// ============================================================================
218// Distance
219// ============================================================================
220
221/// Distance between two points: distance(a, b) -> scalar
222pub struct Distance;
223
224impl<T, V> LinalgFn<T, V> for Distance
225where
226    T: Float + Numeric,
227    V: LinalgValue<T>,
228{
229    fn name(&self) -> &str {
230        "distance"
231    }
232
233    fn signatures(&self) -> Vec<Signature> {
234        let mut sigs = vec![Signature {
235            args: vec![Type::Vec2, Type::Vec2],
236            ret: Type::Scalar,
237        }];
238        #[cfg(feature = "3d")]
239        sigs.push(Signature {
240            args: vec![Type::Vec3, Type::Vec3],
241            ret: Type::Scalar,
242        });
243        #[cfg(feature = "4d")]
244        sigs.push(Signature {
245            args: vec![Type::Vec4, Type::Vec4],
246            ret: Type::Scalar,
247        });
248        sigs
249    }
250
251    fn call(&self, args: &[V]) -> V {
252        match (args[0].typ(), args[1].typ()) {
253            (Type::Vec2, Type::Vec2) => {
254                let a = args[0].as_vec2().unwrap();
255                let b = args[1].as_vec2().unwrap();
256                let dx = a[0] - b[0];
257                let dy = a[1] - b[1];
258                V::from_scalar((dx * dx + dy * dy).sqrt())
259            }
260            #[cfg(feature = "3d")]
261            (Type::Vec3, Type::Vec3) => {
262                let a = args[0].as_vec3().unwrap();
263                let b = args[1].as_vec3().unwrap();
264                let dx = a[0] - b[0];
265                let dy = a[1] - b[1];
266                let dz = a[2] - b[2];
267                V::from_scalar((dx * dx + dy * dy + dz * dz).sqrt())
268            }
269            #[cfg(feature = "4d")]
270            (Type::Vec4, Type::Vec4) => {
271                let a = args[0].as_vec4().unwrap();
272                let b = args[1].as_vec4().unwrap();
273                let dx = a[0] - b[0];
274                let dy = a[1] - b[1];
275                let dz = a[2] - b[2];
276                let dw = a[3] - b[3];
277                V::from_scalar((dx * dx + dy * dy + dz * dz + dw * dw).sqrt())
278            }
279            _ => unreachable!("signature mismatch"),
280        }
281    }
282}
283
284// ============================================================================
285// Reflect
286// ============================================================================
287
288/// Reflect vector: reflect(incident, normal) -> vec
289/// Returns incident - 2 * dot(normal, incident) * normal
290pub struct Reflect;
291
292impl<T, V> LinalgFn<T, V> for Reflect
293where
294    T: Float + Numeric,
295    V: LinalgValue<T>,
296{
297    fn name(&self) -> &str {
298        "reflect"
299    }
300
301    fn signatures(&self) -> Vec<Signature> {
302        let mut sigs = vec![Signature {
303            args: vec![Type::Vec2, Type::Vec2],
304            ret: Type::Vec2,
305        }];
306        #[cfg(feature = "3d")]
307        sigs.push(Signature {
308            args: vec![Type::Vec3, Type::Vec3],
309            ret: Type::Vec3,
310        });
311        #[cfg(feature = "4d")]
312        sigs.push(Signature {
313            args: vec![Type::Vec4, Type::Vec4],
314            ret: Type::Vec4,
315        });
316        sigs
317    }
318
319    fn call(&self, args: &[V]) -> V {
320        let two = T::from(2.0).unwrap();
321        match (args[0].typ(), args[1].typ()) {
322            (Type::Vec2, Type::Vec2) => {
323                let i = args[0].as_vec2().unwrap();
324                let n = args[1].as_vec2().unwrap();
325                let d = i[0] * n[0] + i[1] * n[1];
326                V::from_vec2([i[0] - two * d * n[0], i[1] - two * d * n[1]])
327            }
328            #[cfg(feature = "3d")]
329            (Type::Vec3, Type::Vec3) => {
330                let i = args[0].as_vec3().unwrap();
331                let n = args[1].as_vec3().unwrap();
332                let d = i[0] * n[0] + i[1] * n[1] + i[2] * n[2];
333                V::from_vec3([
334                    i[0] - two * d * n[0],
335                    i[1] - two * d * n[1],
336                    i[2] - two * d * n[2],
337                ])
338            }
339            #[cfg(feature = "4d")]
340            (Type::Vec4, Type::Vec4) => {
341                let i = args[0].as_vec4().unwrap();
342                let n = args[1].as_vec4().unwrap();
343                let d = i[0] * n[0] + i[1] * n[1] + i[2] * n[2] + i[3] * n[3];
344                V::from_vec4([
345                    i[0] - two * d * n[0],
346                    i[1] - two * d * n[1],
347                    i[2] - two * d * n[2],
348                    i[3] - two * d * n[3],
349                ])
350            }
351            _ => unreachable!("signature mismatch"),
352        }
353    }
354}
355
356// ============================================================================
357// Hadamard (element-wise multiply)
358// ============================================================================
359
360/// Element-wise vector multiply: hadamard(a, b) -> vec
361pub struct Hadamard;
362
363impl<T, V> LinalgFn<T, V> for Hadamard
364where
365    T: Numeric,
366    V: LinalgValue<T>,
367{
368    fn name(&self) -> &str {
369        "hadamard"
370    }
371
372    fn signatures(&self) -> Vec<Signature> {
373        let mut sigs = vec![Signature {
374            args: vec![Type::Vec2, Type::Vec2],
375            ret: Type::Vec2,
376        }];
377        #[cfg(feature = "3d")]
378        sigs.push(Signature {
379            args: vec![Type::Vec3, Type::Vec3],
380            ret: Type::Vec3,
381        });
382        #[cfg(feature = "4d")]
383        sigs.push(Signature {
384            args: vec![Type::Vec4, Type::Vec4],
385            ret: Type::Vec4,
386        });
387        sigs
388    }
389
390    fn call(&self, args: &[V]) -> V {
391        match (args[0].typ(), args[1].typ()) {
392            (Type::Vec2, Type::Vec2) => {
393                let a = args[0].as_vec2().unwrap();
394                let b = args[1].as_vec2().unwrap();
395                V::from_vec2([a[0] * b[0], a[1] * b[1]])
396            }
397            #[cfg(feature = "3d")]
398            (Type::Vec3, Type::Vec3) => {
399                let a = args[0].as_vec3().unwrap();
400                let b = args[1].as_vec3().unwrap();
401                V::from_vec3([a[0] * b[0], a[1] * b[1], a[2] * b[2]])
402            }
403            #[cfg(feature = "4d")]
404            (Type::Vec4, Type::Vec4) => {
405                let a = args[0].as_vec4().unwrap();
406                let b = args[1].as_vec4().unwrap();
407                V::from_vec4([a[0] * b[0], a[1] * b[1], a[2] * b[2], a[3] * b[3]])
408            }
409            _ => unreachable!("signature mismatch"),
410        }
411    }
412}
413
414// ============================================================================
415// Lerp (linear interpolation for vectors)
416// ============================================================================
417
418/// Linear interpolation: lerp(a, b, t) -> vec
419/// Returns a + (b - a) * t
420pub struct Lerp;
421
422impl<T, V> LinalgFn<T, V> for Lerp
423where
424    T: Numeric,
425    V: LinalgValue<T>,
426{
427    fn name(&self) -> &str {
428        "lerp"
429    }
430
431    fn signatures(&self) -> Vec<Signature> {
432        let mut sigs = vec![Signature {
433            args: vec![Type::Vec2, Type::Vec2, Type::Scalar],
434            ret: Type::Vec2,
435        }];
436        #[cfg(feature = "3d")]
437        sigs.push(Signature {
438            args: vec![Type::Vec3, Type::Vec3, Type::Scalar],
439            ret: Type::Vec3,
440        });
441        #[cfg(feature = "4d")]
442        sigs.push(Signature {
443            args: vec![Type::Vec4, Type::Vec4, Type::Scalar],
444            ret: Type::Vec4,
445        });
446        sigs
447    }
448
449    fn call(&self, args: &[V]) -> V {
450        let t = args[2].as_scalar().unwrap();
451        match (args[0].typ(), args[1].typ()) {
452            (Type::Vec2, Type::Vec2) => {
453                let a = args[0].as_vec2().unwrap();
454                let b = args[1].as_vec2().unwrap();
455                V::from_vec2([a[0] + (b[0] - a[0]) * t, a[1] + (b[1] - a[1]) * t])
456            }
457            #[cfg(feature = "3d")]
458            (Type::Vec3, Type::Vec3) => {
459                let a = args[0].as_vec3().unwrap();
460                let b = args[1].as_vec3().unwrap();
461                V::from_vec3([
462                    a[0] + (b[0] - a[0]) * t,
463                    a[1] + (b[1] - a[1]) * t,
464                    a[2] + (b[2] - a[2]) * t,
465                ])
466            }
467            #[cfg(feature = "4d")]
468            (Type::Vec4, Type::Vec4) => {
469                let a = args[0].as_vec4().unwrap();
470                let b = args[1].as_vec4().unwrap();
471                V::from_vec4([
472                    a[0] + (b[0] - a[0]) * t,
473                    a[1] + (b[1] - a[1]) * t,
474                    a[2] + (b[2] - a[2]) * t,
475                    a[3] + (b[3] - a[3]) * t,
476                ])
477            }
478            _ => unreachable!("signature mismatch"),
479        }
480    }
481}
482
483// ============================================================================
484// Mix (alias for lerp, GLSL naming)
485// ============================================================================
486
487/// Linear interpolation (GLSL naming): mix(a, b, t) -> vec
488pub struct Mix;
489
490impl<T, V> LinalgFn<T, V> for Mix
491where
492    T: Numeric,
493    V: LinalgValue<T>,
494{
495    fn name(&self) -> &str {
496        "mix"
497    }
498
499    fn signatures(&self) -> Vec<Signature> {
500        let mut sigs = vec![Signature {
501            args: vec![Type::Vec2, Type::Vec2, Type::Scalar],
502            ret: Type::Vec2,
503        }];
504        #[cfg(feature = "3d")]
505        sigs.push(Signature {
506            args: vec![Type::Vec3, Type::Vec3, Type::Scalar],
507            ret: Type::Vec3,
508        });
509        #[cfg(feature = "4d")]
510        sigs.push(Signature {
511            args: vec![Type::Vec4, Type::Vec4, Type::Scalar],
512            ret: Type::Vec4,
513        });
514        sigs
515    }
516
517    fn call(&self, args: &[V]) -> V {
518        let t = args[2].as_scalar().unwrap();
519        match (args[0].typ(), args[1].typ()) {
520            (Type::Vec2, Type::Vec2) => {
521                let a = args[0].as_vec2().unwrap();
522                let b = args[1].as_vec2().unwrap();
523                V::from_vec2([a[0] + (b[0] - a[0]) * t, a[1] + (b[1] - a[1]) * t])
524            }
525            #[cfg(feature = "3d")]
526            (Type::Vec3, Type::Vec3) => {
527                let a = args[0].as_vec3().unwrap();
528                let b = args[1].as_vec3().unwrap();
529                V::from_vec3([
530                    a[0] + (b[0] - a[0]) * t,
531                    a[1] + (b[1] - a[1]) * t,
532                    a[2] + (b[2] - a[2]) * t,
533                ])
534            }
535            #[cfg(feature = "4d")]
536            (Type::Vec4, Type::Vec4) => {
537                let a = args[0].as_vec4().unwrap();
538                let b = args[1].as_vec4().unwrap();
539                V::from_vec4([
540                    a[0] + (b[0] - a[0]) * t,
541                    a[1] + (b[1] - a[1]) * t,
542                    a[2] + (b[2] - a[2]) * t,
543                    a[3] + (b[3] - a[3]) * t,
544                ])
545            }
546            _ => unreachable!("signature mismatch"),
547        }
548    }
549}
550
551// ============================================================================
552// Vector constructors
553// ============================================================================
554
555/// Construct Vec2 from two scalars: vec2(x, y) -> Vec2
556pub struct Vec2Constructor;
557
558impl<T, V> LinalgFn<T, V> for Vec2Constructor
559where
560    T: Numeric,
561    V: LinalgValue<T>,
562{
563    fn name(&self) -> &str {
564        "vec2"
565    }
566
567    fn signatures(&self) -> Vec<Signature> {
568        vec![Signature {
569            args: vec![Type::Scalar, Type::Scalar],
570            ret: Type::Vec2,
571        }]
572    }
573
574    fn call(&self, args: &[V]) -> V {
575        let x = args[0].as_scalar().unwrap();
576        let y = args[1].as_scalar().unwrap();
577        V::from_vec2([x, y])
578    }
579}
580
581/// Construct Vec3 from three scalars: vec3(x, y, z) -> Vec3
582#[cfg(feature = "3d")]
583pub struct Vec3Constructor;
584
585#[cfg(feature = "3d")]
586impl<T, V> LinalgFn<T, V> for Vec3Constructor
587where
588    T: Numeric,
589    V: LinalgValue<T>,
590{
591    fn name(&self) -> &str {
592        "vec3"
593    }
594
595    fn signatures(&self) -> Vec<Signature> {
596        vec![Signature {
597            args: vec![Type::Scalar, Type::Scalar, Type::Scalar],
598            ret: Type::Vec3,
599        }]
600    }
601
602    fn call(&self, args: &[V]) -> V {
603        let x = args[0].as_scalar().unwrap();
604        let y = args[1].as_scalar().unwrap();
605        let z = args[2].as_scalar().unwrap();
606        V::from_vec3([x, y, z])
607    }
608}
609
610/// Construct Vec4 from four scalars: vec4(x, y, z, w) -> Vec4
611#[cfg(feature = "4d")]
612pub struct Vec4Constructor;
613
614#[cfg(feature = "4d")]
615impl<T, V> LinalgFn<T, V> for Vec4Constructor
616where
617    T: Numeric,
618    V: LinalgValue<T>,
619{
620    fn name(&self) -> &str {
621        "vec4"
622    }
623
624    fn signatures(&self) -> Vec<Signature> {
625        vec![Signature {
626            args: vec![Type::Scalar, Type::Scalar, Type::Scalar, Type::Scalar],
627            ret: Type::Vec4,
628        }]
629    }
630
631    fn call(&self, args: &[V]) -> V {
632        let x = args[0].as_scalar().unwrap();
633        let y = args[1].as_scalar().unwrap();
634        let z = args[2].as_scalar().unwrap();
635        let w = args[3].as_scalar().unwrap();
636        V::from_vec4([x, y, z, w])
637    }
638}
639
640// ============================================================================
641// Matrix constructors
642// ============================================================================
643
644/// Construct Mat2 from four scalars (column-major): mat2(c0r0, c0r1, c1r0, c1r1) -> Mat2
645pub struct Mat2Constructor;
646
647impl<T, V> LinalgFn<T, V> for Mat2Constructor
648where
649    T: Numeric,
650    V: LinalgValue<T>,
651{
652    fn name(&self) -> &str {
653        "mat2"
654    }
655
656    fn signatures(&self) -> Vec<Signature> {
657        vec![Signature {
658            args: vec![Type::Scalar, Type::Scalar, Type::Scalar, Type::Scalar],
659            ret: Type::Mat2,
660        }]
661    }
662
663    fn call(&self, args: &[V]) -> V {
664        let a = args[0].as_scalar().unwrap();
665        let b = args[1].as_scalar().unwrap();
666        let c = args[2].as_scalar().unwrap();
667        let d = args[3].as_scalar().unwrap();
668        V::from_mat2([a, b, c, d])
669    }
670}
671
672/// Construct Mat3 from nine scalars (column-major): mat3(...) -> Mat3
673#[cfg(feature = "3d")]
674pub struct Mat3Constructor;
675
676#[cfg(feature = "3d")]
677impl<T, V> LinalgFn<T, V> for Mat3Constructor
678where
679    T: Numeric,
680    V: LinalgValue<T>,
681{
682    fn name(&self) -> &str {
683        "mat3"
684    }
685
686    fn signatures(&self) -> Vec<Signature> {
687        vec![Signature {
688            args: vec![
689                Type::Scalar,
690                Type::Scalar,
691                Type::Scalar,
692                Type::Scalar,
693                Type::Scalar,
694                Type::Scalar,
695                Type::Scalar,
696                Type::Scalar,
697                Type::Scalar,
698            ],
699            ret: Type::Mat3,
700        }]
701    }
702
703    fn call(&self, args: &[V]) -> V {
704        V::from_mat3([
705            args[0].as_scalar().unwrap(),
706            args[1].as_scalar().unwrap(),
707            args[2].as_scalar().unwrap(),
708            args[3].as_scalar().unwrap(),
709            args[4].as_scalar().unwrap(),
710            args[5].as_scalar().unwrap(),
711            args[6].as_scalar().unwrap(),
712            args[7].as_scalar().unwrap(),
713            args[8].as_scalar().unwrap(),
714        ])
715    }
716}
717
718/// Construct Mat4 from sixteen scalars (column-major): mat4(...) -> Mat4
719#[cfg(feature = "4d")]
720pub struct Mat4Constructor;
721
722#[cfg(feature = "4d")]
723impl<T, V> LinalgFn<T, V> for Mat4Constructor
724where
725    T: Numeric,
726    V: LinalgValue<T>,
727{
728    fn name(&self) -> &str {
729        "mat4"
730    }
731
732    fn signatures(&self) -> Vec<Signature> {
733        vec![Signature {
734            args: vec![
735                Type::Scalar,
736                Type::Scalar,
737                Type::Scalar,
738                Type::Scalar,
739                Type::Scalar,
740                Type::Scalar,
741                Type::Scalar,
742                Type::Scalar,
743                Type::Scalar,
744                Type::Scalar,
745                Type::Scalar,
746                Type::Scalar,
747                Type::Scalar,
748                Type::Scalar,
749                Type::Scalar,
750                Type::Scalar,
751            ],
752            ret: Type::Mat4,
753        }]
754    }
755
756    fn call(&self, args: &[V]) -> V {
757        V::from_mat4([
758            args[0].as_scalar().unwrap(),
759            args[1].as_scalar().unwrap(),
760            args[2].as_scalar().unwrap(),
761            args[3].as_scalar().unwrap(),
762            args[4].as_scalar().unwrap(),
763            args[5].as_scalar().unwrap(),
764            args[6].as_scalar().unwrap(),
765            args[7].as_scalar().unwrap(),
766            args[8].as_scalar().unwrap(),
767            args[9].as_scalar().unwrap(),
768            args[10].as_scalar().unwrap(),
769            args[11].as_scalar().unwrap(),
770            args[12].as_scalar().unwrap(),
771            args[13].as_scalar().unwrap(),
772            args[14].as_scalar().unwrap(),
773            args[15].as_scalar().unwrap(),
774        ])
775    }
776}
777
778// ============================================================================
779// Component extraction
780// ============================================================================
781
782/// Extract x component: x(v) -> Scalar
783pub struct ExtractX;
784
785impl<T, V> LinalgFn<T, V> for ExtractX
786where
787    T: Numeric,
788    V: LinalgValue<T>,
789{
790    fn name(&self) -> &str {
791        "x"
792    }
793
794    fn signatures(&self) -> Vec<Signature> {
795        let mut sigs = vec![Signature {
796            args: vec![Type::Vec2],
797            ret: Type::Scalar,
798        }];
799        #[cfg(feature = "3d")]
800        sigs.push(Signature {
801            args: vec![Type::Vec3],
802            ret: Type::Scalar,
803        });
804        #[cfg(feature = "4d")]
805        sigs.push(Signature {
806            args: vec![Type::Vec4],
807            ret: Type::Scalar,
808        });
809        sigs
810    }
811
812    fn call(&self, args: &[V]) -> V {
813        match args[0].typ() {
814            Type::Vec2 => V::from_scalar(args[0].as_vec2().unwrap()[0]),
815            #[cfg(feature = "3d")]
816            Type::Vec3 => V::from_scalar(args[0].as_vec3().unwrap()[0]),
817            #[cfg(feature = "4d")]
818            Type::Vec4 => V::from_scalar(args[0].as_vec4().unwrap()[0]),
819            _ => unreachable!("signature mismatch"),
820        }
821    }
822}
823
824/// Extract y component: y(v) -> Scalar
825pub struct ExtractY;
826
827impl<T, V> LinalgFn<T, V> for ExtractY
828where
829    T: Numeric,
830    V: LinalgValue<T>,
831{
832    fn name(&self) -> &str {
833        "y"
834    }
835
836    fn signatures(&self) -> Vec<Signature> {
837        let mut sigs = vec![Signature {
838            args: vec![Type::Vec2],
839            ret: Type::Scalar,
840        }];
841        #[cfg(feature = "3d")]
842        sigs.push(Signature {
843            args: vec![Type::Vec3],
844            ret: Type::Scalar,
845        });
846        #[cfg(feature = "4d")]
847        sigs.push(Signature {
848            args: vec![Type::Vec4],
849            ret: Type::Scalar,
850        });
851        sigs
852    }
853
854    fn call(&self, args: &[V]) -> V {
855        match args[0].typ() {
856            Type::Vec2 => V::from_scalar(args[0].as_vec2().unwrap()[1]),
857            #[cfg(feature = "3d")]
858            Type::Vec3 => V::from_scalar(args[0].as_vec3().unwrap()[1]),
859            #[cfg(feature = "4d")]
860            Type::Vec4 => V::from_scalar(args[0].as_vec4().unwrap()[1]),
861            _ => unreachable!("signature mismatch"),
862        }
863    }
864}
865
866/// Extract z component: z(v) -> Scalar (Vec3 and Vec4 only)
867#[cfg(feature = "3d")]
868pub struct ExtractZ;
869
870#[cfg(feature = "3d")]
871impl<T, V> LinalgFn<T, V> for ExtractZ
872where
873    T: Numeric,
874    V: LinalgValue<T>,
875{
876    fn name(&self) -> &str {
877        "z"
878    }
879
880    fn signatures(&self) -> Vec<Signature> {
881        let mut sigs = vec![Signature {
882            args: vec![Type::Vec3],
883            ret: Type::Scalar,
884        }];
885        #[cfg(feature = "4d")]
886        sigs.push(Signature {
887            args: vec![Type::Vec4],
888            ret: Type::Scalar,
889        });
890        sigs
891    }
892
893    fn call(&self, args: &[V]) -> V {
894        match args[0].typ() {
895            Type::Vec3 => V::from_scalar(args[0].as_vec3().unwrap()[2]),
896            #[cfg(feature = "4d")]
897            Type::Vec4 => V::from_scalar(args[0].as_vec4().unwrap()[2]),
898            _ => unreachable!("signature mismatch"),
899        }
900    }
901}
902
903/// Extract w component: w(v) -> Scalar (Vec4 only)
904#[cfg(feature = "4d")]
905pub struct ExtractW;
906
907#[cfg(feature = "4d")]
908impl<T, V> LinalgFn<T, V> for ExtractW
909where
910    T: Numeric,
911    V: LinalgValue<T>,
912{
913    fn name(&self) -> &str {
914        "w"
915    }
916
917    fn signatures(&self) -> Vec<Signature> {
918        vec![Signature {
919            args: vec![Type::Vec4],
920            ret: Type::Scalar,
921        }]
922    }
923
924    fn call(&self, args: &[V]) -> V {
925        V::from_scalar(args[0].as_vec4().unwrap()[3])
926    }
927}
928
929// ============================================================================
930// Vectorized math functions
931// ============================================================================
932
933macro_rules! define_vectorized_fn {
934    ($name:ident, $fn_name:expr, $method:ident) => {
935        pub struct $name;
936
937        impl<T, V> LinalgFn<T, V> for $name
938        where
939            T: Float + Numeric,
940            V: LinalgValue<T>,
941        {
942            fn name(&self) -> &str {
943                $fn_name
944            }
945
946            fn signatures(&self) -> Vec<Signature> {
947                let mut sigs = vec![Signature {
948                    args: vec![Type::Vec2],
949                    ret: Type::Vec2,
950                }];
951                #[cfg(feature = "3d")]
952                sigs.push(Signature {
953                    args: vec![Type::Vec3],
954                    ret: Type::Vec3,
955                });
956                #[cfg(feature = "4d")]
957                sigs.push(Signature {
958                    args: vec![Type::Vec4],
959                    ret: Type::Vec4,
960                });
961                sigs
962            }
963
964            fn call(&self, args: &[V]) -> V {
965                match args[0].typ() {
966                    Type::Vec2 => {
967                        let v = args[0].as_vec2().unwrap();
968                        V::from_vec2([v[0].$method(), v[1].$method()])
969                    }
970                    #[cfg(feature = "3d")]
971                    Type::Vec3 => {
972                        let v = args[0].as_vec3().unwrap();
973                        V::from_vec3([v[0].$method(), v[1].$method(), v[2].$method()])
974                    }
975                    #[cfg(feature = "4d")]
976                    Type::Vec4 => {
977                        let v = args[0].as_vec4().unwrap();
978                        V::from_vec4([
979                            v[0].$method(),
980                            v[1].$method(),
981                            v[2].$method(),
982                            v[3].$method(),
983                        ])
984                    }
985                    _ => unreachable!("signature mismatch"),
986                }
987            }
988        }
989    };
990}
991
992define_vectorized_fn!(VecSin, "sin", sin);
993define_vectorized_fn!(VecCos, "cos", cos);
994define_vectorized_fn!(VecAbs, "abs", abs);
995define_vectorized_fn!(VecFloor, "floor", floor);
996define_vectorized_fn!(VecSqrt, "sqrt", sqrt);
997
998/// Vectorized fract: fract(v) -> VecN (fractional part)
999pub struct VecFract;
1000
1001impl<T, V> LinalgFn<T, V> for VecFract
1002where
1003    T: Float + Numeric,
1004    V: LinalgValue<T>,
1005{
1006    fn name(&self) -> &str {
1007        "fract"
1008    }
1009
1010    fn signatures(&self) -> Vec<Signature> {
1011        let mut sigs = vec![Signature {
1012            args: vec![Type::Vec2],
1013            ret: Type::Vec2,
1014        }];
1015        #[cfg(feature = "3d")]
1016        sigs.push(Signature {
1017            args: vec![Type::Vec3],
1018            ret: Type::Vec3,
1019        });
1020        #[cfg(feature = "4d")]
1021        sigs.push(Signature {
1022            args: vec![Type::Vec4],
1023            ret: Type::Vec4,
1024        });
1025        sigs
1026    }
1027
1028    fn call(&self, args: &[V]) -> V {
1029        match args[0].typ() {
1030            Type::Vec2 => {
1031                let v = args[0].as_vec2().unwrap();
1032                V::from_vec2([v[0].fract(), v[1].fract()])
1033            }
1034            #[cfg(feature = "3d")]
1035            Type::Vec3 => {
1036                let v = args[0].as_vec3().unwrap();
1037                V::from_vec3([v[0].fract(), v[1].fract(), v[2].fract()])
1038            }
1039            #[cfg(feature = "4d")]
1040            Type::Vec4 => {
1041                let v = args[0].as_vec4().unwrap();
1042                V::from_vec4([v[0].fract(), v[1].fract(), v[2].fract(), v[3].fract()])
1043            }
1044            _ => unreachable!("signature mismatch"),
1045        }
1046    }
1047}
1048
1049// ============================================================================
1050// Vectorized comparison functions
1051// ============================================================================
1052
1053/// Vectorized min: min(a, b) -> VecN (component-wise minimum)
1054pub struct VecMin;
1055
1056impl<T, V> LinalgFn<T, V> for VecMin
1057where
1058    T: Float + Numeric,
1059    V: LinalgValue<T>,
1060{
1061    fn name(&self) -> &str {
1062        "min"
1063    }
1064
1065    fn signatures(&self) -> Vec<Signature> {
1066        let mut sigs = vec![Signature {
1067            args: vec![Type::Vec2, Type::Vec2],
1068            ret: Type::Vec2,
1069        }];
1070        #[cfg(feature = "3d")]
1071        sigs.push(Signature {
1072            args: vec![Type::Vec3, Type::Vec3],
1073            ret: Type::Vec3,
1074        });
1075        #[cfg(feature = "4d")]
1076        sigs.push(Signature {
1077            args: vec![Type::Vec4, Type::Vec4],
1078            ret: Type::Vec4,
1079        });
1080        sigs
1081    }
1082
1083    fn call(&self, args: &[V]) -> V {
1084        match (args[0].typ(), args[1].typ()) {
1085            (Type::Vec2, Type::Vec2) => {
1086                let a = args[0].as_vec2().unwrap();
1087                let b = args[1].as_vec2().unwrap();
1088                V::from_vec2([a[0].min(b[0]), a[1].min(b[1])])
1089            }
1090            #[cfg(feature = "3d")]
1091            (Type::Vec3, Type::Vec3) => {
1092                let a = args[0].as_vec3().unwrap();
1093                let b = args[1].as_vec3().unwrap();
1094                V::from_vec3([a[0].min(b[0]), a[1].min(b[1]), a[2].min(b[2])])
1095            }
1096            #[cfg(feature = "4d")]
1097            (Type::Vec4, Type::Vec4) => {
1098                let a = args[0].as_vec4().unwrap();
1099                let b = args[1].as_vec4().unwrap();
1100                V::from_vec4([
1101                    a[0].min(b[0]),
1102                    a[1].min(b[1]),
1103                    a[2].min(b[2]),
1104                    a[3].min(b[3]),
1105                ])
1106            }
1107            _ => unreachable!("signature mismatch"),
1108        }
1109    }
1110}
1111
1112/// Vectorized max: max(a, b) -> VecN (component-wise maximum)
1113pub struct VecMax;
1114
1115impl<T, V> LinalgFn<T, V> for VecMax
1116where
1117    T: Float + Numeric,
1118    V: LinalgValue<T>,
1119{
1120    fn name(&self) -> &str {
1121        "max"
1122    }
1123
1124    fn signatures(&self) -> Vec<Signature> {
1125        let mut sigs = vec![Signature {
1126            args: vec![Type::Vec2, Type::Vec2],
1127            ret: Type::Vec2,
1128        }];
1129        #[cfg(feature = "3d")]
1130        sigs.push(Signature {
1131            args: vec![Type::Vec3, Type::Vec3],
1132            ret: Type::Vec3,
1133        });
1134        #[cfg(feature = "4d")]
1135        sigs.push(Signature {
1136            args: vec![Type::Vec4, Type::Vec4],
1137            ret: Type::Vec4,
1138        });
1139        sigs
1140    }
1141
1142    fn call(&self, args: &[V]) -> V {
1143        match (args[0].typ(), args[1].typ()) {
1144            (Type::Vec2, Type::Vec2) => {
1145                let a = args[0].as_vec2().unwrap();
1146                let b = args[1].as_vec2().unwrap();
1147                V::from_vec2([a[0].max(b[0]), a[1].max(b[1])])
1148            }
1149            #[cfg(feature = "3d")]
1150            (Type::Vec3, Type::Vec3) => {
1151                let a = args[0].as_vec3().unwrap();
1152                let b = args[1].as_vec3().unwrap();
1153                V::from_vec3([a[0].max(b[0]), a[1].max(b[1]), a[2].max(b[2])])
1154            }
1155            #[cfg(feature = "4d")]
1156            (Type::Vec4, Type::Vec4) => {
1157                let a = args[0].as_vec4().unwrap();
1158                let b = args[1].as_vec4().unwrap();
1159                V::from_vec4([
1160                    a[0].max(b[0]),
1161                    a[1].max(b[1]),
1162                    a[2].max(b[2]),
1163                    a[3].max(b[3]),
1164                ])
1165            }
1166            _ => unreachable!("signature mismatch"),
1167        }
1168    }
1169}
1170
1171/// Vectorized clamp: clamp(x, min, max) -> VecN
1172pub struct VecClamp;
1173
1174impl<T, V> LinalgFn<T, V> for VecClamp
1175where
1176    T: Float + Numeric,
1177    V: LinalgValue<T>,
1178{
1179    fn name(&self) -> &str {
1180        "clamp"
1181    }
1182
1183    fn signatures(&self) -> Vec<Signature> {
1184        let mut sigs = vec![Signature {
1185            args: vec![Type::Vec2, Type::Vec2, Type::Vec2],
1186            ret: Type::Vec2,
1187        }];
1188        #[cfg(feature = "3d")]
1189        sigs.push(Signature {
1190            args: vec![Type::Vec3, Type::Vec3, Type::Vec3],
1191            ret: Type::Vec3,
1192        });
1193        #[cfg(feature = "4d")]
1194        sigs.push(Signature {
1195            args: vec![Type::Vec4, Type::Vec4, Type::Vec4],
1196            ret: Type::Vec4,
1197        });
1198        sigs
1199    }
1200
1201    fn call(&self, args: &[V]) -> V {
1202        match (args[0].typ(), args[1].typ(), args[2].typ()) {
1203            (Type::Vec2, Type::Vec2, Type::Vec2) => {
1204                let x = args[0].as_vec2().unwrap();
1205                let lo = args[1].as_vec2().unwrap();
1206                let hi = args[2].as_vec2().unwrap();
1207                V::from_vec2([x[0].max(lo[0]).min(hi[0]), x[1].max(lo[1]).min(hi[1])])
1208            }
1209            #[cfg(feature = "3d")]
1210            (Type::Vec3, Type::Vec3, Type::Vec3) => {
1211                let x = args[0].as_vec3().unwrap();
1212                let lo = args[1].as_vec3().unwrap();
1213                let hi = args[2].as_vec3().unwrap();
1214                V::from_vec3([
1215                    x[0].max(lo[0]).min(hi[0]),
1216                    x[1].max(lo[1]).min(hi[1]),
1217                    x[2].max(lo[2]).min(hi[2]),
1218                ])
1219            }
1220            #[cfg(feature = "4d")]
1221            (Type::Vec4, Type::Vec4, Type::Vec4) => {
1222                let x = args[0].as_vec4().unwrap();
1223                let lo = args[1].as_vec4().unwrap();
1224                let hi = args[2].as_vec4().unwrap();
1225                V::from_vec4([
1226                    x[0].max(lo[0]).min(hi[0]),
1227                    x[1].max(lo[1]).min(hi[1]),
1228                    x[2].max(lo[2]).min(hi[2]),
1229                    x[3].max(lo[3]).min(hi[3]),
1230                ])
1231            }
1232            _ => unreachable!("signature mismatch"),
1233        }
1234    }
1235}
1236
1237// ============================================================================
1238// Interpolation functions
1239// ============================================================================
1240
1241/// Vectorized step: step(edge, x) -> VecN (0 if x < edge, 1 otherwise)
1242pub struct VecStep;
1243
1244impl<T, V> LinalgFn<T, V> for VecStep
1245where
1246    T: Float + Numeric,
1247    V: LinalgValue<T>,
1248{
1249    fn name(&self) -> &str {
1250        "step"
1251    }
1252
1253    fn signatures(&self) -> Vec<Signature> {
1254        let mut sigs = vec![Signature {
1255            args: vec![Type::Vec2, Type::Vec2],
1256            ret: Type::Vec2,
1257        }];
1258        #[cfg(feature = "3d")]
1259        sigs.push(Signature {
1260            args: vec![Type::Vec3, Type::Vec3],
1261            ret: Type::Vec3,
1262        });
1263        #[cfg(feature = "4d")]
1264        sigs.push(Signature {
1265            args: vec![Type::Vec4, Type::Vec4],
1266            ret: Type::Vec4,
1267        });
1268        sigs
1269    }
1270
1271    fn call(&self, args: &[V]) -> V {
1272        fn step<T: Float>(edge: T, x: T) -> T {
1273            if x < edge { T::zero() } else { T::one() }
1274        }
1275        match (args[0].typ(), args[1].typ()) {
1276            (Type::Vec2, Type::Vec2) => {
1277                let edge = args[0].as_vec2().unwrap();
1278                let x = args[1].as_vec2().unwrap();
1279                V::from_vec2([step(edge[0], x[0]), step(edge[1], x[1])])
1280            }
1281            #[cfg(feature = "3d")]
1282            (Type::Vec3, Type::Vec3) => {
1283                let edge = args[0].as_vec3().unwrap();
1284                let x = args[1].as_vec3().unwrap();
1285                V::from_vec3([
1286                    step(edge[0], x[0]),
1287                    step(edge[1], x[1]),
1288                    step(edge[2], x[2]),
1289                ])
1290            }
1291            #[cfg(feature = "4d")]
1292            (Type::Vec4, Type::Vec4) => {
1293                let edge = args[0].as_vec4().unwrap();
1294                let x = args[1].as_vec4().unwrap();
1295                V::from_vec4([
1296                    step(edge[0], x[0]),
1297                    step(edge[1], x[1]),
1298                    step(edge[2], x[2]),
1299                    step(edge[3], x[3]),
1300                ])
1301            }
1302            _ => unreachable!("signature mismatch"),
1303        }
1304    }
1305}
1306
1307/// Vectorized smoothstep: smoothstep(edge0, edge1, x) -> VecN
1308pub struct VecSmoothstep;
1309
1310impl<T, V> LinalgFn<T, V> for VecSmoothstep
1311where
1312    T: Float + Numeric,
1313    V: LinalgValue<T>,
1314{
1315    fn name(&self) -> &str {
1316        "smoothstep"
1317    }
1318
1319    fn signatures(&self) -> Vec<Signature> {
1320        let mut sigs = vec![Signature {
1321            args: vec![Type::Vec2, Type::Vec2, Type::Vec2],
1322            ret: Type::Vec2,
1323        }];
1324        #[cfg(feature = "3d")]
1325        sigs.push(Signature {
1326            args: vec![Type::Vec3, Type::Vec3, Type::Vec3],
1327            ret: Type::Vec3,
1328        });
1329        #[cfg(feature = "4d")]
1330        sigs.push(Signature {
1331            args: vec![Type::Vec4, Type::Vec4, Type::Vec4],
1332            ret: Type::Vec4,
1333        });
1334        sigs
1335    }
1336
1337    fn call(&self, args: &[V]) -> V {
1338        fn smoothstep<T: Float>(edge0: T, edge1: T, x: T) -> T {
1339            let t = ((x - edge0) / (edge1 - edge0)).max(T::zero()).min(T::one());
1340            let three = T::from(3.0).unwrap();
1341            let two = T::from(2.0).unwrap();
1342            t * t * (three - two * t)
1343        }
1344        match (args[0].typ(), args[1].typ(), args[2].typ()) {
1345            (Type::Vec2, Type::Vec2, Type::Vec2) => {
1346                let e0 = args[0].as_vec2().unwrap();
1347                let e1 = args[1].as_vec2().unwrap();
1348                let x = args[2].as_vec2().unwrap();
1349                V::from_vec2([
1350                    smoothstep(e0[0], e1[0], x[0]),
1351                    smoothstep(e0[1], e1[1], x[1]),
1352                ])
1353            }
1354            #[cfg(feature = "3d")]
1355            (Type::Vec3, Type::Vec3, Type::Vec3) => {
1356                let e0 = args[0].as_vec3().unwrap();
1357                let e1 = args[1].as_vec3().unwrap();
1358                let x = args[2].as_vec3().unwrap();
1359                V::from_vec3([
1360                    smoothstep(e0[0], e1[0], x[0]),
1361                    smoothstep(e0[1], e1[1], x[1]),
1362                    smoothstep(e0[2], e1[2], x[2]),
1363                ])
1364            }
1365            #[cfg(feature = "4d")]
1366            (Type::Vec4, Type::Vec4, Type::Vec4) => {
1367                let e0 = args[0].as_vec4().unwrap();
1368                let e1 = args[1].as_vec4().unwrap();
1369                let x = args[2].as_vec4().unwrap();
1370                V::from_vec4([
1371                    smoothstep(e0[0], e1[0], x[0]),
1372                    smoothstep(e0[1], e1[1], x[1]),
1373                    smoothstep(e0[2], e1[2], x[2]),
1374                    smoothstep(e0[3], e1[3], x[3]),
1375                ])
1376            }
1377            _ => unreachable!("signature mismatch"),
1378        }
1379    }
1380}
1381
1382// ============================================================================
1383// Transform functions
1384// ============================================================================
1385
1386/// Rotate a 2D vector by an angle: rotate2d(v, angle) -> Vec2
1387pub struct Rotate2D;
1388
1389impl<T, V> LinalgFn<T, V> for Rotate2D
1390where
1391    T: Float + Numeric,
1392    V: LinalgValue<T>,
1393{
1394    fn name(&self) -> &str {
1395        "rotate2d"
1396    }
1397
1398    fn signatures(&self) -> Vec<Signature> {
1399        vec![Signature {
1400            args: vec![Type::Vec2, Type::Scalar],
1401            ret: Type::Vec2,
1402        }]
1403    }
1404
1405    fn call(&self, args: &[V]) -> V {
1406        let v = args[0].as_vec2().unwrap();
1407        let angle = args[1].as_scalar().unwrap();
1408        let c = angle.cos();
1409        let s = angle.sin();
1410        V::from_vec2([v[0] * c - v[1] * s, v[0] * s + v[1] * c])
1411    }
1412}
1413
1414/// Rotate a 3D vector around the X axis: rotate_x(v, angle) -> Vec3
1415#[cfg(feature = "3d")]
1416pub struct RotateX;
1417
1418#[cfg(feature = "3d")]
1419impl<T, V> LinalgFn<T, V> for RotateX
1420where
1421    T: Float + Numeric,
1422    V: LinalgValue<T>,
1423{
1424    fn name(&self) -> &str {
1425        "rotate_x"
1426    }
1427
1428    fn signatures(&self) -> Vec<Signature> {
1429        vec![Signature {
1430            args: vec![Type::Vec3, Type::Scalar],
1431            ret: Type::Vec3,
1432        }]
1433    }
1434
1435    fn call(&self, args: &[V]) -> V {
1436        let v = args[0].as_vec3().unwrap();
1437        let angle = args[1].as_scalar().unwrap();
1438        let c = angle.cos();
1439        let s = angle.sin();
1440        // Rotation around X: [x, y*c - z*s, y*s + z*c]
1441        V::from_vec3([v[0], v[1] * c - v[2] * s, v[1] * s + v[2] * c])
1442    }
1443}
1444
1445/// Rotate a 3D vector around the Y axis: rotate_y(v, angle) -> Vec3
1446#[cfg(feature = "3d")]
1447pub struct RotateY;
1448
1449#[cfg(feature = "3d")]
1450impl<T, V> LinalgFn<T, V> for RotateY
1451where
1452    T: Float + Numeric,
1453    V: LinalgValue<T>,
1454{
1455    fn name(&self) -> &str {
1456        "rotate_y"
1457    }
1458
1459    fn signatures(&self) -> Vec<Signature> {
1460        vec![Signature {
1461            args: vec![Type::Vec3, Type::Scalar],
1462            ret: Type::Vec3,
1463        }]
1464    }
1465
1466    fn call(&self, args: &[V]) -> V {
1467        let v = args[0].as_vec3().unwrap();
1468        let angle = args[1].as_scalar().unwrap();
1469        let c = angle.cos();
1470        let s = angle.sin();
1471        // Rotation around Y: [x*c + z*s, y, -x*s + z*c]
1472        V::from_vec3([v[0] * c + v[2] * s, v[1], -v[0] * s + v[2] * c])
1473    }
1474}
1475
1476/// Rotate a 3D vector around the Z axis: rotate_z(v, angle) -> Vec3
1477#[cfg(feature = "3d")]
1478pub struct RotateZ;
1479
1480#[cfg(feature = "3d")]
1481impl<T, V> LinalgFn<T, V> for RotateZ
1482where
1483    T: Float + Numeric,
1484    V: LinalgValue<T>,
1485{
1486    fn name(&self) -> &str {
1487        "rotate_z"
1488    }
1489
1490    fn signatures(&self) -> Vec<Signature> {
1491        vec![Signature {
1492            args: vec![Type::Vec3, Type::Scalar],
1493            ret: Type::Vec3,
1494        }]
1495    }
1496
1497    fn call(&self, args: &[V]) -> V {
1498        let v = args[0].as_vec3().unwrap();
1499        let angle = args[1].as_scalar().unwrap();
1500        let c = angle.cos();
1501        let s = angle.sin();
1502        // Rotation around Z: [x*c - y*s, x*s + y*c, z]
1503        V::from_vec3([v[0] * c - v[1] * s, v[0] * s + v[1] * c, v[2]])
1504    }
1505}
1506
1507/// Rotate a 3D vector around an arbitrary axis: rotate3d(v, axis, angle) -> Vec3
1508/// Uses Rodrigues' rotation formula. The axis should be normalized.
1509#[cfg(feature = "3d")]
1510pub struct Rotate3D;
1511
1512#[cfg(feature = "3d")]
1513impl<T, V> LinalgFn<T, V> for Rotate3D
1514where
1515    T: Float + Numeric,
1516    V: LinalgValue<T>,
1517{
1518    fn name(&self) -> &str {
1519        "rotate3d"
1520    }
1521
1522    fn signatures(&self) -> Vec<Signature> {
1523        vec![Signature {
1524            args: vec![Type::Vec3, Type::Vec3, Type::Scalar],
1525            ret: Type::Vec3,
1526        }]
1527    }
1528
1529    fn call(&self, args: &[V]) -> V {
1530        let v = args[0].as_vec3().unwrap();
1531        let k = args[1].as_vec3().unwrap(); // Assumed normalized
1532        let angle = args[2].as_scalar().unwrap();
1533
1534        // Rodrigues' rotation formula:
1535        // v' = v*cos(θ) + (k × v)*sin(θ) + k*(k·v)*(1-cos(θ))
1536        let c = angle.cos();
1537        let s = angle.sin();
1538
1539        // k · v (dot product)
1540        let k_dot_v = k[0] * v[0] + k[1] * v[1] + k[2] * v[2];
1541
1542        // k × v (cross product)
1543        let k_cross_v = [
1544            k[1] * v[2] - k[2] * v[1],
1545            k[2] * v[0] - k[0] * v[2],
1546            k[0] * v[1] - k[1] * v[0],
1547        ];
1548
1549        let one_minus_c = T::one() - c;
1550
1551        V::from_vec3([
1552            v[0] * c + k_cross_v[0] * s + k[0] * k_dot_v * one_minus_c,
1553            v[1] * c + k_cross_v[1] * s + k[1] * k_dot_v * one_minus_c,
1554            v[2] * c + k_cross_v[2] * s + k[2] * k_dot_v * one_minus_c,
1555        ])
1556    }
1557}
1558
1559// ============================================================================
1560// Registry helper
1561// ============================================================================
1562
1563use crate::{FunctionRegistry, Value};
1564
1565/// Register all standard linalg functions.
1566pub fn register_linalg<T, V>(registry: &mut FunctionRegistry<T, V>)
1567where
1568    T: Float + Numeric + 'static,
1569    V: LinalgValue<T> + 'static,
1570{
1571    registry.register(Dot);
1572    #[cfg(feature = "3d")]
1573    registry.register(Cross);
1574    registry.register(Length);
1575    registry.register(Normalize);
1576    registry.register(Distance);
1577    registry.register(Reflect);
1578    registry.register(Hadamard);
1579    registry.register(Lerp);
1580    registry.register(Mix);
1581
1582    // Vector constructors
1583    registry.register(Vec2Constructor);
1584    #[cfg(feature = "3d")]
1585    registry.register(Vec3Constructor);
1586    #[cfg(feature = "4d")]
1587    registry.register(Vec4Constructor);
1588
1589    // Matrix constructors
1590    registry.register(Mat2Constructor);
1591    #[cfg(feature = "3d")]
1592    registry.register(Mat3Constructor);
1593    #[cfg(feature = "4d")]
1594    registry.register(Mat4Constructor);
1595
1596    // Component extraction
1597    registry.register(ExtractX);
1598    registry.register(ExtractY);
1599    #[cfg(feature = "3d")]
1600    registry.register(ExtractZ);
1601    #[cfg(feature = "4d")]
1602    registry.register(ExtractW);
1603
1604    // Vectorized math
1605    registry.register(VecSin);
1606    registry.register(VecCos);
1607    registry.register(VecAbs);
1608    registry.register(VecFloor);
1609    registry.register(VecFract);
1610    registry.register(VecSqrt);
1611
1612    // Vectorized comparison
1613    registry.register(VecMin);
1614    registry.register(VecMax);
1615    registry.register(VecClamp);
1616
1617    // Interpolation
1618    registry.register(VecStep);
1619    registry.register(VecSmoothstep);
1620
1621    // Transform
1622    registry.register(Rotate2D);
1623    #[cfg(feature = "3d")]
1624    registry.register(RotateX);
1625    #[cfg(feature = "3d")]
1626    registry.register(RotateY);
1627    #[cfg(feature = "3d")]
1628    registry.register(RotateZ);
1629    #[cfg(feature = "3d")]
1630    registry.register(Rotate3D);
1631}
1632
1633/// Create a new registry with all standard linalg functions using the default Value type.
1634pub fn linalg_registry<T: Float + Numeric + 'static>() -> FunctionRegistry<T, Value<T>> {
1635    let mut registry = FunctionRegistry::new();
1636    register_linalg(&mut registry);
1637    registry
1638}
1639
1640/// Register only Numeric-compatible linalg functions (no sqrt, trig, etc.).
1641///
1642/// This is useful for integer vector math where Float methods aren't available.
1643/// Includes: dot, cross, hadamard, lerp, mix, constructors, extractors.
1644pub fn register_linalg_numeric<T, V>(registry: &mut FunctionRegistry<T, V>)
1645where
1646    T: Numeric + 'static,
1647    V: LinalgValue<T> + 'static,
1648{
1649    // Basic operations
1650    registry.register(Dot);
1651    #[cfg(feature = "3d")]
1652    registry.register(Cross);
1653    registry.register(Hadamard);
1654    registry.register(Lerp);
1655    registry.register(Mix);
1656
1657    // Vector constructors
1658    registry.register(Vec2Constructor);
1659    #[cfg(feature = "3d")]
1660    registry.register(Vec3Constructor);
1661    #[cfg(feature = "4d")]
1662    registry.register(Vec4Constructor);
1663
1664    // Matrix constructors
1665    registry.register(Mat2Constructor);
1666    #[cfg(feature = "3d")]
1667    registry.register(Mat3Constructor);
1668    #[cfg(feature = "4d")]
1669    registry.register(Mat4Constructor);
1670
1671    // Component extraction
1672    registry.register(ExtractX);
1673    registry.register(ExtractY);
1674    #[cfg(feature = "3d")]
1675    registry.register(ExtractZ);
1676    #[cfg(feature = "4d")]
1677    registry.register(ExtractW);
1678}
1679
1680/// Create a registry with Numeric-compatible functions for integer vectors.
1681///
1682/// Use this for `i32` or `i64` vector math. For float vectors, use `linalg_registry()`.
1683pub fn linalg_registry_int<T: Numeric + 'static>() -> FunctionRegistry<T, Value<T>> {
1684    let mut registry = FunctionRegistry::new();
1685    register_linalg_numeric(&mut registry);
1686    registry
1687}
1688
1689// ============================================================================
1690// Tests
1691// ============================================================================
1692
1693#[cfg(test)]
1694mod tests {
1695    use super::*;
1696    use std::collections::HashMap;
1697    use wick_core::Expr;
1698
1699    fn eval_expr(expr: &str, vars: &[(&str, Value<f32>)]) -> Value<f32> {
1700        let expr = Expr::parse(expr).unwrap();
1701        let var_map: HashMap<String, Value<f32>> = vars
1702            .iter()
1703            .map(|(k, v)| (k.to_string(), v.clone()))
1704            .collect();
1705        let registry = linalg_registry();
1706        crate::eval(expr.ast(), &var_map, &registry).unwrap()
1707    }
1708
1709    #[test]
1710    fn test_dot_vec2() {
1711        let result = eval_expr(
1712            "dot(a, b)",
1713            &[
1714                ("a", Value::Vec2([1.0, 2.0])),
1715                ("b", Value::Vec2([3.0, 4.0])),
1716            ],
1717        );
1718        assert_eq!(result, Value::Scalar(11.0)); // 1*3 + 2*4 = 11
1719    }
1720
1721    #[cfg(feature = "3d")]
1722    #[test]
1723    fn test_dot_vec3() {
1724        let result = eval_expr(
1725            "dot(a, b)",
1726            &[
1727                ("a", Value::Vec3([1.0, 2.0, 3.0])),
1728                ("b", Value::Vec3([4.0, 5.0, 6.0])),
1729            ],
1730        );
1731        assert_eq!(result, Value::Scalar(32.0)); // 1*4 + 2*5 + 3*6 = 32
1732    }
1733
1734    #[cfg(feature = "3d")]
1735    #[test]
1736    fn test_cross() {
1737        let result = eval_expr(
1738            "cross(a, b)",
1739            &[
1740                ("a", Value::Vec3([1.0, 0.0, 0.0])),
1741                ("b", Value::Vec3([0.0, 1.0, 0.0])),
1742            ],
1743        );
1744        assert_eq!(result, Value::Vec3([0.0, 0.0, 1.0])); // x cross y = z
1745    }
1746
1747    #[test]
1748    fn test_length_vec2() {
1749        let result = eval_expr("length(v)", &[("v", Value::Vec2([3.0, 4.0]))]);
1750        assert_eq!(result, Value::Scalar(5.0)); // 3-4-5 triangle
1751    }
1752
1753    #[test]
1754    fn test_normalize_vec2() {
1755        let result = eval_expr("normalize(v)", &[("v", Value::Vec2([3.0, 4.0]))]);
1756        if let Value::Vec2(v) = result {
1757            assert!((v[0] - 0.6).abs() < 0.001);
1758            assert!((v[1] - 0.8).abs() < 0.001);
1759        } else {
1760            panic!("expected Vec2");
1761        }
1762    }
1763
1764    #[test]
1765    fn test_distance_vec2() {
1766        let result = eval_expr(
1767            "distance(a, b)",
1768            &[
1769                ("a", Value::Vec2([0.0, 0.0])),
1770                ("b", Value::Vec2([3.0, 4.0])),
1771            ],
1772        );
1773        assert_eq!(result, Value::Scalar(5.0));
1774    }
1775
1776    #[test]
1777    fn test_reflect_vec2() {
1778        // Reflect (1, -1) off horizontal surface with normal (0, 1)
1779        let result = eval_expr(
1780            "reflect(i, n)",
1781            &[
1782                ("i", Value::Vec2([1.0, -1.0])),
1783                ("n", Value::Vec2([0.0, 1.0])),
1784            ],
1785        );
1786        if let Value::Vec2(v) = result {
1787            assert!((v[0] - 1.0).abs() < 0.001);
1788            assert!((v[1] - 1.0).abs() < 0.001);
1789        } else {
1790            panic!("expected Vec2");
1791        }
1792    }
1793
1794    #[test]
1795    fn test_hadamard_vec2() {
1796        let result = eval_expr(
1797            "hadamard(a, b)",
1798            &[
1799                ("a", Value::Vec2([2.0, 3.0])),
1800                ("b", Value::Vec2([4.0, 5.0])),
1801            ],
1802        );
1803        assert_eq!(result, Value::Vec2([8.0, 15.0]));
1804    }
1805
1806    #[test]
1807    fn test_lerp_vec2() {
1808        let result = eval_expr(
1809            "lerp(a, b, t)",
1810            &[
1811                ("a", Value::Vec2([0.0, 0.0])),
1812                ("b", Value::Vec2([10.0, 20.0])),
1813                ("t", Value::Scalar(0.5)),
1814            ],
1815        );
1816        assert_eq!(result, Value::Vec2([5.0, 10.0]));
1817    }
1818
1819    #[test]
1820    fn test_mix_vec2() {
1821        let result = eval_expr(
1822            "mix(a, b, t)",
1823            &[
1824                ("a", Value::Vec2([0.0, 0.0])),
1825                ("b", Value::Vec2([10.0, 20.0])),
1826                ("t", Value::Scalar(0.25)),
1827            ],
1828        );
1829        assert_eq!(result, Value::Vec2([2.5, 5.0]));
1830    }
1831
1832    #[test]
1833    fn test_integer_vectors() {
1834        use wick_core::Expr;
1835
1836        // Create integer registry
1837        let registry: crate::FunctionRegistry<i32, Value<i32>> = linalg_registry_int();
1838
1839        // Test dot product with integers
1840        let expr = Expr::parse("dot(a, b)").unwrap();
1841        let vars: HashMap<String, Value<i32>> = [
1842            ("a".to_string(), Value::Vec2([1, 2])),
1843            ("b".to_string(), Value::Vec2([3, 4])),
1844        ]
1845        .into();
1846        let result = crate::eval(expr.ast(), &vars, &registry).unwrap();
1847        assert_eq!(result, Value::Scalar(11)); // 1*3 + 2*4 = 11
1848
1849        // Test vec2 constructor
1850        let expr = Expr::parse("vec2(x, y)").unwrap();
1851        let vars: HashMap<String, Value<i32>> = [
1852            ("x".to_string(), Value::Scalar(5)),
1853            ("y".to_string(), Value::Scalar(7)),
1854        ]
1855        .into();
1856        let result = crate::eval(expr.ast(), &vars, &registry).unwrap();
1857        assert_eq!(result, Value::Vec2([5, 7]));
1858
1859        // Test hadamard (element-wise multiply)
1860        let expr = Expr::parse("hadamard(a, b)").unwrap();
1861        let vars: HashMap<String, Value<i32>> = [
1862            ("a".to_string(), Value::Vec2([2, 3])),
1863            ("b".to_string(), Value::Vec2([4, 5])),
1864        ]
1865        .into();
1866        let result = crate::eval(expr.ast(), &vars, &registry).unwrap();
1867        assert_eq!(result, Value::Vec2([8, 15]));
1868    }
1869}