Skip to main content

scivex_core/
math.rs

1//! Element-wise tensor math functions (ufuncs).
2//!
3//! Provides both methods on `Tensor<T: Float>` and free functions mirroring
4//! `NumPy`'s top-level ufuncs (`np.sin`, `np.exp`, etc.).
5
6use crate::Float;
7use crate::tensor::Tensor;
8
9// ======================================================================
10// Tensor methods
11// ======================================================================
12
13impl<T: Float> Tensor<T> {
14    /// Element-wise absolute value.
15    ///
16    /// # Examples
17    ///
18    /// ```
19    /// # use scivex_core::Tensor;
20    /// let t = Tensor::from_vec(vec![-3.0_f64, -1.0, 0.0, 2.0], vec![4]).unwrap();
21    /// assert_eq!(t.abs().as_slice(), &[3.0, 1.0, 0.0, 2.0]);
22    /// ```
23    #[inline]
24    pub fn abs(&self) -> Tensor<T> {
25        self.map(Float::abs)
26    }
27
28    /// Element-wise square root.
29    ///
30    /// # Examples
31    ///
32    /// ```
33    /// # use scivex_core::Tensor;
34    /// let t = Tensor::from_vec(vec![4.0_f64, 9.0, 16.0], vec![3]).unwrap();
35    /// assert_eq!(t.sqrt().as_slice(), &[2.0, 3.0, 4.0]);
36    /// ```
37    #[inline]
38    pub fn sqrt(&self) -> Tensor<T> {
39        self.map(Float::sqrt)
40    }
41
42    /// Element-wise sine.
43    ///
44    /// # Examples
45    ///
46    /// ```
47    /// # use scivex_core::Tensor;
48    /// let t = Tensor::from_vec(vec![0.0_f64, std::f64::consts::FRAC_PI_2], vec![2]).unwrap();
49    /// let s = t.sin();
50    /// assert!((s.as_slice()[0]).abs() < 1e-15);
51    /// assert!((s.as_slice()[1] - 1.0).abs() < 1e-15);
52    /// ```
53    #[inline]
54    pub fn sin(&self) -> Tensor<T> {
55        self.map(Float::sin)
56    }
57
58    /// Element-wise cosine.
59    ///
60    /// # Examples
61    ///
62    /// ```
63    /// # use scivex_core::Tensor;
64    /// let t = Tensor::from_vec(vec![0.0_f64], vec![1]).unwrap();
65    /// assert!((t.cos().as_slice()[0] - 1.0).abs() < 1e-15);
66    /// ```
67    #[inline]
68    pub fn cos(&self) -> Tensor<T> {
69        self.map(Float::cos)
70    }
71
72    /// Element-wise tangent.
73    ///
74    /// # Examples
75    ///
76    /// ```
77    /// # use scivex_core::Tensor;
78    /// let t = Tensor::from_vec(vec![0.0_f64], vec![1]).unwrap();
79    /// assert!(t.tan().as_slice()[0].abs() < 1e-15);
80    /// ```
81    #[inline]
82    pub fn tan(&self) -> Tensor<T> {
83        self.map(Float::tan)
84    }
85
86    /// Element-wise natural exponential.
87    ///
88    /// # Examples
89    ///
90    /// ```
91    /// # use scivex_core::Tensor;
92    /// let t = Tensor::from_vec(vec![0.0_f64, 1.0], vec![2]).unwrap();
93    /// let e = t.exp();
94    /// assert!((e.as_slice()[0] - 1.0).abs() < 1e-15);
95    /// assert!((e.as_slice()[1] - std::f64::consts::E).abs() < 1e-14);
96    /// ```
97    #[inline]
98    pub fn exp(&self) -> Tensor<T> {
99        self.map(Float::exp)
100    }
101
102    /// Element-wise natural logarithm.
103    ///
104    /// # Examples
105    ///
106    /// ```
107    /// # use scivex_core::Tensor;
108    /// let t = Tensor::from_vec(vec![1.0_f64, std::f64::consts::E], vec![2]).unwrap();
109    /// let l = t.ln();
110    /// assert!((l.as_slice()[0]).abs() < 1e-15);
111    /// assert!((l.as_slice()[1] - 1.0).abs() < 1e-14);
112    /// ```
113    #[inline]
114    pub fn ln(&self) -> Tensor<T> {
115        self.map(Float::ln)
116    }
117
118    /// Element-wise base-2 logarithm.
119    ///
120    /// # Examples
121    ///
122    /// ```
123    /// # use scivex_core::Tensor;
124    /// let t = Tensor::from_vec(vec![1.0_f64, 2.0, 4.0, 8.0], vec![4]).unwrap();
125    /// assert_eq!(t.log2().as_slice(), &[0.0, 1.0, 2.0, 3.0]);
126    /// ```
127    #[inline]
128    pub fn log2(&self) -> Tensor<T> {
129        self.map(Float::log2)
130    }
131
132    /// Element-wise base-10 logarithm.
133    ///
134    /// # Examples
135    ///
136    /// ```
137    /// # use scivex_core::Tensor;
138    /// let t = Tensor::from_vec(vec![1.0_f64, 10.0, 100.0], vec![3]).unwrap();
139    /// let l = t.log10();
140    /// assert!((l.as_slice()[1] - 1.0).abs() < 1e-15);
141    /// assert!((l.as_slice()[2] - 2.0).abs() < 1e-14);
142    /// ```
143    #[inline]
144    pub fn log10(&self) -> Tensor<T> {
145        self.map(Float::log10)
146    }
147
148    /// Element-wise floor.
149    ///
150    /// # Examples
151    ///
152    /// ```
153    /// # use scivex_core::Tensor;
154    /// let t = Tensor::from_vec(vec![1.3_f64, 2.7, -0.5], vec![3]).unwrap();
155    /// assert_eq!(t.floor().as_slice(), &[1.0, 2.0, -1.0]);
156    /// ```
157    #[inline]
158    pub fn floor(&self) -> Tensor<T> {
159        self.map(Float::floor)
160    }
161
162    /// Element-wise ceiling.
163    ///
164    /// # Examples
165    ///
166    /// ```
167    /// # use scivex_core::Tensor;
168    /// let t = Tensor::from_vec(vec![1.3_f64, 2.7, -0.5], vec![3]).unwrap();
169    /// assert_eq!(t.ceil().as_slice(), &[2.0, 3.0, 0.0]);
170    /// ```
171    #[inline]
172    pub fn ceil(&self) -> Tensor<T> {
173        self.map(Float::ceil)
174    }
175
176    /// Element-wise rounding to nearest integer.
177    ///
178    /// # Examples
179    ///
180    /// ```
181    /// # use scivex_core::Tensor;
182    /// let t = Tensor::from_vec(vec![1.3_f64, 2.7], vec![2]).unwrap();
183    /// assert_eq!(t.round().as_slice(), &[1.0, 3.0]);
184    /// ```
185    #[inline]
186    pub fn round(&self) -> Tensor<T> {
187        self.map(Float::round)
188    }
189
190    /// Element-wise reciprocal (`1/x`).
191    ///
192    /// # Examples
193    ///
194    /// ```
195    /// # use scivex_core::Tensor;
196    /// let t = Tensor::from_vec(vec![2.0_f64, 4.0, 5.0], vec![3]).unwrap();
197    /// assert_eq!(t.recip().as_slice(), &[0.5, 0.25, 0.2]);
198    /// ```
199    #[inline]
200    pub fn recip(&self) -> Tensor<T> {
201        self.map(Float::recip)
202    }
203
204    /// Raise every element to a floating-point power.
205    ///
206    /// # Examples
207    ///
208    /// ```
209    /// # use scivex_core::Tensor;
210    /// let t = Tensor::from_vec(vec![2.0_f64, 3.0], vec![2]).unwrap();
211    /// let p = t.powf(3.0);
212    /// assert!((p.as_slice()[0] - 8.0).abs() < 1e-14);
213    /// assert!((p.as_slice()[1] - 27.0).abs() < 1e-14);
214    /// ```
215    #[inline]
216    pub fn powf(&self, exponent: T) -> Tensor<T> {
217        self.map(|x| x.powf(exponent))
218    }
219
220    /// Raise every element to an integer power.
221    ///
222    /// # Examples
223    ///
224    /// ```
225    /// # use scivex_core::Tensor;
226    /// let t = Tensor::from_vec(vec![2.0_f64, 3.0], vec![2]).unwrap();
227    /// let p = t.powi(2);
228    /// assert!((p.as_slice()[0] - 4.0).abs() < 1e-14);
229    /// assert!((p.as_slice()[1] - 9.0).abs() < 1e-14);
230    /// ```
231    #[inline]
232    pub fn powi(&self, n: i32) -> Tensor<T> {
233        self.map(|x| x.powi(n))
234    }
235
236    /// Clamp every element to `[min, max]`.
237    ///
238    /// # Examples
239    ///
240    /// ```
241    /// # use scivex_core::Tensor;
242    /// let t = Tensor::from_vec(vec![-5.0_f64, 0.5, 3.0, 10.0], vec![4]).unwrap();
243    /// assert_eq!(t.clamp(0.0, 2.0).as_slice(), &[0.0, 0.5, 2.0, 2.0]);
244    /// ```
245    #[inline]
246    pub fn clamp(&self, min: T, max: T) -> Tensor<T> {
247        self.map(|x| x.max(min).min(max))
248    }
249}
250
251// ======================================================================
252// Free functions
253// ======================================================================
254
255/// Element-wise absolute value.
256///
257/// # Examples
258///
259/// ```
260/// # use scivex_core::{Tensor, math};
261/// let t = Tensor::from_vec(vec![-2.0_f64, 3.0], vec![2]).unwrap();
262/// assert_eq!(math::abs(&t).as_slice(), &[2.0, 3.0]);
263/// ```
264pub fn abs<T: Float>(t: &Tensor<T>) -> Tensor<T> {
265    t.abs()
266}
267
268/// Element-wise square root.
269///
270/// # Examples
271///
272/// ```
273/// # use scivex_core::{Tensor, math};
274/// let t = Tensor::from_vec(vec![4.0_f64, 9.0], vec![2]).unwrap();
275/// assert_eq!(math::sqrt(&t).as_slice(), &[2.0, 3.0]);
276/// ```
277pub fn sqrt<T: Float>(t: &Tensor<T>) -> Tensor<T> {
278    t.sqrt()
279}
280
281/// Element-wise sine.
282///
283/// # Examples
284///
285/// ```
286/// # use scivex_core::{Tensor, math};
287/// let t = Tensor::from_vec(vec![0.0_f64], vec![1]).unwrap();
288/// assert!(math::sin(&t).as_slice()[0].abs() < 1e-15);
289/// ```
290pub fn sin<T: Float>(t: &Tensor<T>) -> Tensor<T> {
291    t.sin()
292}
293
294/// Element-wise cosine.
295///
296/// # Examples
297///
298/// ```
299/// # use scivex_core::{Tensor, math};
300/// let t = Tensor::from_vec(vec![0.0_f64], vec![1]).unwrap();
301/// assert!((math::cos(&t).as_slice()[0] - 1.0).abs() < 1e-15);
302/// ```
303pub fn cos<T: Float>(t: &Tensor<T>) -> Tensor<T> {
304    t.cos()
305}
306
307/// Element-wise tangent.
308///
309/// # Examples
310///
311/// ```
312/// # use scivex_core::{Tensor, math};
313/// let t = Tensor::from_vec(vec![0.0_f64], vec![1]).unwrap();
314/// assert!(math::tan(&t).as_slice()[0].abs() < 1e-15);
315/// ```
316pub fn tan<T: Float>(t: &Tensor<T>) -> Tensor<T> {
317    t.tan()
318}
319
320/// Element-wise natural exponential.
321///
322/// # Examples
323///
324/// ```
325/// # use scivex_core::{Tensor, math};
326/// let t = Tensor::from_vec(vec![0.0_f64, 1.0], vec![2]).unwrap();
327/// assert!((math::exp(&t).as_slice()[0] - 1.0).abs() < 1e-15);
328/// ```
329pub fn exp<T: Float>(t: &Tensor<T>) -> Tensor<T> {
330    t.exp()
331}
332
333/// Element-wise natural logarithm.
334///
335/// # Examples
336///
337/// ```
338/// # use scivex_core::{Tensor, math};
339/// let t = Tensor::from_vec(vec![1.0_f64, std::f64::consts::E], vec![2]).unwrap();
340/// assert!(math::ln(&t).as_slice()[0].abs() < 1e-15);
341/// ```
342pub fn ln<T: Float>(t: &Tensor<T>) -> Tensor<T> {
343    t.ln()
344}
345
346/// Element-wise base-2 logarithm.
347///
348/// # Examples
349///
350/// ```
351/// # use scivex_core::{Tensor, math};
352/// let t = Tensor::from_vec(vec![1.0_f64, 2.0, 4.0], vec![3]).unwrap();
353/// assert_eq!(math::log2(&t).as_slice(), &[0.0, 1.0, 2.0]);
354/// ```
355pub fn log2<T: Float>(t: &Tensor<T>) -> Tensor<T> {
356    t.log2()
357}
358
359/// Element-wise base-10 logarithm.
360///
361/// # Examples
362///
363/// ```
364/// # use scivex_core::{Tensor, math};
365/// let t = Tensor::from_vec(vec![1.0_f64, 10.0, 100.0], vec![3]).unwrap();
366/// assert!((math::log10(&t).as_slice()[1] - 1.0).abs() < 1e-15);
367/// ```
368pub fn log10<T: Float>(t: &Tensor<T>) -> Tensor<T> {
369    t.log10()
370}
371
372/// Element-wise floor.
373///
374/// # Examples
375///
376/// ```
377/// # use scivex_core::{Tensor, math};
378/// let t = Tensor::from_vec(vec![1.7_f64, -0.3], vec![2]).unwrap();
379/// assert_eq!(math::floor(&t).as_slice(), &[1.0, -1.0]);
380/// ```
381pub fn floor<T: Float>(t: &Tensor<T>) -> Tensor<T> {
382    t.floor()
383}
384
385/// Element-wise ceiling.
386///
387/// # Examples
388///
389/// ```
390/// # use scivex_core::{Tensor, math};
391/// let t = Tensor::from_vec(vec![1.1_f64, -0.3], vec![2]).unwrap();
392/// assert_eq!(math::ceil(&t).as_slice(), &[2.0, 0.0]);
393/// ```
394pub fn ceil<T: Float>(t: &Tensor<T>) -> Tensor<T> {
395    t.ceil()
396}
397
398/// Element-wise rounding.
399///
400/// # Examples
401///
402/// ```
403/// # use scivex_core::{Tensor, math};
404/// let t = Tensor::from_vec(vec![1.3_f64, 2.7], vec![2]).unwrap();
405/// assert_eq!(math::round(&t).as_slice(), &[1.0, 3.0]);
406/// ```
407pub fn round<T: Float>(t: &Tensor<T>) -> Tensor<T> {
408    t.round()
409}
410
411/// Element-wise reciprocal.
412///
413/// # Examples
414///
415/// ```
416/// # use scivex_core::{Tensor, math};
417/// let t = Tensor::from_vec(vec![2.0_f64, 4.0], vec![2]).unwrap();
418/// assert_eq!(math::recip(&t).as_slice(), &[0.5, 0.25]);
419/// ```
420pub fn recip<T: Float>(t: &Tensor<T>) -> Tensor<T> {
421    t.recip()
422}
423
424#[cfg(test)]
425#[allow(clippy::float_cmp)]
426mod tests {
427    use super::*;
428
429    #[test]
430    fn test_sin_cos_known_values() {
431        let t = Tensor::from_vec(vec![0.0_f64, std::f64::consts::FRAC_PI_2], vec![2]).unwrap();
432        let s = t.sin();
433        assert!((s.as_slice()[0] - 0.0).abs() < 1e-15);
434        assert!((s.as_slice()[1] - 1.0).abs() < 1e-15);
435
436        let c = t.cos();
437        assert!((c.as_slice()[0] - 1.0).abs() < 1e-15);
438        assert!(c.as_slice()[1].abs() < 1e-15);
439    }
440
441    #[test]
442    fn test_exp_ln() {
443        let t = Tensor::from_vec(vec![0.0_f64, 1.0], vec![2]).unwrap();
444        let e = t.exp();
445        assert!((e.as_slice()[0] - 1.0).abs() < 1e-15);
446        assert!((e.as_slice()[1] - std::f64::consts::E).abs() < 1e-14);
447
448        let l = e.ln();
449        assert!((l.as_slice()[0] - 0.0).abs() < 1e-15);
450        assert!((l.as_slice()[1] - 1.0).abs() < 1e-14);
451    }
452
453    #[test]
454    fn test_sqrt() {
455        let t = Tensor::from_vec(vec![0.0_f64, 1.0, 4.0, 9.0, 16.0], vec![5]).unwrap();
456        let s = t.sqrt();
457        assert_eq!(s.as_slice(), &[0.0, 1.0, 2.0, 3.0, 4.0]);
458    }
459
460    #[test]
461    fn test_abs() {
462        let t = Tensor::from_vec(vec![-3.0_f64, -1.0, 0.0, 2.0, 5.0], vec![5]).unwrap();
463        let a = t.abs();
464        assert_eq!(a.as_slice(), &[3.0, 1.0, 0.0, 2.0, 5.0]);
465    }
466
467    #[test]
468    fn test_powf_powi() {
469        let t = Tensor::from_vec(vec![2.0_f64, 3.0], vec![2]).unwrap();
470        let p = t.powf(3.0);
471        assert!((p.as_slice()[0] - 8.0).abs() < 1e-14);
472        assert!((p.as_slice()[1] - 27.0).abs() < 1e-14);
473
474        let p2 = t.powi(2);
475        assert!((p2.as_slice()[0] - 4.0).abs() < 1e-14);
476        assert!((p2.as_slice()[1] - 9.0).abs() < 1e-14);
477    }
478
479    #[test]
480    fn test_floor_ceil_round() {
481        let t = Tensor::from_vec(vec![1.3_f64, 2.7, -0.5], vec![3]).unwrap();
482        assert_eq!(t.floor().as_slice(), &[1.0, 2.0, -1.0]);
483        assert_eq!(t.ceil().as_slice(), &[2.0, 3.0, 0.0]);
484        // Rust rounds half-to-even for some cases, so test with clear values
485        let t2 = Tensor::from_vec(vec![1.3_f64, 2.7, 3.5], vec![3]).unwrap();
486        let r = t2.round();
487        assert_eq!(r.as_slice()[0], 1.0);
488        assert_eq!(r.as_slice()[1], 3.0);
489        assert_eq!(r.as_slice()[2], 4.0);
490    }
491
492    #[test]
493    fn test_recip() {
494        let t = Tensor::from_vec(vec![2.0_f64, 4.0, 5.0], vec![3]).unwrap();
495        let r = t.recip();
496        assert_eq!(r.as_slice(), &[0.5, 0.25, 0.2]);
497    }
498
499    #[test]
500    fn test_clamp() {
501        let t = Tensor::from_vec(vec![-5.0_f64, 0.5, 3.0, 10.0], vec![4]).unwrap();
502        let c = t.clamp(0.0, 2.0);
503        assert_eq!(c.as_slice(), &[0.0, 0.5, 2.0, 2.0]);
504    }
505
506    #[test]
507    fn test_shape_preserved() {
508        let t = Tensor::from_vec(vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
509        assert_eq!(t.sin().shape(), &[2, 3]);
510        assert_eq!(t.exp().shape(), &[2, 3]);
511        assert_eq!(t.sqrt().shape(), &[2, 3]);
512    }
513
514    #[test]
515    fn test_log2_log10() {
516        let t = Tensor::from_vec(vec![1.0_f64, 2.0, 4.0, 8.0], vec![4]).unwrap();
517        let l2 = t.log2();
518        assert_eq!(l2.as_slice(), &[0.0, 1.0, 2.0, 3.0]);
519
520        let t2 = Tensor::from_vec(vec![1.0_f64, 10.0, 100.0], vec![3]).unwrap();
521        let l10 = t2.log10();
522        assert!((l10.as_slice()[0] - 0.0).abs() < 1e-15);
523        assert!((l10.as_slice()[1] - 1.0).abs() < 1e-15);
524        assert!((l10.as_slice()[2] - 2.0).abs() < 1e-14);
525    }
526
527    #[test]
528    fn test_free_functions() {
529        let t = Tensor::from_vec(vec![0.0_f64, 1.0], vec![2]).unwrap();
530        let s = sin(&t);
531        assert!((s.as_slice()[0]).abs() < 1e-15);
532
533        let e = exp(&t);
534        assert!((e.as_slice()[0] - 1.0).abs() < 1e-15);
535    }
536
537    #[test]
538    fn test_f32_works() {
539        let t = Tensor::from_vec(vec![0.0_f32, 1.0, 4.0], vec![3]).unwrap();
540        let s = t.sqrt();
541        assert_eq!(s.as_slice(), &[0.0_f32, 1.0, 2.0]);
542    }
543}