Skip to main content

torsh_functional/
numerical.rs

1//! Numerical integration and differentiation operations
2//!
3//! This module provides numerical methods for integration and differentiation,
4//! including various quadrature rules and finite difference methods.
5
6use torsh_core::{Result as TorshResult, TorshError};
7use torsh_tensor::Tensor;
8
9// ============================================================================
10// Numerical Integration
11// ============================================================================
12
13/// Integration methods
14#[derive(Debug, Clone, Copy)]
15pub enum IntegrationMethod {
16    /// Trapezoidal rule
17    Trapezoidal,
18    /// Simpson's rule (requires odd number of points)
19    Simpson,
20    /// Simpson's 3/8 rule
21    Simpson38,
22    /// Gaussian quadrature
23    Gaussian,
24    /// Romberg integration
25    Romberg,
26    /// Adaptive quadrature
27    Adaptive,
28}
29
30/// Numerical integration using trapezoidal rule
31///
32/// Integrates a 1D tensor representing function values at equally spaced points.
33///
34/// # Arguments
35/// * `y` - Tensor of function values
36/// * `dx` - Step size between points
37pub fn trapz(y: &Tensor, dx: Option<f32>) -> TorshResult<Tensor> {
38    let data = y.data()?;
39    let dx = dx.unwrap_or(1.0);
40
41    if data.len() < 2 {
42        return Err(TorshError::InvalidArgument(
43            "Need at least 2 points for trapezoidal integration".to_string(),
44        ));
45    }
46
47    let mut sum = 0.5 * (data[0] + data[data.len() - 1]);
48    for i in 1..data.len() - 1 {
49        sum += data[i];
50    }
51
52    let result = sum * dx;
53    Tensor::from_data(vec![result], vec![], y.device())
54}
55
56/// Cumulative integration using trapezoidal rule
57///
58/// Returns cumulative integral at each point.
59///
60/// # Arguments
61/// * `y` - Tensor of function values
62/// * `dx` - Step size between points
63pub fn cumtrapz(y: &Tensor, dx: Option<f32>) -> TorshResult<Tensor> {
64    let data = y.data()?;
65    let dx = dx.unwrap_or(1.0);
66
67    if data.len() < 2 {
68        return Err(TorshError::InvalidArgument(
69            "Need at least 2 points for cumulative integration".to_string(),
70        ));
71    }
72
73    let mut result = Vec::with_capacity(data.len());
74    result.push(0.0); // First point has zero integral
75
76    for i in 1..data.len() {
77        let integral = result[i - 1] + 0.5 * (data[i - 1] + data[i]) * dx;
78        result.push(integral);
79    }
80
81    Tensor::from_data(result, y.shape().dims().to_vec(), y.device())
82}
83
84/// Simpson's rule integration
85///
86/// Requires odd number of points (even number of intervals).
87///
88/// # Arguments
89/// * `y` - Tensor of function values
90/// * `dx` - Step size between points
91pub fn simps(y: &Tensor, dx: Option<f32>) -> TorshResult<Tensor> {
92    let data = y.data()?;
93    let dx = dx.unwrap_or(1.0);
94    let n = data.len();
95
96    if n < 3 {
97        return Err(TorshError::InvalidArgument(
98            "Need at least 3 points for Simpson's rule".to_string(),
99        ));
100    }
101
102    if n % 2 == 0 {
103        return Err(TorshError::InvalidArgument(
104            "Simpson's rule requires odd number of points".to_string(),
105        ));
106    }
107
108    let mut sum = data[0] + data[n - 1];
109
110    // Add even indices with coefficient 4
111    for i in (1..n - 1).step_by(2) {
112        sum += 4.0 * data[i];
113    }
114
115    // Add odd indices with coefficient 2
116    for i in (2..n - 1).step_by(2) {
117        sum += 2.0 * data[i];
118    }
119
120    let result = sum * dx / 3.0;
121    Tensor::from_data(vec![result], vec![], y.device())
122}
123
124/// Gaussian quadrature integration
125///
126/// Uses Gauss-Legendre quadrature for integration over [-1, 1].
127/// For other intervals, use change of variables.
128///
129/// # Arguments
130/// * `func` - Function to integrate (closure)
131/// * `n_points` - Number of quadrature points (2-10 supported)
132pub fn gaussian_quad<F>(func: F, n_points: usize) -> TorshResult<f32>
133where
134    F: Fn(f32) -> f32,
135{
136    let (nodes, weights) = match n_points {
137        2 => (vec![-0.5773502692, 0.5773502692], vec![1.0, 1.0]),
138        3 => (
139            vec![-0.7745966692, 0.0, 0.7745966692],
140            vec![0.5555555556, 0.8888888889, 0.5555555556],
141        ),
142        4 => (
143            vec![-0.8611363116, -0.3399810436, 0.3399810436, 0.8611363116],
144            vec![0.3478548451, 0.6521451549, 0.6521451549, 0.3478548451],
145        ),
146        5 => (
147            vec![
148                -0.9061798459,
149                -0.5384693101,
150                0.0,
151                0.5384693101,
152                0.9061798459,
153            ],
154            vec![
155                0.2369268851,
156                0.4786286705,
157                0.5688888889,
158                0.4786286705,
159                0.2369268851,
160            ],
161        ),
162        _ => {
163            return Err(TorshError::InvalidArgument(
164                "Gaussian quadrature supports 2-5 points".to_string(),
165            ))
166        }
167    };
168
169    let mut integral = 0.0;
170    for (i, &x) in nodes.iter().enumerate() {
171        integral += weights[i] * func(x);
172    }
173
174    Ok(integral)
175}
176
177/// Adaptive quadrature integration
178///
179/// Uses recursive subdivision to achieve desired accuracy.
180///
181/// # Arguments
182/// * `func` - Function to integrate
183/// * `a` - Lower bound
184/// * `b` - Upper bound
185/// * `tol` - Tolerance for convergence
186/// * `max_depth` - Maximum recursion depth
187pub fn adaptive_quad<F>(
188    func: F,
189    a: f32,
190    b: f32,
191    tol: Option<f32>,
192    max_depth: Option<usize>,
193) -> TorshResult<f32>
194where
195    F: Fn(f32) -> f32 + Clone,
196{
197    let tol = tol.unwrap_or(1e-6);
198    let max_depth = max_depth.unwrap_or(10);
199
200    fn adaptive_simpson<F>(
201        func: &F,
202        a: f32,
203        b: f32,
204        tol: f32,
205        depth: usize,
206        max_depth: usize,
207    ) -> f32
208    where
209        F: Fn(f32) -> f32,
210    {
211        let h = b - a;
212        let c = (a + b) / 2.0;
213
214        let fa = func(a);
215        let fb = func(b);
216        let fc = func(c);
217
218        let s1 = h * (fa + 4.0 * fc + fb) / 6.0; // Simpson's rule
219
220        let fd = func((a + c) / 2.0);
221        let fe = func((c + b) / 2.0);
222
223        let s2 = h * (fa + 4.0 * fd + 2.0 * fc + 4.0 * fe + fb) / 12.0;
224
225        if depth >= max_depth || (s2 - s1).abs() < 15.0 * tol {
226            s2 + (s2 - s1) / 15.0
227        } else {
228            adaptive_simpson(func, a, c, tol / 2.0, depth + 1, max_depth)
229                + adaptive_simpson(func, c, b, tol / 2.0, depth + 1, max_depth)
230        }
231    }
232
233    Ok(adaptive_simpson(&func, a, b, tol, 0, max_depth))
234}
235
236// ============================================================================
237// Numerical Differentiation
238// ============================================================================
239
240/// Differentiation methods
241#[derive(Debug, Clone, Copy)]
242pub enum DifferentiationMethod {
243    /// Forward finite difference
244    Forward,
245    /// Backward finite difference
246    Backward,
247    /// Central finite difference
248    Central,
249    /// Higher-order finite difference
250    HigherOrder,
251}
252
253/// Gradient computation using finite differences
254///
255/// Computes gradient of a 1D tensor using specified method.
256///
257/// # Arguments
258/// * `y` - Tensor of function values
259/// * `dx` - Step size
260/// * `method` - Differentiation method
261pub fn gradient(
262    y: &Tensor,
263    dx: Option<f32>,
264    method: Option<DifferentiationMethod>,
265) -> TorshResult<Tensor> {
266    let data = y.data()?;
267    let dx = dx.unwrap_or(1.0);
268    let method = method.unwrap_or(DifferentiationMethod::Central);
269
270    if data.len() < 2 {
271        return Err(TorshError::InvalidArgument(
272            "Need at least 2 points for differentiation".to_string(),
273        ));
274    }
275
276    let mut grad = Vec::with_capacity(data.len());
277
278    match method {
279        DifferentiationMethod::Forward => {
280            for i in 0..data.len() - 1 {
281                grad.push((data[i + 1] - data[i]) / dx);
282            }
283            // Last point uses backward difference
284            grad.push((data[data.len() - 1] - data[data.len() - 2]) / dx);
285        }
286        DifferentiationMethod::Backward => {
287            // First point uses forward difference
288            grad.push((data[1] - data[0]) / dx);
289            for i in 1..data.len() {
290                grad.push((data[i] - data[i - 1]) / dx);
291            }
292        }
293        DifferentiationMethod::Central => {
294            // First point uses forward difference
295            grad.push((data[1] - data[0]) / dx);
296            // Central difference for interior points
297            for i in 1..data.len() - 1 {
298                grad.push((data[i + 1] - data[i - 1]) / (2.0 * dx));
299            }
300            // Last point uses backward difference
301            grad.push((data[data.len() - 1] - data[data.len() - 2]) / dx);
302        }
303        DifferentiationMethod::HigherOrder => {
304            if data.len() < 5 {
305                return Err(TorshError::InvalidArgument(
306                    "Need at least 5 points for higher-order differentiation".to_string(),
307                ));
308            }
309
310            // First two points use forward differences
311            grad.push(
312                (-25.0 * data[0] + 48.0 * data[1] - 36.0 * data[2] + 16.0 * data[3]
313                    - 3.0 * data[4])
314                    / (12.0 * dx),
315            );
316            grad.push(
317                (-3.0 * data[0] - 10.0 * data[1] + 18.0 * data[2] - 6.0 * data[3] + data[4])
318                    / (12.0 * dx),
319            );
320
321            // Central differences for interior points
322            for i in 2..data.len() - 2 {
323                grad.push(
324                    (data[i - 2] - 8.0 * data[i - 1] + 8.0 * data[i + 1] - data[i + 2])
325                        / (12.0 * dx),
326                );
327            }
328
329            // Last two points use backward differences
330            let n = data.len();
331            grad.push(
332                (3.0 * data[n - 1] + 10.0 * data[n - 2] - 18.0 * data[n - 3] + 6.0 * data[n - 4]
333                    - data[n - 5])
334                    / (12.0 * dx),
335            );
336            grad.push(
337                (25.0 * data[n - 1] - 48.0 * data[n - 2] + 36.0 * data[n - 3] - 16.0 * data[n - 4]
338                    + 3.0 * data[n - 5])
339                    / (12.0 * dx),
340            );
341        }
342    }
343
344    Tensor::from_data(grad, y.shape().dims().to_vec(), y.device())
345}
346
347/// Second derivative using finite differences
348///
349/// # Arguments
350/// * `y` - Tensor of function values
351/// * `dx` - Step size
352pub fn second_derivative(y: &Tensor, dx: Option<f32>) -> TorshResult<Tensor> {
353    let data = y.data()?;
354    let dx = dx.unwrap_or(1.0);
355
356    if data.len() < 3 {
357        return Err(TorshError::InvalidArgument(
358            "Need at least 3 points for second derivative".to_string(),
359        ));
360    }
361
362    let mut second_deriv = Vec::with_capacity(data.len());
363    let dx2 = dx * dx;
364
365    // First point (forward difference)
366    if data.len() >= 4 {
367        second_deriv.push((2.0 * data[0] - 5.0 * data[1] + 4.0 * data[2] - data[3]) / dx2);
368    } else {
369        // For small arrays, use simple forward difference
370        second_deriv.push((data[2] - 2.0 * data[1] + data[0]) / dx2);
371    }
372
373    // Interior points (central difference)
374    for i in 1..data.len() - 1 {
375        second_deriv.push((data[i - 1] - 2.0 * data[i] + data[i + 1]) / dx2);
376    }
377
378    // Last point (backward difference)
379    let n = data.len();
380    if n >= 4 {
381        second_deriv
382            .push((2.0 * data[n - 1] - 5.0 * data[n - 2] + 4.0 * data[n - 3] - data[n - 4]) / dx2);
383    } else {
384        // For small arrays, use simple backward difference
385        second_deriv.push((data[n - 1] - 2.0 * data[n - 2] + data[n - 3]) / dx2);
386    }
387
388    Tensor::from_data(second_deriv, y.shape().dims().to_vec(), y.device())
389}
390
391/// Partial derivatives for multi-dimensional tensors
392///
393/// Computes partial derivative along specified axis.
394///
395/// # Arguments
396/// * `tensor` - Input tensor
397/// * `axis` - Axis along which to compute derivative
398/// * `dx` - Step size
399pub fn partial_derivative(tensor: &Tensor, axis: usize, dx: Option<f32>) -> TorshResult<Tensor> {
400    let dx = dx.unwrap_or(1.0);
401    let binding = tensor.shape();
402    let shape = binding.dims();
403
404    if axis >= shape.len() {
405        return Err(TorshError::InvalidArgument(format!(
406            "Axis {} out of bounds for tensor with {} dimensions",
407            axis,
408            shape.len()
409        )));
410    }
411
412    if shape[axis] < 2 {
413        return Err(TorshError::InvalidArgument(
414            "Need at least 2 points along differentiation axis".to_string(),
415        ));
416    }
417
418    // For now, implement simple central difference
419    // This is a simplified implementation - a full implementation would need
420    // proper multi-dimensional indexing
421    gradient(tensor, Some(dx), Some(DifferentiationMethod::Central))
422}
423
424// ============================================================================
425// Optimization and Root Finding
426// ============================================================================
427
428/// Find roots using Newton-Raphson method
429///
430/// # Arguments
431/// * `func` - Function for which to find roots
432/// * `dfunc` - Derivative of the function
433/// * `x0` - Initial guess
434/// * `tol` - Tolerance for convergence
435/// * `max_iter` - Maximum number of iterations
436pub fn newton_raphson<F, DF>(
437    func: F,
438    dfunc: DF,
439    x0: f32,
440    tol: Option<f32>,
441    max_iter: Option<usize>,
442) -> TorshResult<f32>
443where
444    F: Fn(f32) -> f32,
445    DF: Fn(f32) -> f32,
446{
447    let tol = tol.unwrap_or(1e-6);
448    let max_iter = max_iter.unwrap_or(100);
449
450    let mut x = x0;
451
452    for _ in 0..max_iter {
453        let fx = func(x);
454        let dfx = dfunc(x);
455
456        if dfx.abs() < 1e-12 {
457            return Err(TorshError::ComputeError(
458                "Derivative is zero, Newton-Raphson method failed".to_string(),
459            ));
460        }
461
462        let x_new = x - fx / dfx;
463
464        if (x_new - x).abs() < tol {
465            return Ok(x_new);
466        }
467
468        x = x_new;
469    }
470
471    Err(TorshError::ComputeError(
472        "Newton-Raphson method did not converge".to_string(),
473    ))
474}
475
476/// Find roots using bisection method
477///
478/// # Arguments
479/// * `func` - Function for which to find roots
480/// * `a` - Lower bound (func(a) and func(b) should have opposite signs)
481/// * `b` - Upper bound
482/// * `tol` - Tolerance for convergence
483/// * `max_iter` - Maximum number of iterations
484pub fn bisection<F>(
485    func: F,
486    a: f32,
487    b: f32,
488    tol: Option<f32>,
489    max_iter: Option<usize>,
490) -> TorshResult<f32>
491where
492    F: Fn(f32) -> f32,
493{
494    let tol = tol.unwrap_or(1e-6);
495    let max_iter = max_iter.unwrap_or(100);
496
497    let fa = func(a);
498    let fb = func(b);
499
500    if fa * fb > 0.0 {
501        return Err(TorshError::InvalidArgument(
502            "Function values at endpoints must have opposite signs".to_string(),
503        ));
504    }
505
506    let mut a = a;
507    let mut b = b;
508
509    for _ in 0..max_iter {
510        let c = (a + b) / 2.0;
511        let fc = func(c);
512
513        if fc.abs() < tol || (b - a) / 2.0 < tol {
514            return Ok(c);
515        }
516
517        if fa * fc < 0.0 {
518            b = c;
519        } else {
520            a = c;
521        }
522    }
523
524    Ok((a + b) / 2.0)
525}
526
527#[cfg(test)]
528mod tests {
529    use super::*;
530    use approx::assert_relative_eq;
531    use torsh_core::device::DeviceType;
532    use torsh_tensor::creation::*;
533
534    #[test]
535    fn test_trapz() {
536        // Test integration of x^2 from 0 to 1
537        let x: Vec<f32> = (0..11).map(|i| i as f32 / 10.0).collect();
538        let y: Vec<f32> = x.iter().map(|&xi| xi * xi).collect();
539        let tensor = from_vec(y, &[11], DeviceType::Cpu).unwrap();
540
541        let result = trapz(&tensor, Some(0.1)).unwrap();
542        let result_val = result.data().expect("tensor should have data")[0];
543
544        // Analytical result for integral of x^2 from 0 to 1 is 1/3
545        assert_relative_eq!(result_val, 1.0 / 3.0, epsilon = 0.01);
546    }
547
548    #[test]
549    fn test_gradient() {
550        // Test gradient of x^2
551        let x: Vec<f32> = (0..11).map(|i| i as f32 / 10.0).collect();
552        let y: Vec<f32> = x.iter().map(|&xi| xi * xi).collect();
553        let tensor = from_vec(y, &[11], DeviceType::Cpu).unwrap();
554
555        let grad = gradient(&tensor, Some(0.1), Some(DifferentiationMethod::Central)).unwrap();
556        let grad_data = grad.data().expect("tensor should have data");
557
558        // Analytical gradient of x^2 is 2x
559        // Check a few points
560        for i in 1..grad_data.len() - 1 {
561            let expected = 2.0 * x[i];
562            assert_relative_eq!(grad_data[i], expected, epsilon = 0.1);
563        }
564    }
565
566    #[test]
567    fn test_simps() {
568        // Test Simpson's rule on x^2 from 0 to 1
569        let x: Vec<f32> = (0..11).map(|i| i as f32 / 10.0).collect();
570        let y: Vec<f32> = x.iter().map(|&xi| xi * xi).collect();
571        let tensor = from_vec(y, &[11], DeviceType::Cpu).unwrap();
572
573        let result = simps(&tensor, Some(0.1)).unwrap();
574        let result_val = result.data().expect("tensor should have data")[0];
575
576        // Should be more accurate than trapezoidal rule
577        assert_relative_eq!(result_val, 1.0 / 3.0, epsilon = 0.001);
578    }
579
580    #[test]
581    fn test_newton_raphson() {
582        // Find root of x^2 - 2 = 0 (should be sqrt(2))
583        let func = |x: f32| x * x - 2.0;
584        let dfunc = |x: f32| 2.0 * x;
585
586        let root = newton_raphson(func, dfunc, 1.0, None, None).unwrap();
587        assert_relative_eq!(root, 2.0_f32.sqrt(), epsilon = 1e-6);
588    }
589
590    #[test]
591    fn test_bisection() {
592        // Find root of x^2 - 2 = 0
593        let func = |x: f32| x * x - 2.0;
594
595        let root = bisection(func, 0.0, 2.0, None, None).unwrap();
596        assert_relative_eq!(root, 2.0_f32.sqrt(), epsilon = 1e-6);
597    }
598}