Skip to main content

spintronics/
vector3.rs

1//! Simple 3D vector implementation for spintronics
2//!
3//! This is a temporary internal implementation until scirs2-linalg is available.
4
5use std::fmt;
6use std::ops::{Add, Mul, Sub};
7
8#[cfg(feature = "serde")]
9use serde::{Deserialize, Serialize};
10
11/// A 3D vector with generic type T
12#[derive(Debug, Clone, Copy, PartialEq)]
13#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
14pub struct Vector3<T> {
15    /// X component
16    pub x: T,
17    /// Y component
18    pub y: T,
19    /// Z component
20    pub z: T,
21}
22
23impl<T> Vector3<T> {
24    /// Create a new 3D vector
25    pub const fn new(x: T, y: T, z: T) -> Self {
26        Self { x, y, z }
27    }
28}
29
30impl Vector3<f64> {
31    /// Create a zero vector
32    pub const fn zero() -> Self {
33        Self {
34            x: 0.0,
35            y: 0.0,
36            z: 0.0,
37        }
38    }
39
40    /// Create a unit vector along the x-axis
41    pub const fn unit_x() -> Self {
42        Self {
43            x: 1.0,
44            y: 0.0,
45            z: 0.0,
46        }
47    }
48
49    /// Create a unit vector along the y-axis
50    pub const fn unit_y() -> Self {
51        Self {
52            x: 0.0,
53            y: 1.0,
54            z: 0.0,
55        }
56    }
57
58    /// Create a unit vector along the z-axis
59    pub const fn unit_z() -> Self {
60        Self {
61            x: 0.0,
62            y: 0.0,
63            z: 1.0,
64        }
65    }
66
67    /// Calculate the dot product with another vector
68    #[inline]
69    pub fn dot(&self, other: &Self) -> f64 {
70        self.x * other.x + self.y * other.y + self.z * other.z
71    }
72
73    /// Calculate the cross product with another vector
74    #[inline]
75    pub fn cross(&self, other: &Self) -> Self {
76        Self {
77            x: self.y * other.z - self.z * other.y,
78            y: self.z * other.x - self.x * other.z,
79            z: self.x * other.y - self.y * other.x,
80        }
81    }
82
83    /// Calculate the magnitude (length) of the vector
84    #[inline]
85    pub fn magnitude(&self) -> f64 {
86        (self.x * self.x + self.y * self.y + self.z * self.z).sqrt()
87    }
88
89    /// Calculate the squared magnitude (avoids sqrt for performance)
90    #[inline]
91    pub fn magnitude_squared(&self) -> f64 {
92        self.x * self.x + self.y * self.y + self.z * self.z
93    }
94
95    /// Return a normalized (unit) vector in the same direction
96    #[inline]
97    pub fn normalize(&self) -> Self {
98        let mag = self.magnitude();
99        if mag > 0.0 {
100            Self {
101                x: self.x / mag,
102                y: self.y / mag,
103                z: self.z / mag,
104            }
105        } else {
106            *self
107        }
108    }
109
110    /// Check if the vector is normalized (unit length)
111    ///
112    /// Uses a tolerance of 1e-10 for floating point comparison
113    #[inline]
114    pub fn is_normalized(&self) -> bool {
115        (self.magnitude_squared() - 1.0).abs() < 1e-10
116    }
117
118    /// Calculate angle between two vectors in radians
119    ///
120    /// Returns angle in range \[0, π\]
121    #[inline]
122    pub fn angle_between(&self, other: &Self) -> f64 {
123        let dot = self.dot(other);
124        let mags = self.magnitude() * other.magnitude();
125        if mags > 0.0 {
126            (dot / mags).clamp(-1.0, 1.0).acos()
127        } else {
128            0.0
129        }
130    }
131
132    /// Project this vector onto another vector
133    ///
134    /// Returns the component of self in the direction of other
135    #[inline]
136    pub fn project(&self, other: &Self) -> Self {
137        let other_mag_sq = other.magnitude_squared();
138        if other_mag_sq > 0.0 {
139            *other * (self.dot(other) / other_mag_sq)
140        } else {
141            Self::zero()
142        }
143    }
144}
145
146impl Add for Vector3<f64> {
147    type Output = Self;
148
149    fn add(self, other: Self) -> Self {
150        Self {
151            x: self.x + other.x,
152            y: self.y + other.y,
153            z: self.z + other.z,
154        }
155    }
156}
157
158impl Sub for Vector3<f64> {
159    type Output = Self;
160
161    fn sub(self, other: Self) -> Self {
162        Self {
163            x: self.x - other.x,
164            y: self.y - other.y,
165            z: self.z - other.z,
166        }
167    }
168}
169
170impl Mul<f64> for Vector3<f64> {
171    type Output = Self;
172
173    fn mul(self, scalar: f64) -> Self {
174        Self {
175            x: self.x * scalar,
176            y: self.y * scalar,
177            z: self.z * scalar,
178        }
179    }
180}
181
182impl fmt::Display for Vector3<f64> {
183    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
184        write!(f, "({:.6}, {:.6}, {:.6})", self.x, self.y, self.z)
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191
192    #[test]
193    fn test_cross_product() {
194        let v1 = Vector3::new(1.0, 0.0, 0.0);
195        let v2 = Vector3::new(0.0, 1.0, 0.0);
196        let result = v1.cross(&v2);
197        assert!((result.x - 0.0).abs() < 1e-10);
198        assert!((result.y - 0.0).abs() < 1e-10);
199        assert!((result.z - 1.0).abs() < 1e-10);
200    }
201
202    #[test]
203    fn test_normalize() {
204        let v = Vector3::new(3.0, 4.0, 0.0);
205        let normalized = v.normalize();
206        assert!((normalized.magnitude() - 1.0).abs() < 1e-10);
207    }
208
209    #[test]
210    fn test_zero() {
211        let v = Vector3::zero();
212        assert_eq!(v.x, 0.0);
213        assert_eq!(v.y, 0.0);
214        assert_eq!(v.z, 0.0);
215    }
216
217    #[test]
218    fn test_unit_vectors() {
219        let ux = Vector3::unit_x();
220        let uy = Vector3::unit_y();
221        let uz = Vector3::unit_z();
222
223        assert_eq!(ux.x, 1.0);
224        assert_eq!(ux.y, 0.0);
225        assert_eq!(ux.z, 0.0);
226
227        assert_eq!(uy.x, 0.0);
228        assert_eq!(uy.y, 1.0);
229        assert_eq!(uy.z, 0.0);
230
231        assert_eq!(uz.x, 0.0);
232        assert_eq!(uz.y, 0.0);
233        assert_eq!(uz.z, 1.0);
234
235        // All should be normalized
236        assert!(ux.is_normalized());
237        assert!(uy.is_normalized());
238        assert!(uz.is_normalized());
239    }
240
241    #[test]
242    fn test_magnitude_squared() {
243        let v = Vector3::new(3.0, 4.0, 0.0);
244        assert_eq!(v.magnitude_squared(), 25.0);
245        assert!((v.magnitude() - 5.0).abs() < 1e-10);
246    }
247
248    #[test]
249    fn test_is_normalized() {
250        let v1 = Vector3::new(1.0, 0.0, 0.0);
251        assert!(v1.is_normalized());
252
253        let v2 = Vector3::new(3.0, 4.0, 0.0);
254        assert!(!v2.is_normalized());
255
256        let v3 = v2.normalize();
257        assert!(v3.is_normalized());
258    }
259
260    #[test]
261    fn test_angle_between() {
262        let v1 = Vector3::new(1.0, 0.0, 0.0);
263        let v2 = Vector3::new(0.0, 1.0, 0.0);
264
265        // Should be 90 degrees (π/2 radians)
266        let angle = v1.angle_between(&v2);
267        assert!((angle - std::f64::consts::FRAC_PI_2).abs() < 1e-10);
268
269        // Parallel vectors
270        let v3 = Vector3::new(2.0, 0.0, 0.0);
271        let angle2 = v1.angle_between(&v3);
272        assert!(angle2.abs() < 1e-10);
273
274        // Anti-parallel vectors
275        let v4 = Vector3::new(-1.0, 0.0, 0.0);
276        let angle3 = v1.angle_between(&v4);
277        assert!((angle3 - std::f64::consts::PI).abs() < 1e-10);
278    }
279
280    #[test]
281    fn test_project() {
282        let v1 = Vector3::new(3.0, 4.0, 0.0);
283        let v2 = Vector3::new(1.0, 0.0, 0.0);
284
285        // Project v1 onto v2 (should give (3, 0, 0))
286        let proj = v1.project(&v2);
287        assert!((proj.x - 3.0).abs() < 1e-10);
288        assert!(proj.y.abs() < 1e-10);
289        assert!(proj.z.abs() < 1e-10);
290
291        // Project onto diagonal
292        let v3 = Vector3::new(1.0, 1.0, 0.0);
293        let proj2 = v1.project(&v3);
294        // Projection should be (3.5, 3.5, 0)
295        assert!((proj2.x - 3.5).abs() < 1e-10);
296        assert!((proj2.y - 3.5).abs() < 1e-10);
297        assert!(proj2.z.abs() < 1e-10);
298    }
299}