svod_tensor/math.rs
1//! Mathematical operations for tensors.
2//!
3//! This module provides:
4//! - Trigonometric functions: sin, cos, tan
5//! - Rounding functions: floor, ceil, round, trunc
6//! - Advanced math: erf (error function), reciprocal, square, sign
7
8use snafu::ResultExt;
9use svod_ir::ConstValue;
10
11use super::*;
12
13/// Horner's method for polynomial evaluation: `coeffs[0]*x^(n-1) + ... + coeffs[n-1]`.
14fn poly_n(x: &Tensor, coefficients: &[f64]) -> Result<Tensor> {
15 let mut acc = x.broadcast_scalar(ConstValue::Float(coefficients[0]))?;
16 for &c in &coefficients[1..] {
17 let c_t = x.broadcast_scalar(ConstValue::Float(c))?;
18 acc = acc.try_mul(x)?.try_add(&c_t)?;
19 }
20 Ok(acc)
21}
22
23impl Tensor {
24 // =========================================================================
25 // Trigonometric Functions
26 // =========================================================================
27
28 /// Sine function: sin(x).
29 ///
30 /// Computes the sine of each element. Requires float dtype.
31 ///
32 /// # Examples
33 /// ```ignore
34 /// use std::f32::consts::PI;
35 /// let t = Tensor::from_slice(&[0.0f32, PI/2.0, PI]);
36 /// let result = t.sin()?; // [0, 1, 0]
37 /// ```
38 ///
39 /// # Errors
40 /// Returns error if dtype is not float.
41 #[track_caller]
42 pub fn sin(&self) -> Result<Tensor> {
43 self.uop().try_sin().map(Self::new).context(UOpSnafu)
44 }
45
46 /// Cosine function: cos(x).
47 ///
48 /// Computes the cosine of each element. Requires float dtype.
49 ///
50 /// # Examples
51 /// ```ignore
52 /// use std::f32::consts::PI;
53 /// let t = Tensor::from_slice(&[0.0f32, PI/2.0, PI]);
54 /// let result = t.cos()?; // [1, 0, -1]
55 /// ```
56 ///
57 /// # Errors
58 /// Returns error if dtype is not float.
59 #[track_caller]
60 pub fn cos(&self) -> Result<Tensor> {
61 self.uop().try_cos().map(Self::new).context(UOpSnafu)
62 }
63
64 /// Tangent function: tan(x).
65 ///
66 /// Computes the tangent of each element. Requires float dtype.
67 ///
68 /// # Examples
69 /// ```ignore
70 /// use std::f32::consts::PI;
71 /// let t = Tensor::from_slice(&[0.0f32, PI/4.0]);
72 /// let result = t.tan()?; // [0, 1]
73 /// ```
74 ///
75 /// # Errors
76 /// Returns error if dtype is not float.
77 #[track_caller]
78 pub fn tan(&self) -> Result<Tensor> {
79 self.uop().try_tan().map(Self::new).context(UOpSnafu)
80 }
81
82 // =========================================================================
83 // Rounding Functions
84 // =========================================================================
85
86 /// Floor function: round towards -∞.
87 ///
88 /// Returns the largest integer less than or equal to each element.
89 /// For integer dtypes, returns the tensor unchanged.
90 ///
91 /// # Examples
92 /// ```ignore
93 /// let t = Tensor::from_slice(&[1.2f32, -1.2, 2.8, -2.8]);
94 /// let result = t.floor()?; // [1.0, -2.0, 2.0, -3.0]
95 /// ```
96 #[track_caller]
97 pub fn floor(&self) -> Result<Tensor> {
98 Ok(Self::new(UOp::floor(self.uop())))
99 }
100
101 /// Ceiling function: round towards +∞.
102 ///
103 /// Returns the smallest integer greater than or equal to each element.
104 /// For integer dtypes, returns the tensor unchanged.
105 ///
106 /// # Examples
107 /// ```ignore
108 /// let t = Tensor::from_slice(&[1.2f32, -1.2, 2.8, -2.8]);
109 /// let result = t.ceil()?; // [2.0, -1.0, 3.0, -2.0]
110 /// ```
111 #[track_caller]
112 pub fn ceil(&self) -> Result<Tensor> {
113 Ok(Self::new(UOp::ceil(self.uop())))
114 }
115
116 /// Round function: round to nearest integer (half to even).
117 ///
118 /// Rounds each element to the nearest integer. Ties are rounded to the nearest even number.
119 /// For integer dtypes, returns the tensor unchanged.
120 ///
121 /// # Examples
122 /// ```ignore
123 /// let t = Tensor::from_slice(&[1.2f32, 1.5, 2.5, -1.5]);
124 /// let result = t.round()?; // [1.0, 2.0, 2.0, -2.0]
125 /// ```
126 #[track_caller]
127 pub fn round(&self) -> Result<Tensor> {
128 Ok(Self::new(UOp::round(self.uop())))
129 }
130
131 /// Truncate function: round towards zero.
132 ///
133 /// Removes the fractional part, rounding towards zero.
134 /// For integer dtypes, returns the tensor unchanged.
135 ///
136 /// # Examples
137 /// ```ignore
138 /// let t = Tensor::from_slice(&[1.2f32, -1.2, 2.8, -2.8]);
139 /// let result = t.trunc()?; // [1.0, -1.0, 2.0, -2.0]
140 /// ```
141 #[track_caller]
142 pub fn trunc(&self) -> Result<Tensor> {
143 Ok(Self::new(UOp::trunc(self.uop())))
144 }
145
146 // =========================================================================
147 // Advanced Math Functions
148 // =========================================================================
149
150 /// Error function: erf(x).
151 ///
152 /// Computes the error function (Gauss error function) of each element.
153 /// Requires float dtype. Critical for GELU activation.
154 ///
155 /// # Examples
156 /// ```ignore
157 /// let t = Tensor::from_slice(&[-1.0f32, 0.0, 1.0]);
158 /// let result = t.erf()?; // [-0.8427, 0, 0.8427]
159 /// ```
160 ///
161 /// # Errors
162 /// Returns error if dtype is not float.
163 #[track_caller]
164 pub fn erf(&self) -> Result<Tensor> {
165 self.uop().erf().map(Self::new).context(UOpSnafu)
166 }
167
168 /// Reciprocal: 1/x.
169 ///
170 /// Computes the reciprocal of each element.
171 ///
172 /// # Examples
173 /// ```ignore
174 /// let t = Tensor::from_slice(&[1.0f32, 2.0, 4.0]);
175 /// let result = t.reciprocal()?; // [1.0, 0.5, 0.25]
176 /// ```
177 #[track_caller]
178 pub fn reciprocal(&self) -> Result<Tensor> {
179 UOp::try_reciprocal(&self.uop()).map(Self::new).context(UOpSnafu)
180 }
181
182 /// Square: x².
183 ///
184 /// Computes the square of each element.
185 ///
186 /// # Examples
187 /// ```ignore
188 /// let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0, -4.0]);
189 /// let result = t.square()?; // [1.0, 4.0, 9.0, 16.0]
190 /// ```
191 #[track_caller]
192 pub fn square(&self) -> Result<Tensor> {
193 Ok(Self::new(self.uop().square()))
194 }
195
196 /// Sign function: -1 for negative, 0 for zero, 1 for positive.
197 ///
198 /// Returns the sign of each element.
199 ///
200 /// # Examples
201 /// ```ignore
202 /// let t = Tensor::from_slice(&[-5.0f32, 0.0, 3.0, -0.0]);
203 /// let result = t.sign()?; // [-1.0, 0.0, 1.0, 0.0]
204 /// ```
205 #[track_caller]
206 pub fn sign(&self) -> Result<Tensor> {
207 Ok(Self::new(self.uop().sign()))
208 }
209
210 /// Linear interpolation: `self + (end - self) * weight`.
211 #[track_caller]
212 pub fn lerp(&self, end: &Tensor, weight: &Tensor) -> Result<Tensor> {
213 let diff = end.try_sub(self)?;
214 self.try_add(&diff.try_mul(weight)?)
215 }
216
217 // =========================================================================
218 // NaN / Infinity Detection
219 // =========================================================================
220
221 /// Returns `true` where elements are NaN: `self != self`.
222 #[track_caller]
223 pub fn isnan(&self) -> Result<Tensor> {
224 self.try_ne(self)
225 }
226
227 /// Returns `true` where elements are infinite.
228 ///
229 /// Detects ±∞ via bitcast to the corresponding unsigned integer type and a
230 /// bit-pattern compare. Operating in integer space sidesteps Svod's float
231 /// range analysis, which folds `x == ±inf` to false because `dtype_bounds`
232 /// returns finite ±max for floats. Tinygrad gets away with the float compare
233 /// because its `dtype.min/max` are ±inf.
234 #[track_caller]
235 pub fn isinf(&self, detect_positive: bool, detect_negative: bool) -> Result<Tensor> {
236 use svod_dtype::{DType, ScalarDType};
237 let dtype = self.uop().dtype();
238 // (uint_bitcast_dtype, +inf bit pattern, -inf bit pattern, abs-mask)
239 let (uint_dt, pos_bits, neg_bits, abs_mask): (DType, i64, i64, i64) = match dtype {
240 DType::Scalar(ScalarDType::Float16) => (DType::UInt16, 0x7C00, 0xFC00, 0x7FFF),
241 DType::Scalar(ScalarDType::BFloat16) => (DType::UInt16, 0x7F80, 0xFF80, 0x7FFF),
242 DType::Scalar(ScalarDType::Float32) => (DType::UInt32, 0x7F800000, 0xFF800000_u32 as i64, 0x7FFFFFFF),
243 DType::Scalar(ScalarDType::Float64) => {
244 (DType::UInt64, 0x7FF0000000000000, 0xFFF0000000000000_u64 as i64, 0x7FFFFFFFFFFFFFFF)
245 }
246 // Non-float dtypes never have inf.
247 _ => return self.zero()?.cast(DType::Bool),
248 };
249
250 let bits = self.bitcast(uint_dt)?;
251 let pos_pat = bits.broadcast_scalar(ConstValue::Int(pos_bits))?;
252 match (detect_positive, detect_negative) {
253 (true, true) => {
254 // (bits & abs_mask) == +inf bits → matches both +inf and -inf
255 let mask = bits.broadcast_scalar(ConstValue::Int(abs_mask))?;
256 bits.bitwise_and(&mask)?.try_eq(&pos_pat)
257 }
258 (true, false) => bits.try_eq(&pos_pat),
259 (false, true) => {
260 let neg_pat = bits.broadcast_scalar(ConstValue::Int(neg_bits))?;
261 bits.try_eq(&neg_pat)
262 }
263 (false, false) => self.zero()?.cast(DType::Bool),
264 }
265 }
266
267 // =========================================================================
268 // Hyperbolic Functions
269 // =========================================================================
270
271 /// Hyperbolic sine: `(exp(x) - exp(-x)) / 2`.
272 #[track_caller]
273 pub fn sinh(&self) -> Result<Tensor> {
274 let exp_pos = self.try_exp()?;
275 let exp_neg = self.try_neg()?.try_exp()?;
276 let two = self.broadcast_scalar(ConstValue::Int(2))?;
277 exp_pos.try_sub(&exp_neg)?.try_div(&two)
278 }
279
280 /// Hyperbolic cosine: `(exp(x) + exp(-x)) / 2`.
281 #[track_caller]
282 pub fn cosh(&self) -> Result<Tensor> {
283 let exp_pos = self.try_exp()?;
284 let exp_neg = self.try_neg()?.try_exp()?;
285 let two = self.broadcast_scalar(ConstValue::Int(2))?;
286 exp_pos.try_add(&exp_neg)?.try_div(&two)
287 }
288
289 // =========================================================================
290 // Inverse Hyperbolic Functions
291 // =========================================================================
292
293 /// Inverse hyperbolic sine: `log(x + sqrt(x² + 1))`.
294 #[track_caller]
295 pub fn asinh(&self) -> Result<Tensor> {
296 let one = self.one()?;
297 let inner = self.square()?.try_add(&one)?.try_sqrt()?;
298 self.try_add(&inner)?.try_log()
299 }
300
301 /// Inverse hyperbolic cosine: `log(x + sqrt(x² - 1))`.
302 #[track_caller]
303 pub fn acosh(&self) -> Result<Tensor> {
304 let one = self.one()?;
305 let inner = self.square()?.try_sub(&one)?.try_sqrt()?;
306 self.try_add(&inner)?.try_log()
307 }
308
309 /// Inverse hyperbolic tangent: `0.5 * log((1+x)/(1-x))`.
310 #[track_caller]
311 pub fn atanh(&self) -> Result<Tensor> {
312 let one = self.one()?;
313 let half = self.broadcast_scalar(ConstValue::Float(0.5))?;
314 let num = one.try_add(self)?;
315 let den = one.try_sub(self)?;
316 half.try_mul(&num.try_div(&den)?.try_log()?)
317 }
318
319 // =========================================================================
320 // Inverse Trigonometric Functions
321 // =========================================================================
322
323 /// Arcsine using polynomial approximation (Abramowitz & Stegun 4.4.46).
324 #[track_caller]
325 pub fn asin(&self) -> Result<Tensor> {
326 let coefficients = [
327 -0.0012624911,
328 0.0066700901,
329 -0.0170881256,
330 0.0308918810,
331 -0.0501743046,
332 0.0889789874,
333 -0.2145988016,
334 1.5707963050,
335 ];
336 let abs_x = self.try_abs()?;
337 let one = self.one()?;
338 let half_pi = self.broadcast_scalar(ConstValue::Float(std::f64::consts::FRAC_PI_2))?;
339 let sqrt_part = one.try_sub(&abs_x)?.try_sqrt()?;
340 let poly = poly_n(&abs_x, &coefficients)?;
341 let x = half_pi.try_sub(&sqrt_part.try_mul(&poly)?)?;
342 self.sign()?.try_mul(&x)
343 }
344
345 /// Arccosine: `π/2 - asin(x)`.
346 #[track_caller]
347 pub fn acos(&self) -> Result<Tensor> {
348 let half_pi = self.broadcast_scalar(ConstValue::Float(std::f64::consts::FRAC_PI_2))?;
349 half_pi.try_sub(&self.asin()?)
350 }
351
352 /// Arctangent: `asin(x / sqrt(1 + x²))`.
353 #[track_caller]
354 pub fn atan(&self) -> Result<Tensor> {
355 let one = self.one()?;
356 let denom = one.try_add(&self.square()?)?.try_sqrt()?;
357 self.try_div(&denom)?.asin()
358 }
359
360 // =========================================================================
361 // Shrinkage / Thresholding
362 // =========================================================================
363
364 /// Shrinkage operator: applies soft/hard thresholding.
365 ///
366 /// `(x < -λ)*(x+bias) + (x > λ)*(x-bias)`
367 #[track_caller]
368 pub fn shrink(&self, bias: f64, lambd: f64) -> Result<Tensor> {
369 let dtype = self.uop().dtype();
370 let neg_lambd = Tensor::const_(-lambd, dtype.clone());
371 let pos_lambd = Tensor::const_(lambd, dtype.clone());
372 let bias_t = Tensor::const_(bias, dtype.clone());
373 let neg_bias = Tensor::const_(-bias, dtype.clone());
374 let neg_part = self.try_lt(&neg_lambd)?.cast(dtype.clone())?.try_mul(&self.try_add(&bias_t)?)?;
375 let pos_part = self.try_gt(&pos_lambd)?.cast(dtype)?.try_mul(&self.try_add(&neg_bias)?)?;
376 neg_part.try_add(&pos_part)
377 }
378
379 // =========================================================================
380 // Linear Algebra
381 // =========================================================================
382
383 /// Matrix determinant via LU decomposition with partial pivoting.
384 ///
385 /// Input shape: `[..., n, n]`. Output shape: `[...]`.
386 /// Batch dimensions are preserved. Uses O(n³) computation with O(n)
387 /// graph construction steps (unrolled at compile time).
388 #[track_caller]
389 pub fn det(&self) -> Result<Tensor> {
390 let shape = self.shape()?;
391 let ndim = shape.len();
392 snafu::ensure!(
393 ndim >= 2,
394 crate::error::ShapeMismatchSnafu {
395 context: "det",
396 expected: "at least 2D".to_string(),
397 actual: format!("{ndim}D"),
398 }
399 );
400 let n = shape[ndim - 1].as_const().unwrap();
401 let m = shape[ndim - 2].as_const().unwrap();
402 snafu::ensure!(
403 n == m,
404 crate::error::ShapeMismatchSnafu {
405 context: "det",
406 expected: format!("square last two dims, got [{m}, {n}]"),
407 actual: format!("[{m}, {n}]"),
408 }
409 );
410
411 let dtype = self.uop().dtype();
412 let float_dt = if dtype.is_float() { dtype.clone() } else { DType::Float32 };
413
414 if n == 0 {
415 let batch: Vec<usize> = shape[..ndim - 2].iter().map(|s| s.as_const().unwrap()).collect();
416 return if batch.is_empty() {
417 Ok(Tensor::const_(1.0, float_dt))
418 } else {
419 Tensor::full(&batch, 1.0, float_dt)
420 };
421 }
422
423 // Cast to float for correct division in Gaussian elimination
424 let mut a = if dtype.is_float() { self.clone() } else { self.cast(float_dt.clone())? };
425 let mut det_val: Option<Tensor> = None;
426 let neg_one = Tensor::const_(-1.0, float_dt.clone());
427 let one = Tensor::const_(1.0, float_dt.clone());
428 let zero_i = Tensor::const_(ConstValue::Int(0), DType::Int32);
429
430 for step in 0..n {
431 let cur_n = n - step;
432 let cni = cur_n as isize;
433
434 if cur_n > 1 {
435 // Partial pivoting: find row with max |a[..., :, 0]|
436 let col0 = shrink_last2(&a, ndim, (0, cni), (0, 1))?;
437 let max_idx = col0.try_abs()?.argmax_with().axis(Some(-2)).keepdim(true).call()?;
438
439 // Extract max_row via gather
440 let mut gather_shape: Vec<isize> = vec![-1; ndim - 2];
441 gather_shape.push(1);
442 gather_shape.push(cur_n as isize);
443 let max_idx_gather = max_idx.try_expand(&gather_shape)?;
444 let max_row = a.gather(-2, &max_idx_gather)?;
445
446 // Extract row 0
447 let row_0 = shrink_last2(&a, ndim, (0, 1), (0, cni))?;
448
449 // Build row-index mask: shape [1, ..., 1, cur_n, 1] for broadcasting
450 let mut row_idx = Tensor::arange(0, Some(cur_n as i64), None)?.try_unsqueeze(-1)?;
451 for _ in 0..ndim - 2 {
452 row_idx = row_idx.try_unsqueeze(0)?;
453 }
454 let mask_0 = row_idx.try_eq(&zero_i)?;
455 let mask_max = row_idx.try_eq(&max_idx)?;
456
457 // Swap: row_0 → max_idx position, max_row → row 0 position
458 let temp = row_0.where_(&mask_max, &a)?;
459 a = max_row.where_(&mask_0, &temp)?;
460
461 // Track sign: flip when a swap actually happened
462 let max_idx_scalar = max_idx.try_squeeze(Some(-1))?.try_squeeze(Some(-1))?;
463 let swapped = max_idx_scalar.try_ne(&zero_i)?;
464 let swap_sign = neg_one.where_(&swapped, &one)?;
465 det_val = Some(match det_val {
466 None => swap_sign,
467 Some(d) => d.try_mul(&swap_sign)?,
468 });
469 }
470
471 // Extract pivot a[..., 0, 0] and accumulate
472 let pivot = shrink_last2(&a, ndim, (0, 1), (0, 1))?;
473 let pivot_scalar = pivot.try_squeeze(Some(-1))?.try_squeeze(Some(-1))?;
474 det_val = Some(match det_val {
475 None => pivot_scalar,
476 Some(d) => d.try_mul(&pivot_scalar)?,
477 });
478
479 if cur_n <= 1 {
480 break;
481 }
482
483 // Gaussian elimination on the submatrix.
484 // Use safe pivot: replace 0 with 1 to avoid div-by-zero NaN.
485 // When pivot is 0 the matrix is singular (det=0), already captured
486 // in det_val; the elimination result doesn't matter.
487 let pivot_is_zero = pivot.try_eq(&Tensor::const_(0.0, float_dt.clone()))?;
488 let pivot_safe = one.where_(&pivot_is_zero, &pivot)?;
489 let col_below = shrink_last2(&a, ndim, (1, cni), (0, 1))?;
490 let factors = col_below.try_div(&pivot_safe)?;
491 let row_0_rest = shrink_last2(&a, ndim, (0, 1), (1, cni))?;
492 let sub = shrink_last2(&a, ndim, (1, cni), (1, cni))?;
493 a = sub.try_sub(&factors.try_mul(&row_0_rest)?)?;
494 }
495
496 Ok(det_val.unwrap())
497 }
498}
499
500/// Shrink only the last two dimensions of a tensor, preserving batch dims.
501fn shrink_last2(tensor: &Tensor, ndim: usize, row_range: (isize, isize), col_range: (isize, isize)) -> Result<Tensor> {
502 let shape = tensor.shape()?;
503 let mut ranges: Vec<(isize, isize)> =
504 shape[..ndim - 2].iter().map(|s| (0, s.as_const().unwrap() as isize)).collect();
505 ranges.push(row_range);
506 ranges.push(col_range);
507 tensor.try_shrink(&ranges)
508}