Skip to main content

svod_tensor/
math.rs

1//! Mathematical operations for tensors.
2//!
3//! This module provides:
4//! - Trigonometric functions: sin, cos, tan
5//! - Rounding functions: floor, ceil, round, trunc
6//! - Advanced math: erf (error function), reciprocal, square, sign
7
8use snafu::ResultExt;
9use svod_ir::ConstValue;
10
11use super::*;
12
13/// Horner's method for polynomial evaluation: `coeffs[0]*x^(n-1) + ... + coeffs[n-1]`.
14fn poly_n(x: &Tensor, coefficients: &[f64]) -> Result<Tensor> {
15    let mut acc = x.broadcast_scalar(ConstValue::Float(coefficients[0]))?;
16    for &c in &coefficients[1..] {
17        let c_t = x.broadcast_scalar(ConstValue::Float(c))?;
18        acc = acc.try_mul(x)?.try_add(&c_t)?;
19    }
20    Ok(acc)
21}
22
23impl Tensor {
24    // =========================================================================
25    // Trigonometric Functions
26    // =========================================================================
27
28    /// Sine function: sin(x).
29    ///
30    /// Computes the sine of each element. Requires float dtype.
31    ///
32    /// # Examples
33    /// ```ignore
34    /// use std::f32::consts::PI;
35    /// let t = Tensor::from_slice(&[0.0f32, PI/2.0, PI]);
36    /// let result = t.sin()?;  // [0, 1, 0]
37    /// ```
38    ///
39    /// # Errors
40    /// Returns error if dtype is not float.
41    #[track_caller]
42    pub fn sin(&self) -> Result<Tensor> {
43        self.uop().try_sin().map(Self::new).context(UOpSnafu)
44    }
45
46    /// Cosine function: cos(x).
47    ///
48    /// Computes the cosine of each element. Requires float dtype.
49    ///
50    /// # Examples
51    /// ```ignore
52    /// use std::f32::consts::PI;
53    /// let t = Tensor::from_slice(&[0.0f32, PI/2.0, PI]);
54    /// let result = t.cos()?;  // [1, 0, -1]
55    /// ```
56    ///
57    /// # Errors
58    /// Returns error if dtype is not float.
59    #[track_caller]
60    pub fn cos(&self) -> Result<Tensor> {
61        self.uop().try_cos().map(Self::new).context(UOpSnafu)
62    }
63
64    /// Tangent function: tan(x).
65    ///
66    /// Computes the tangent of each element. Requires float dtype.
67    ///
68    /// # Examples
69    /// ```ignore
70    /// use std::f32::consts::PI;
71    /// let t = Tensor::from_slice(&[0.0f32, PI/4.0]);
72    /// let result = t.tan()?;  // [0, 1]
73    /// ```
74    ///
75    /// # Errors
76    /// Returns error if dtype is not float.
77    #[track_caller]
78    pub fn tan(&self) -> Result<Tensor> {
79        self.uop().try_tan().map(Self::new).context(UOpSnafu)
80    }
81
82    // =========================================================================
83    // Rounding Functions
84    // =========================================================================
85
86    /// Floor function: round towards -∞.
87    ///
88    /// Returns the largest integer less than or equal to each element.
89    /// For integer dtypes, returns the tensor unchanged.
90    ///
91    /// # Examples
92    /// ```ignore
93    /// let t = Tensor::from_slice(&[1.2f32, -1.2, 2.8, -2.8]);
94    /// let result = t.floor()?;  // [1.0, -2.0, 2.0, -3.0]
95    /// ```
96    #[track_caller]
97    pub fn floor(&self) -> Result<Tensor> {
98        Ok(Self::new(UOp::floor(self.uop())))
99    }
100
101    /// Ceiling function: round towards +∞.
102    ///
103    /// Returns the smallest integer greater than or equal to each element.
104    /// For integer dtypes, returns the tensor unchanged.
105    ///
106    /// # Examples
107    /// ```ignore
108    /// let t = Tensor::from_slice(&[1.2f32, -1.2, 2.8, -2.8]);
109    /// let result = t.ceil()?;  // [2.0, -1.0, 3.0, -2.0]
110    /// ```
111    #[track_caller]
112    pub fn ceil(&self) -> Result<Tensor> {
113        Ok(Self::new(UOp::ceil(self.uop())))
114    }
115
116    /// Round function: round to nearest integer (half to even).
117    ///
118    /// Rounds each element to the nearest integer. Ties are rounded to the nearest even number.
119    /// For integer dtypes, returns the tensor unchanged.
120    ///
121    /// # Examples
122    /// ```ignore
123    /// let t = Tensor::from_slice(&[1.2f32, 1.5, 2.5, -1.5]);
124    /// let result = t.round()?;  // [1.0, 2.0, 2.0, -2.0]
125    /// ```
126    #[track_caller]
127    pub fn round(&self) -> Result<Tensor> {
128        Ok(Self::new(UOp::round(self.uop())))
129    }
130
131    /// Truncate function: round towards zero.
132    ///
133    /// Removes the fractional part, rounding towards zero.
134    /// For integer dtypes, returns the tensor unchanged.
135    ///
136    /// # Examples
137    /// ```ignore
138    /// let t = Tensor::from_slice(&[1.2f32, -1.2, 2.8, -2.8]);
139    /// let result = t.trunc()?;  // [1.0, -1.0, 2.0, -2.0]
140    /// ```
141    #[track_caller]
142    pub fn trunc(&self) -> Result<Tensor> {
143        Ok(Self::new(UOp::trunc(self.uop())))
144    }
145
146    // =========================================================================
147    // Advanced Math Functions
148    // =========================================================================
149
150    /// Error function: erf(x).
151    ///
152    /// Computes the error function (Gauss error function) of each element.
153    /// Requires float dtype. Critical for GELU activation.
154    ///
155    /// # Examples
156    /// ```ignore
157    /// let t = Tensor::from_slice(&[-1.0f32, 0.0, 1.0]);
158    /// let result = t.erf()?;  // [-0.8427, 0, 0.8427]
159    /// ```
160    ///
161    /// # Errors
162    /// Returns error if dtype is not float.
163    #[track_caller]
164    pub fn erf(&self) -> Result<Tensor> {
165        self.uop().erf().map(Self::new).context(UOpSnafu)
166    }
167
168    /// Reciprocal: 1/x.
169    ///
170    /// Computes the reciprocal of each element.
171    ///
172    /// # Examples
173    /// ```ignore
174    /// let t = Tensor::from_slice(&[1.0f32, 2.0, 4.0]);
175    /// let result = t.reciprocal()?;  // [1.0, 0.5, 0.25]
176    /// ```
177    #[track_caller]
178    pub fn reciprocal(&self) -> Result<Tensor> {
179        UOp::try_reciprocal(&self.uop()).map(Self::new).context(UOpSnafu)
180    }
181
182    /// Square: x².
183    ///
184    /// Computes the square of each element.
185    ///
186    /// # Examples
187    /// ```ignore
188    /// let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0, -4.0]);
189    /// let result = t.square()?;  // [1.0, 4.0, 9.0, 16.0]
190    /// ```
191    #[track_caller]
192    pub fn square(&self) -> Result<Tensor> {
193        Ok(Self::new(self.uop().square()))
194    }
195
196    /// Sign function: -1 for negative, 0 for zero, 1 for positive.
197    ///
198    /// Returns the sign of each element.
199    ///
200    /// # Examples
201    /// ```ignore
202    /// let t = Tensor::from_slice(&[-5.0f32, 0.0, 3.0, -0.0]);
203    /// let result = t.sign()?;  // [-1.0, 0.0, 1.0, 0.0]
204    /// ```
205    #[track_caller]
206    pub fn sign(&self) -> Result<Tensor> {
207        Ok(Self::new(self.uop().sign()))
208    }
209
210    /// Linear interpolation: `self + (end - self) * weight`.
211    #[track_caller]
212    pub fn lerp(&self, end: &Tensor, weight: &Tensor) -> Result<Tensor> {
213        let diff = end.try_sub(self)?;
214        self.try_add(&diff.try_mul(weight)?)
215    }
216
217    // =========================================================================
218    // NaN / Infinity Detection
219    // =========================================================================
220
221    /// Returns `true` where elements are NaN: `self != self`.
222    #[track_caller]
223    pub fn isnan(&self) -> Result<Tensor> {
224        self.try_ne(self)
225    }
226
227    /// Returns `true` where elements are infinite.
228    ///
229    /// Detects ±∞ via bitcast to the corresponding unsigned integer type and a
230    /// bit-pattern compare. Operating in integer space sidesteps Svod's float
231    /// range analysis, which folds `x == ±inf` to false because `dtype_bounds`
232    /// returns finite ±max for floats. Tinygrad gets away with the float compare
233    /// because its `dtype.min/max` are ±inf.
234    #[track_caller]
235    pub fn isinf(&self, detect_positive: bool, detect_negative: bool) -> Result<Tensor> {
236        use svod_dtype::{DType, ScalarDType};
237        let dtype = self.uop().dtype();
238        // (uint_bitcast_dtype, +inf bit pattern, -inf bit pattern, abs-mask)
239        let (uint_dt, pos_bits, neg_bits, abs_mask): (DType, i64, i64, i64) = match dtype {
240            DType::Scalar(ScalarDType::Float16) => (DType::UInt16, 0x7C00, 0xFC00, 0x7FFF),
241            DType::Scalar(ScalarDType::BFloat16) => (DType::UInt16, 0x7F80, 0xFF80, 0x7FFF),
242            DType::Scalar(ScalarDType::Float32) => (DType::UInt32, 0x7F800000, 0xFF800000_u32 as i64, 0x7FFFFFFF),
243            DType::Scalar(ScalarDType::Float64) => {
244                (DType::UInt64, 0x7FF0000000000000, 0xFFF0000000000000_u64 as i64, 0x7FFFFFFFFFFFFFFF)
245            }
246            // Non-float dtypes never have inf.
247            _ => return self.zero()?.cast(DType::Bool),
248        };
249
250        let bits = self.bitcast(uint_dt)?;
251        let pos_pat = bits.broadcast_scalar(ConstValue::Int(pos_bits))?;
252        match (detect_positive, detect_negative) {
253            (true, true) => {
254                // (bits & abs_mask) == +inf bits → matches both +inf and -inf
255                let mask = bits.broadcast_scalar(ConstValue::Int(abs_mask))?;
256                bits.bitwise_and(&mask)?.try_eq(&pos_pat)
257            }
258            (true, false) => bits.try_eq(&pos_pat),
259            (false, true) => {
260                let neg_pat = bits.broadcast_scalar(ConstValue::Int(neg_bits))?;
261                bits.try_eq(&neg_pat)
262            }
263            (false, false) => self.zero()?.cast(DType::Bool),
264        }
265    }
266
267    // =========================================================================
268    // Hyperbolic Functions
269    // =========================================================================
270
271    /// Hyperbolic sine: `(exp(x) - exp(-x)) / 2`.
272    #[track_caller]
273    pub fn sinh(&self) -> Result<Tensor> {
274        let exp_pos = self.try_exp()?;
275        let exp_neg = self.try_neg()?.try_exp()?;
276        let two = self.broadcast_scalar(ConstValue::Int(2))?;
277        exp_pos.try_sub(&exp_neg)?.try_div(&two)
278    }
279
280    /// Hyperbolic cosine: `(exp(x) + exp(-x)) / 2`.
281    #[track_caller]
282    pub fn cosh(&self) -> Result<Tensor> {
283        let exp_pos = self.try_exp()?;
284        let exp_neg = self.try_neg()?.try_exp()?;
285        let two = self.broadcast_scalar(ConstValue::Int(2))?;
286        exp_pos.try_add(&exp_neg)?.try_div(&two)
287    }
288
289    // =========================================================================
290    // Inverse Hyperbolic Functions
291    // =========================================================================
292
293    /// Inverse hyperbolic sine: `log(x + sqrt(x² + 1))`.
294    #[track_caller]
295    pub fn asinh(&self) -> Result<Tensor> {
296        let one = self.one()?;
297        let inner = self.square()?.try_add(&one)?.try_sqrt()?;
298        self.try_add(&inner)?.try_log()
299    }
300
301    /// Inverse hyperbolic cosine: `log(x + sqrt(x² - 1))`.
302    #[track_caller]
303    pub fn acosh(&self) -> Result<Tensor> {
304        let one = self.one()?;
305        let inner = self.square()?.try_sub(&one)?.try_sqrt()?;
306        self.try_add(&inner)?.try_log()
307    }
308
309    /// Inverse hyperbolic tangent: `0.5 * log((1+x)/(1-x))`.
310    #[track_caller]
311    pub fn atanh(&self) -> Result<Tensor> {
312        let one = self.one()?;
313        let half = self.broadcast_scalar(ConstValue::Float(0.5))?;
314        let num = one.try_add(self)?;
315        let den = one.try_sub(self)?;
316        half.try_mul(&num.try_div(&den)?.try_log()?)
317    }
318
319    // =========================================================================
320    // Inverse Trigonometric Functions
321    // =========================================================================
322
323    /// Arcsine using polynomial approximation (Abramowitz & Stegun 4.4.46).
324    #[track_caller]
325    pub fn asin(&self) -> Result<Tensor> {
326        let coefficients = [
327            -0.0012624911,
328            0.0066700901,
329            -0.0170881256,
330            0.0308918810,
331            -0.0501743046,
332            0.0889789874,
333            -0.2145988016,
334            1.5707963050,
335        ];
336        let abs_x = self.try_abs()?;
337        let one = self.one()?;
338        let half_pi = self.broadcast_scalar(ConstValue::Float(std::f64::consts::FRAC_PI_2))?;
339        let sqrt_part = one.try_sub(&abs_x)?.try_sqrt()?;
340        let poly = poly_n(&abs_x, &coefficients)?;
341        let x = half_pi.try_sub(&sqrt_part.try_mul(&poly)?)?;
342        self.sign()?.try_mul(&x)
343    }
344
345    /// Arccosine: `π/2 - asin(x)`.
346    #[track_caller]
347    pub fn acos(&self) -> Result<Tensor> {
348        let half_pi = self.broadcast_scalar(ConstValue::Float(std::f64::consts::FRAC_PI_2))?;
349        half_pi.try_sub(&self.asin()?)
350    }
351
352    /// Arctangent: `asin(x / sqrt(1 + x²))`.
353    #[track_caller]
354    pub fn atan(&self) -> Result<Tensor> {
355        let one = self.one()?;
356        let denom = one.try_add(&self.square()?)?.try_sqrt()?;
357        self.try_div(&denom)?.asin()
358    }
359
360    // =========================================================================
361    // Shrinkage / Thresholding
362    // =========================================================================
363
364    /// Shrinkage operator: applies soft/hard thresholding.
365    ///
366    /// `(x < -λ)*(x+bias) + (x > λ)*(x-bias)`
367    #[track_caller]
368    pub fn shrink(&self, bias: f64, lambd: f64) -> Result<Tensor> {
369        let dtype = self.uop().dtype();
370        let neg_lambd = Tensor::const_(-lambd, dtype.clone());
371        let pos_lambd = Tensor::const_(lambd, dtype.clone());
372        let bias_t = Tensor::const_(bias, dtype.clone());
373        let neg_bias = Tensor::const_(-bias, dtype.clone());
374        let neg_part = self.try_lt(&neg_lambd)?.cast(dtype.clone())?.try_mul(&self.try_add(&bias_t)?)?;
375        let pos_part = self.try_gt(&pos_lambd)?.cast(dtype)?.try_mul(&self.try_add(&neg_bias)?)?;
376        neg_part.try_add(&pos_part)
377    }
378
379    // =========================================================================
380    // Linear Algebra
381    // =========================================================================
382
383    /// Matrix determinant via LU decomposition with partial pivoting.
384    ///
385    /// Input shape: `[..., n, n]`. Output shape: `[...]`.
386    /// Batch dimensions are preserved. Uses O(n³) computation with O(n)
387    /// graph construction steps (unrolled at compile time).
388    #[track_caller]
389    pub fn det(&self) -> Result<Tensor> {
390        let shape = self.shape()?;
391        let ndim = shape.len();
392        snafu::ensure!(
393            ndim >= 2,
394            crate::error::ShapeMismatchSnafu {
395                context: "det",
396                expected: "at least 2D".to_string(),
397                actual: format!("{ndim}D"),
398            }
399        );
400        let n = shape[ndim - 1].as_const().unwrap();
401        let m = shape[ndim - 2].as_const().unwrap();
402        snafu::ensure!(
403            n == m,
404            crate::error::ShapeMismatchSnafu {
405                context: "det",
406                expected: format!("square last two dims, got [{m}, {n}]"),
407                actual: format!("[{m}, {n}]"),
408            }
409        );
410
411        let dtype = self.uop().dtype();
412        let float_dt = if dtype.is_float() { dtype.clone() } else { DType::Float32 };
413
414        if n == 0 {
415            let batch: Vec<usize> = shape[..ndim - 2].iter().map(|s| s.as_const().unwrap()).collect();
416            return if batch.is_empty() {
417                Ok(Tensor::const_(1.0, float_dt))
418            } else {
419                Tensor::full(&batch, 1.0, float_dt)
420            };
421        }
422
423        // Cast to float for correct division in Gaussian elimination
424        let mut a = if dtype.is_float() { self.clone() } else { self.cast(float_dt.clone())? };
425        let mut det_val: Option<Tensor> = None;
426        let neg_one = Tensor::const_(-1.0, float_dt.clone());
427        let one = Tensor::const_(1.0, float_dt.clone());
428        let zero_i = Tensor::const_(ConstValue::Int(0), DType::Int32);
429
430        for step in 0..n {
431            let cur_n = n - step;
432            let cni = cur_n as isize;
433
434            if cur_n > 1 {
435                // Partial pivoting: find row with max |a[..., :, 0]|
436                let col0 = shrink_last2(&a, ndim, (0, cni), (0, 1))?;
437                let max_idx = col0.try_abs()?.argmax_with().axis(Some(-2)).keepdim(true).call()?;
438
439                // Extract max_row via gather
440                let mut gather_shape: Vec<isize> = vec![-1; ndim - 2];
441                gather_shape.push(1);
442                gather_shape.push(cur_n as isize);
443                let max_idx_gather = max_idx.try_expand(&gather_shape)?;
444                let max_row = a.gather(-2, &max_idx_gather)?;
445
446                // Extract row 0
447                let row_0 = shrink_last2(&a, ndim, (0, 1), (0, cni))?;
448
449                // Build row-index mask: shape [1, ..., 1, cur_n, 1] for broadcasting
450                let mut row_idx = Tensor::arange(0, Some(cur_n as i64), None)?.try_unsqueeze(-1)?;
451                for _ in 0..ndim - 2 {
452                    row_idx = row_idx.try_unsqueeze(0)?;
453                }
454                let mask_0 = row_idx.try_eq(&zero_i)?;
455                let mask_max = row_idx.try_eq(&max_idx)?;
456
457                // Swap: row_0 → max_idx position, max_row → row 0 position
458                let temp = row_0.where_(&mask_max, &a)?;
459                a = max_row.where_(&mask_0, &temp)?;
460
461                // Track sign: flip when a swap actually happened
462                let max_idx_scalar = max_idx.try_squeeze(Some(-1))?.try_squeeze(Some(-1))?;
463                let swapped = max_idx_scalar.try_ne(&zero_i)?;
464                let swap_sign = neg_one.where_(&swapped, &one)?;
465                det_val = Some(match det_val {
466                    None => swap_sign,
467                    Some(d) => d.try_mul(&swap_sign)?,
468                });
469            }
470
471            // Extract pivot a[..., 0, 0] and accumulate
472            let pivot = shrink_last2(&a, ndim, (0, 1), (0, 1))?;
473            let pivot_scalar = pivot.try_squeeze(Some(-1))?.try_squeeze(Some(-1))?;
474            det_val = Some(match det_val {
475                None => pivot_scalar,
476                Some(d) => d.try_mul(&pivot_scalar)?,
477            });
478
479            if cur_n <= 1 {
480                break;
481            }
482
483            // Gaussian elimination on the submatrix.
484            // Use safe pivot: replace 0 with 1 to avoid div-by-zero NaN.
485            // When pivot is 0 the matrix is singular (det=0), already captured
486            // in det_val; the elimination result doesn't matter.
487            let pivot_is_zero = pivot.try_eq(&Tensor::const_(0.0, float_dt.clone()))?;
488            let pivot_safe = one.where_(&pivot_is_zero, &pivot)?;
489            let col_below = shrink_last2(&a, ndim, (1, cni), (0, 1))?;
490            let factors = col_below.try_div(&pivot_safe)?;
491            let row_0_rest = shrink_last2(&a, ndim, (0, 1), (1, cni))?;
492            let sub = shrink_last2(&a, ndim, (1, cni), (1, cni))?;
493            a = sub.try_sub(&factors.try_mul(&row_0_rest)?)?;
494        }
495
496        Ok(det_val.unwrap())
497    }
498}
499
500/// Shrink only the last two dimensions of a tensor, preserving batch dims.
501fn shrink_last2(tensor: &Tensor, ndim: usize, row_range: (isize, isize), col_range: (isize, isize)) -> Result<Tensor> {
502    let shape = tensor.shape()?;
503    let mut ranges: Vec<(isize, isize)> =
504        shape[..ndim - 2].iter().map(|s| (0, s.as_const().unwrap() as isize)).collect();
505    ranges.push(row_range);
506    ranges.push(col_range);
507    tensor.try_shrink(&ranges)
508}