Skip to main content

vita_core/tensor/
point.rs

1use core::ops::{Add, AddAssign, Index, IndexMut, Mul, Sub, SubAssign};
2
3use crate::Scalar;
4use crate::tensor::Vector3;
5
6/// A point with three coordinates `x`, `y`, and `z`.
7pub struct Point3<T> {
8    /// The first coordinate.
9    pub x: T,
10    /// The second coordinate.
11    pub y: T,
12    /// The third coordinate.
13    pub z: T,
14}
15
16impl<T: ::core::marker::Copy> ::core::marker::Copy for Point3<T> {}
17
18impl<T: ::core::clone::Clone> ::core::clone::Clone for Point3<T> {
19    #[inline]
20    fn clone(&self) -> Self {
21        Self {
22            x: self.x.clone(),
23            y: self.y.clone(),
24            z: self.z.clone(),
25        }
26    }
27}
28
29impl<T: ::core::fmt::Debug> ::core::fmt::Debug for Point3<T> {
30    fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
31        f.debug_struct("Point3")
32            .field("x", &self.x)
33            .field("y", &self.y)
34            .field("z", &self.z)
35            .finish()
36    }
37}
38
39impl<T: ::core::cmp::PartialEq> ::core::cmp::PartialEq for Point3<T> {
40    #[inline]
41    fn eq(&self, other: &Self) -> bool {
42        self.x == other.x && self.y == other.y && self.z == other.z
43    }
44}
45
46impl<T> Point3<T> {
47    /// Constructs a point from its three coordinates.
48    #[inline]
49    pub const fn new(x: T, y: T, z: T) -> Self {
50        Self { x, y, z }
51    }
52
53    /// Constructs a point from an array `[x, y, z]`.
54    #[inline]
55    pub fn from_array(array: [T; 3]) -> Self {
56        let [x, y, z] = array;
57        Self { x, y, z }
58    }
59
60    /// Returns the coordinates as an array `[x, y, z]`.
61    #[inline]
62    pub fn to_array(self) -> [T; 3] {
63        [self.x, self.y, self.z]
64    }
65
66    /// Returns a copy of `self` with the `x` coordinate replaced.
67    #[inline]
68    pub fn with_x(self, x: T) -> Self {
69        Self {
70            x,
71            y: self.y,
72            z: self.z,
73        }
74    }
75
76    /// Returns a copy of `self` with the `y` coordinate replaced.
77    #[inline]
78    pub fn with_y(self, y: T) -> Self {
79        Self {
80            x: self.x,
81            y,
82            z: self.z,
83        }
84    }
85
86    /// Returns a copy of `self` with the `z` coordinate replaced.
87    #[inline]
88    pub fn with_z(self, z: T) -> Self {
89        Self {
90            x: self.x,
91            y: self.y,
92            z,
93        }
94    }
95
96    /// Reinterprets a displacement from the origin as a point.
97    #[inline]
98    pub fn from_vector(vector: Vector3<T>) -> Self {
99        Self {
100            x: vector.x,
101            y: vector.y,
102            z: vector.z,
103        }
104    }
105
106    /// Returns the position vector of `self` relative to the origin.
107    #[inline]
108    pub fn to_vector(self) -> Vector3<T> {
109        Vector3::new(self.x, self.y, self.z)
110    }
111
112    /// Applies `f` to every coordinate, returning the resulting point.
113    #[inline]
114    pub fn map<U, F: FnMut(T) -> U>(self, mut f: F) -> Point3<U> {
115        Point3 {
116            x: f(self.x),
117            y: f(self.y),
118            z: f(self.z),
119        }
120    }
121
122    /// Combines `self` and `rhs` coordinate-wise through `f`.
123    #[inline]
124    pub fn zip_map<U, R, F: FnMut(T, U) -> R>(self, rhs: Point3<U>, mut f: F) -> Point3<R> {
125        Point3 {
126            x: f(self.x, rhs.x),
127            y: f(self.y, rhs.y),
128            z: f(self.z, rhs.z),
129        }
130    }
131}
132
133impl<T: Copy> Point3<T> {
134    /// Constructs a point with all three coordinates set to `value`.
135    #[inline]
136    pub const fn splat(value: T) -> Self {
137        Self {
138            x: value,
139            y: value,
140            z: value,
141        }
142    }
143
144    /// Constructs a point from the first three elements of a slice.
145    ///
146    /// # Panics
147    ///
148    /// Panics if `slice` has fewer than three elements.
149    #[inline]
150    pub fn from_slice(slice: &[T]) -> Self {
151        Self {
152            x: slice[0],
153            y: slice[1],
154            z: slice[2],
155        }
156    }
157}
158
159impl<T> Index<usize> for Point3<T> {
160    type Output = T;
161
162    /// Returns the coordinate at `index`, where `0`, `1`, and `2` map to `x`,
163    /// `y`, and `z`.
164    ///
165    /// # Panics
166    ///
167    /// Panics if `index` is greater than `2`.
168    #[inline]
169    fn index(&self, index: usize) -> &T {
170        match index {
171            0 => &self.x,
172            1 => &self.y,
173            2 => &self.z,
174            _ => panic!("index out of bounds: Point3 has 3 coordinates but the index is {index}"),
175        }
176    }
177}
178
179impl<T> IndexMut<usize> for Point3<T> {
180    /// Returns the coordinate at `index`, where `0`, `1`, and `2` map to `x`,
181    /// `y`, and `z`.
182    ///
183    /// # Panics
184    ///
185    /// Panics if `index` is greater than `2`.
186    #[inline]
187    fn index_mut(&mut self, index: usize) -> &mut T {
188        match index {
189            0 => &mut self.x,
190            1 => &mut self.y,
191            2 => &mut self.z,
192            _ => panic!("index out of bounds: Point3 has 3 coordinates but the index is {index}"),
193        }
194    }
195}
196
197impl<T: Default> Default for Point3<T> {
198    /// Returns the origin.
199    #[inline]
200    fn default() -> Self {
201        Self {
202            x: T::default(),
203            y: T::default(),
204            z: T::default(),
205        }
206    }
207}
208
209impl<T: Add<Output = T>> Add<Vector3<T>> for Point3<T> {
210    type Output = Self;
211    /// Translates the point by the displacement `rhs`.
212    #[inline]
213    fn add(self, rhs: Vector3<T>) -> Self {
214        Self::new(self.x + rhs.x, self.y + rhs.y, self.z + rhs.z)
215    }
216}
217
218impl<T: AddAssign> AddAssign<Vector3<T>> for Point3<T> {
219    #[inline]
220    fn add_assign(&mut self, rhs: Vector3<T>) {
221        self.x += rhs.x;
222        self.y += rhs.y;
223        self.z += rhs.z;
224    }
225}
226
227impl<T: Sub<Output = T>> Sub<Vector3<T>> for Point3<T> {
228    type Output = Self;
229    /// Translates the point by the negated displacement `rhs`.
230    #[inline]
231    fn sub(self, rhs: Vector3<T>) -> Self {
232        Self::new(self.x - rhs.x, self.y - rhs.y, self.z - rhs.z)
233    }
234}
235
236impl<T: SubAssign> SubAssign<Vector3<T>> for Point3<T> {
237    #[inline]
238    fn sub_assign(&mut self, rhs: Vector3<T>) {
239        self.x -= rhs.x;
240        self.y -= rhs.y;
241        self.z -= rhs.z;
242    }
243}
244
245impl<T: Sub<Output = T>> Sub for Point3<T> {
246    type Output = Vector3<T>;
247    /// Returns the displacement from `rhs` to `self`.
248    #[inline]
249    fn sub(self, rhs: Self) -> Vector3<T> {
250        Vector3::new(self.x - rhs.x, self.y - rhs.y, self.z - rhs.z)
251    }
252}
253
254impl<T: Add<Output = T> + Sub<Output = T> + Copy> Point3<T> {
255    /// Linearly interpolates from `self` toward `rhs` by the factor `t`.
256    ///
257    /// `t == 0` yields `self`, `t == 1` yields `rhs`.
258    #[inline]
259    pub fn lerp<S: Scalar>(self, rhs: Self, t: S) -> Self
260    where
261        T: Mul<S, Output = T>,
262    {
263        self + (rhs - self) * t
264    }
265}
266
267impl<V: Scalar> Point3<V> {
268    /// The origin, `(0, 0, 0)`.
269    pub const ORIGIN: Self = Self::new(V::ZERO, V::ZERO, V::ZERO);
270
271    /// Returns the squared Euclidean distance between `self` and `rhs`.
272    ///
273    /// Cheaper than [`distance`][Self::distance] and sufficient whenever only
274    /// relative distances are compared.
275    #[inline]
276    pub fn distance_squared(self, rhs: Self) -> V {
277        (self - rhs).norm_squared()
278    }
279
280    /// Returns the Euclidean distance between `self` and `rhs`.
281    #[inline]
282    pub fn distance(self, rhs: Self) -> V {
283        (self - rhs).norm()
284    }
285
286    /// Returns the midpoint of the segment between `self` and `rhs`.
287    #[inline]
288    pub fn midpoint(self, rhs: Self) -> Self {
289        self + (rhs - self) * V::from_f64(0.5)
290    }
291
292    /// Returns the centroid (arithmetic mean) of `points`, or the
293    /// [`ORIGIN`][Self::ORIGIN] when `points` is empty.
294    #[inline]
295    pub fn centroid(points: &[Self]) -> Self {
296        if points.is_empty() {
297            return Self::ORIGIN;
298        }
299        let mut sum = Vector3::<V>::ZERO;
300        for p in points {
301            sum += p.to_vector();
302        }
303        Self::from_vector(sum / V::from_f64(points.len() as f64))
304    }
305
306    /// Returns the component-wise minimum of `self` and `rhs`.
307    #[inline]
308    pub fn min(self, rhs: Self) -> Self {
309        Self::new(self.x.min(rhs.x), self.y.min(rhs.y), self.z.min(rhs.z))
310    }
311
312    /// Returns the component-wise maximum of `self` and `rhs`.
313    #[inline]
314    pub fn max(self, rhs: Self) -> Self {
315        Self::new(self.x.max(rhs.x), self.y.max(rhs.y), self.z.max(rhs.z))
316    }
317
318    /// Restricts every coordinate to the interval `[min, max]`, confining the
319    /// point to an axis-aligned bounding box.
320    ///
321    /// # Panics
322    ///
323    /// Panics if any coordinate of `min` exceeds the corresponding coordinate
324    /// of `max`.
325    #[inline]
326    pub fn clamp(self, min: Self, max: Self) -> Self {
327        Self::new(
328            self.x.clamp(min.x, max.x),
329            self.y.clamp(min.y, max.y),
330            self.z.clamp(min.z, max.z),
331        )
332    }
333
334    /// Returns the component-wise floor (largest integer not exceeding each
335    /// coordinate).
336    #[inline]
337    pub fn floor(self) -> Self {
338        Self::new(self.x.floor(), self.y.floor(), self.z.floor())
339    }
340
341    /// Returns the component-wise ceiling (smallest integer not less than each
342    /// coordinate).
343    #[inline]
344    pub fn ceil(self) -> Self {
345        Self::new(self.x.ceil(), self.y.ceil(), self.z.ceil())
346    }
347
348    /// Returns the component-wise nearest integer, rounding halves away from
349    /// zero.
350    #[inline]
351    pub fn round(self) -> Self {
352        Self::new(self.x.round(), self.y.round(), self.z.round())
353    }
354
355    /// Returns the component-wise nearest integer, rounding halves to even.
356    #[inline]
357    pub fn round_ties_even(self) -> Self {
358        Self::new(
359            self.x.round_ties_even(),
360            self.y.round_ties_even(),
361            self.z.round_ties_even(),
362        )
363    }
364
365    /// Returns the component-wise truncation toward zero.
366    #[inline]
367    pub fn trunc(self) -> Self {
368        Self::new(self.x.trunc(), self.y.trunc(), self.z.trunc())
369    }
370
371    /// Returns the component-wise fractional part, i.e. the position of
372    /// `self` within the unit-grid cell it occupies.
373    #[inline]
374    pub fn fract(self) -> Self {
375        Self::new(self.x.fract(), self.y.fract(), self.z.fract())
376    }
377
378    /// Returns `true` if every coordinate is finite.
379    #[inline]
380    pub fn is_finite(self) -> bool {
381        self.x.is_finite() && self.y.is_finite() && self.z.is_finite()
382    }
383
384    /// Returns `true` if any coordinate is positive or negative infinity.
385    #[inline]
386    pub fn is_infinite(self) -> bool {
387        self.x.is_infinite() || self.y.is_infinite() || self.z.is_infinite()
388    }
389
390    /// Returns `true` if any coordinate is `NaN`.
391    #[inline]
392    pub fn is_nan(self) -> bool {
393        self.x.is_nan() || self.y.is_nan() || self.z.is_nan()
394    }
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400
401    #[test]
402    fn new() {
403        let p = Point3::new(1.0, 2.0, 3.0);
404        assert_eq!((p.x, p.y, p.z), (1.0, 2.0, 3.0));
405    }
406
407    #[test]
408    fn from_array() {
409        assert_eq!(
410            Point3::from_array([1.0, 2.0, 3.0]),
411            Point3::new(1.0, 2.0, 3.0)
412        );
413    }
414
415    #[test]
416    fn to_array() {
417        assert_eq!(Point3::new(1.0, 2.0, 3.0).to_array(), [1.0, 2.0, 3.0]);
418    }
419
420    #[test]
421    fn splat() {
422        assert_eq!(Point3::splat(5.0), Point3::new(5.0, 5.0, 5.0));
423    }
424
425    #[test]
426    fn from_slice() {
427        assert_eq!(
428            Point3::from_slice(&[1.0, 2.0, 3.0, 4.0]),
429            Point3::new(1.0, 2.0, 3.0)
430        );
431    }
432
433    #[test]
434    #[should_panic]
435    fn from_slice_panics_when_too_short() {
436        Point3::<f64>::from_slice(&[1.0, 2.0]);
437    }
438
439    #[test]
440    fn with_x() {
441        assert_eq!(
442            Point3::new(1.0, 2.0, 3.0).with_x(9.0),
443            Point3::new(9.0, 2.0, 3.0)
444        );
445    }
446
447    #[test]
448    fn with_y() {
449        assert_eq!(
450            Point3::new(1.0, 2.0, 3.0).with_y(9.0),
451            Point3::new(1.0, 9.0, 3.0)
452        );
453    }
454
455    #[test]
456    fn with_z() {
457        assert_eq!(
458            Point3::new(1.0, 2.0, 3.0).with_z(9.0),
459            Point3::new(1.0, 2.0, 9.0)
460        );
461    }
462
463    #[test]
464    fn from_vector() {
465        assert_eq!(
466            Point3::from_vector(Vector3::new(1.0, 2.0, 3.0)),
467            Point3::new(1.0, 2.0, 3.0)
468        );
469    }
470
471    #[test]
472    fn to_vector() {
473        assert_eq!(
474            Point3::new(1.0, 2.0, 3.0).to_vector(),
475            Vector3::new(1.0, 2.0, 3.0)
476        );
477    }
478
479    #[test]
480    fn map() {
481        assert_eq!(
482            Point3::new(1.0, 2.0, 3.0).map(|c| c * 2.0),
483            Point3::new(2.0, 4.0, 6.0)
484        );
485    }
486
487    #[test]
488    fn zip_map() {
489        assert_eq!(
490            Point3::new(1.0, 2.0, 3.0).zip_map(Point3::new(4.0, 5.0, 6.0), |a, b| a + b),
491            Point3::new(5.0, 7.0, 9.0)
492        );
493    }
494
495    #[test]
496    fn default_is_origin() {
497        assert_eq!(Point3::<f64>::default(), Point3::new(0.0, 0.0, 0.0));
498    }
499
500    #[test]
501    fn copy_and_clone() {
502        let a = Point3::new(1.0, 2.0, 3.0);
503        let b = a;
504        let c = ::core::clone::Clone::clone(&a);
505        assert_eq!(a, b);
506        assert_eq!(a, c);
507    }
508
509    #[test]
510    fn eq() {
511        let a = Point3::new(1.0, 2.0, 3.0);
512        assert_eq!(a, Point3::new(1.0, 2.0, 3.0));
513        assert_ne!(a, Point3::new(1.0, 2.0, 4.0));
514    }
515
516    #[test]
517    fn debug() {
518        assert_eq!(
519            format!("{:?}", Point3::new(1.0, 2.0, 3.0)),
520            "Point3 { x: 1.0, y: 2.0, z: 3.0 }"
521        );
522    }
523
524    #[test]
525    fn index() {
526        let p = Point3::new(1.0, 2.0, 3.0);
527        assert_eq!((p[0], p[1], p[2]), (1.0, 2.0, 3.0));
528    }
529
530    #[test]
531    fn index_mut() {
532        let mut p = Point3::new(1.0, 2.0, 3.0);
533        p[1] = 9.0;
534        assert_eq!(p.y, 9.0);
535    }
536
537    #[test]
538    #[should_panic]
539    fn index_panics_when_out_of_bounds() {
540        let _ = Point3::new(1.0, 2.0, 3.0)[3];
541    }
542
543    #[test]
544    #[should_panic]
545    fn index_mut_panics_when_out_of_bounds() {
546        Point3::new(1.0, 2.0, 3.0)[3] = 0.0;
547    }
548
549    #[test]
550    fn origin_constant() {
551        assert_eq!(Point3::<f64>::ORIGIN, Point3::new(0.0, 0.0, 0.0));
552    }
553
554    #[test]
555    fn add_vector() {
556        assert_eq!(
557            Point3::new(1.0, 2.0, 3.0) + Vector3::new(1.0, 1.0, 1.0),
558            Point3::new(2.0, 3.0, 4.0)
559        );
560    }
561
562    #[test]
563    fn add_assign_vector() {
564        let mut p = Point3::new(1.0, 2.0, 3.0);
565        p += Vector3::new(1.0, 1.0, 1.0);
566        assert_eq!(p, Point3::new(2.0, 3.0, 4.0));
567    }
568
569    #[test]
570    fn sub_vector() {
571        assert_eq!(
572            Point3::new(2.0, 3.0, 4.0) - Vector3::new(1.0, 1.0, 1.0),
573            Point3::new(1.0, 2.0, 3.0)
574        );
575    }
576
577    #[test]
578    fn sub_assign_vector() {
579        let mut p = Point3::new(2.0, 3.0, 4.0);
580        p -= Vector3::new(1.0, 1.0, 1.0);
581        assert_eq!(p, Point3::new(1.0, 2.0, 3.0));
582    }
583
584    #[test]
585    fn sub_point_yields_vector() {
586        assert_eq!(
587            Point3::new(3.0, 3.0, 3.0) - Point3::new(1.0, 1.0, 1.0),
588            Vector3::new(2.0, 2.0, 2.0)
589        );
590    }
591
592    #[test]
593    fn lerp() {
594        assert_eq!(
595            Point3::new(0.0, 0.0, 0.0).lerp(Point3::new(2.0, 4.0, 6.0), 0.5),
596            Point3::new(1.0, 2.0, 3.0)
597        );
598    }
599
600    #[test]
601    fn distance_squared() {
602        assert_eq!(
603            Point3::new(0.0, 0.0, 0.0).distance_squared(Point3::new(3.0, 4.0, 0.0)),
604            25.0
605        );
606    }
607
608    #[test]
609    fn distance() {
610        assert_eq!(
611            Point3::new(0.0, 0.0, 0.0).distance(Point3::new(3.0, 4.0, 0.0)),
612            5.0
613        );
614    }
615
616    #[test]
617    fn midpoint() {
618        assert_eq!(
619            Point3::new(0.0, 0.0, 0.0).midpoint(Point3::new(2.0, 4.0, 6.0)),
620            Point3::new(1.0, 2.0, 3.0)
621        );
622    }
623
624    #[test]
625    fn centroid() {
626        let points = [
627            Point3::new(0.0, 0.0, 0.0),
628            Point3::new(2.0, 0.0, 0.0),
629            Point3::new(1.0, 3.0, 0.0),
630        ];
631        assert_eq!(Point3::centroid(&points), Point3::new(1.0, 1.0, 0.0));
632    }
633
634    #[test]
635    fn centroid_empty_is_origin() {
636        assert_eq!(Point3::<f64>::centroid(&[]), Point3::ORIGIN);
637    }
638
639    #[test]
640    fn min() {
641        assert_eq!(
642            Point3::new(1.0, 5.0, 3.0).min(Point3::new(4.0, 2.0, 6.0)),
643            Point3::new(1.0, 2.0, 3.0)
644        );
645    }
646
647    #[test]
648    fn max() {
649        assert_eq!(
650            Point3::new(1.0, 5.0, 3.0).max(Point3::new(4.0, 2.0, 6.0)),
651            Point3::new(4.0, 5.0, 6.0)
652        );
653    }
654
655    #[test]
656    fn clamp() {
657        assert_eq!(
658            Point3::new(5.0, -1.0, 2.0).clamp(Point3::splat(0.0), Point3::splat(3.0)),
659            Point3::new(3.0, 0.0, 2.0)
660        );
661    }
662
663    #[test]
664    #[should_panic]
665    fn clamp_panics_when_min_gt_max() {
666        Point3::new(1.0, 1.0, 1.0).clamp(Point3::splat(3.0), Point3::splat(0.0));
667    }
668
669    #[test]
670    fn floor() {
671        assert_eq!(
672            Point3::new(1.7, -1.2, 2.0).floor(),
673            Point3::new(1.0, -2.0, 2.0)
674        );
675    }
676
677    #[test]
678    fn ceil() {
679        assert_eq!(
680            Point3::new(1.2, -1.7, 2.0).ceil(),
681            Point3::new(2.0, -1.0, 2.0)
682        );
683    }
684
685    #[test]
686    fn round() {
687        assert_eq!(
688            Point3::new(1.5, -1.5, 2.4).round(),
689            Point3::new(2.0, -2.0, 2.0)
690        );
691    }
692
693    #[test]
694    fn round_ties_even() {
695        assert_eq!(
696            Point3::new(1.5, 2.5, -1.5).round_ties_even(),
697            Point3::new(2.0, 2.0, -2.0)
698        );
699    }
700
701    #[test]
702    fn trunc() {
703        assert_eq!(
704            Point3::new(1.7, -1.7, 2.0).trunc(),
705            Point3::new(1.0, -1.0, 2.0)
706        );
707    }
708
709    #[test]
710    fn fract() {
711        assert_eq!(
712            Point3::new(2.5, -1.25, 0.0).fract(),
713            Point3::new(0.5, -0.25, 0.0)
714        );
715    }
716
717    #[test]
718    fn is_finite() {
719        assert!(Point3::new(1.0, 2.0, 3.0).is_finite());
720        assert!(!Point3::new(1.0, f64::INFINITY, 3.0).is_finite());
721    }
722
723    #[test]
724    fn is_infinite() {
725        assert!(Point3::new(1.0, f64::INFINITY, 3.0).is_infinite());
726        assert!(!Point3::new(1.0, 2.0, 3.0).is_infinite());
727    }
728
729    #[test]
730    fn is_nan() {
731        assert!(Point3::new(1.0, f64::NAN, 3.0).is_nan());
732        assert!(!Point3::new(1.0, 2.0, 3.0).is_nan());
733    }
734
735    #[test]
736    fn f32_distance() {
737        assert_eq!(
738            Point3::<f32>::new(0.0, 0.0, 0.0).distance(Point3::new(3.0, 4.0, 0.0)),
739            5.0
740        );
741    }
742}