scirs2_linalg/gradient/
mod.rs

1//! Gradient calculation utilities for neural networks
2//!
3//! This module provides utilities for calculating gradients in the context of
4//! neural network training, focusing on efficiency and numerical stability.
5
6use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, ScalarOperand};
7use scirs2_core::numeric::{Float, NumAssign};
8use std::iter::Sum;
9
10use crate::error::{LinalgError, LinalgResult};
11
12/// Calculate the gradient of mean squared error with respect to predictions
13///
14/// Computes the gradient of MSE loss function with respect to predictions.
15/// This is a common gradient calculation in regression tasks.
16///
17/// # Arguments
18///
19/// * `predictions` - Predicted values
20/// * `targets` - Target values (ground truth)
21///
22/// # Returns
23///
24/// * The gradient of MSE with respect to predictions
25///
26/// # Examples
27///
28/// ```
29/// use scirs2_core::ndarray::array;
30/// use scirs2_linalg::gradient::mse_gradient;
31/// use approx::assert_relative_eq;
32///
33/// let predictions = array![3.0, 1.0, 2.0];
34/// let targets = array![2.5, 0.5, 2.0];
35///
36/// let gradient = mse_gradient(&predictions.view(), &targets.view()).unwrap();
37///
38/// // gradients = 2 * (predictions - targets) / n
39/// // = 2 * ([3.0, 1.0, 2.0] - [2.5, 0.5, 2.0]) / 3
40/// // = 2 * [0.5, 0.5, 0.0] / 3
41/// // = [0.333..., 0.333..., 0.0]
42/// assert_relative_eq!(gradient[0], 1.0/3.0, epsilon = 1e-10);
43/// assert_relative_eq!(gradient[1], 1.0/3.0, epsilon = 1e-10);
44/// assert_relative_eq!(gradient[2], 0.0, epsilon = 1e-10);
45/// ```
46#[allow(dead_code)]
47pub fn mse_gradient<F>(
48    predictions: &ArrayView1<F>,
49    targets: &ArrayView1<F>,
50) -> LinalgResult<Array1<F>>
51where
52    F: Float + NumAssign + Sum + ScalarOperand,
53{
54    // Check dimensions compatibility
55    if predictions.shape() != targets.shape() {
56        return Err(LinalgError::ShapeError(format!(
57            "Shape mismatch for mse_gradient: predictions has shape {:?} but targets has shape {:?}",
58            predictions.shape(),
59            targets.shape()
60        )));
61    }
62
63    let n = F::from(predictions.len()).unwrap();
64    let two = F::from(2.0).unwrap();
65
66    // Compute (predictions - targets) * 2/n
67    // This is the gradient of MSE with respect to predictions
68    let scale = two / n;
69    let gradient = predictions - targets;
70    let gradient = &gradient * scale;
71
72    Ok(gradient.to_owned())
73}
74
75/// Calculate the gradient of binary cross-entropy with respect to predictions
76///
77/// Computes the gradient of binary cross-entropy loss function with respect to predictions.
78/// This is a common gradient calculation in binary classification tasks.
79///
80/// # Arguments
81///
82/// * `predictions` - Predicted probabilities (must be between 0 and 1)
83/// * `targets` - Target values (ground truth, must be 0 or 1)
84///
85/// # Returns
86///
87/// * The gradient of binary cross-entropy with respect to predictions
88///
89/// # Examples
90///
91/// ```
92/// use scirs2_core::ndarray::array;
93/// use scirs2_linalg::gradient::binary_crossentropy_gradient;
94/// use approx::assert_relative_eq;
95///
96/// let predictions = array![0.7, 0.3, 0.9];
97/// let targets = array![1.0, 0.0, 1.0];
98///
99/// let gradient = binary_crossentropy_gradient(&predictions.view(), &targets.view()).unwrap();
100///
101/// // gradients = -targets/predictions + (1-targets)/(1-predictions)
102/// // = -[1.0, 0.0, 1.0]/[0.7, 0.3, 0.9] + [0.0, 1.0, 0.0]/[0.3, 0.7, 0.1]
103/// // = [-1.428..., 0.0, -1.111...] + [0.0, 1.428..., 0.0]
104/// // = [-1.428..., 1.428..., -1.111...]
105///
106/// assert_relative_eq!(gradient[0], -1.428571, epsilon = 1e-6);
107/// assert_relative_eq!(gradient[1], 1.428571, epsilon = 1e-6);
108/// assert_relative_eq!(gradient[2], -1.111111, epsilon = 1e-6);
109/// ```
110#[allow(dead_code)]
111pub fn binary_crossentropy_gradient<F>(
112    predictions: &ArrayView1<F>,
113    targets: &ArrayView1<F>,
114) -> LinalgResult<Array1<F>>
115where
116    F: Float + NumAssign + Sum + ScalarOperand,
117{
118    // Check dimensions compatibility
119    if predictions.shape() != targets.shape() {
120        return Err(LinalgError::ShapeError(format!(
121            "Shape mismatch for binary_crossentropy_gradient: predictions has shape {:?} but targets has shape {:?}",
122            predictions.shape(),
123            targets.shape()
124        )));
125    }
126
127    // Check that predictions are between 0 and 1
128    for &p in predictions.iter() {
129        if p <= F::zero() || p >= F::one() {
130            return Err(LinalgError::InvalidInputError(
131                "Predictions must be between 0 and 1 for binary cross-entropy".to_string(),
132            ));
133        }
134    }
135
136    // Check that targets are either 0 or 1
137    for &t in targets.iter() {
138        if (t - F::zero()).abs() > F::epsilon() && (t - F::one()).abs() > F::epsilon() {
139            return Err(LinalgError::InvalidInputError(
140                "Targets must be 0 or 1 for binary cross-entropy".to_string(),
141            ));
142        }
143    }
144
145    let one = F::one();
146    let eps = F::from(1e-15).unwrap(); // Small epsilon to prevent division by zero
147
148    // Compute -targets/(predictions+eps) + (1-targets)/(1-predictions+eps)
149    // This is the gradient of binary cross-entropy with respect to predictions
150    let mut gradient = Array1::zeros(predictions.len());
151    for i in 0..predictions.len() {
152        let p = predictions[i];
153        let t = targets[i];
154        let term1 = if t > F::epsilon() {
155            -t / (p + eps)
156        } else {
157            F::zero()
158        };
159        let term2 = if (one - t) > F::epsilon() {
160            (one - t) / (one - p + eps)
161        } else {
162            F::zero()
163        };
164        gradient[i] = term1 + term2;
165    }
166
167    Ok(gradient)
168}
169
170/// Calculate the gradient of softmax cross-entropy with respect to logits
171///
172/// Computes the gradient of softmax + cross-entropy loss with respect to logits (pre-softmax).
173/// This is a common gradient calculation in multi-class classification tasks.
174///
175/// # Arguments
176///
177/// * `softmax_output` - The output of the softmax function (probabilities that sum to 1)
178/// * `targets` - Target one-hot encoded vectors
179///
180/// # Returns
181///
182/// * The gradient of softmax cross-entropy with respect to logits
183///
184/// # Examples
185///
186/// ```
187/// use scirs2_core::ndarray::array;
188/// use scirs2_linalg::gradient::softmax_crossentropy_gradient;
189/// use approx::assert_relative_eq;
190///
191/// // Softmax outputs (probabilities)
192/// let softmax_output = array![[0.7, 0.2, 0.1], [0.3, 0.6, 0.1]];
193/// // One-hot encoded targets
194/// let targets = array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]];
195///
196/// let gradient = softmax_crossentropy_gradient(&softmax_output.view(), &targets.view()).unwrap();
197///
198/// // For each example, gradient = (softmax_output - targets) / batchsize
199/// // = ([0.7, 0.2, 0.1] - [1.0, 0.0, 0.0]) / 2
200/// // = [-0.15, 0.1, 0.05]
201/// // For the second example:
202/// // = ([0.3, 0.6, 0.1] - [0.0, 1.0, 0.0]) / 2
203/// // = [0.15, -0.2, 0.05]
204///
205/// assert_relative_eq!(gradient[[0, 0]], -0.15, epsilon = 1e-10);
206/// assert_relative_eq!(gradient[[0, 1]], 0.1, epsilon = 1e-10);
207/// assert_relative_eq!(gradient[[0, 2]], 0.05, epsilon = 1e-10);
208/// assert_relative_eq!(gradient[[1, 0]], 0.15, epsilon = 1e-10);
209/// assert_relative_eq!(gradient[[1, 1]], -0.2, epsilon = 1e-10);
210/// assert_relative_eq!(gradient[[1, 2]], 0.05, epsilon = 1e-10);
211/// ```
212#[allow(dead_code)]
213pub fn softmax_crossentropy_gradient<F>(
214    softmax_output: &ArrayView2<F>,
215    targets: &ArrayView2<F>,
216) -> LinalgResult<Array2<F>>
217where
218    F: Float + NumAssign + Sum + ScalarOperand + std::fmt::Display,
219{
220    // Check dimensions compatibility
221    if softmax_output.shape() != targets.shape() {
222        return Err(LinalgError::ShapeError(format!(
223            "Shape mismatch for softmax_crossentropy_gradient: softmax_output has shape {:?} but targets has shape {:?}",
224            softmax_output.shape(),
225            targets.shape()
226        )));
227    }
228
229    // Check that softmax outputs sum to 1 for each example
230    let (batchsize, _num_classes) = softmax_output.dim();
231    for i in 0..batchsize {
232        let row_sum = softmax_output.slice(s![i, ..]).sum();
233        if (row_sum - F::one()).abs() > F::from(1e-5).unwrap() {
234            return Err(LinalgError::InvalidInputError(format!(
235                "softmax_output row {i} does not sum to 1: sum is {row_sum}"
236            )));
237        }
238    }
239
240    // Check that targets are valid one-hot vectors
241    for i in 0..batchsize {
242        let row_sum = targets.slice(s![i, ..]).sum();
243        if (row_sum - F::one()).abs() > F::from(1e-6).unwrap() {
244            return Err(LinalgError::InvalidInputError(format!(
245                "targets row {i} is not a valid one-hot vector: sum is {row_sum}"
246            )));
247        }
248
249        // Check that only one element is (close to) 1, rest are 0
250        let mut has_one = false;
251        for val in targets.slice(s![i, ..]).iter() {
252            if (*val - F::one()).abs() < F::from(1e-6).unwrap() {
253                if has_one {
254                    // More than one value close to 1
255                    return Err(LinalgError::InvalidInputError(format!(
256                        "targets row {i} is not a valid one-hot vector: multiple entries close to 1"
257                    )));
258                }
259                has_one = true;
260            } else if *val > F::from(1e-6).unwrap() {
261                // Value is not close to 0 or 1
262                return Err(LinalgError::InvalidInputError(format!(
263                    "targets row {i} is not a valid one-hot vector: contains value {} not close to 0 or 1", *val
264                )));
265            }
266        }
267
268        if !has_one {
269            return Err(LinalgError::InvalidInputError(format!(
270                "targets row {i} is not a valid one-hot vector: no entry close to 1"
271            )));
272        }
273    }
274
275    let batchsize_f = F::from(batchsize).unwrap();
276
277    // Compute softmax_output - targets
278    let mut gradient = softmax_output.to_owned() - targets;
279
280    // Scale by 1/batchsize
281    gradient /= batchsize_f;
282
283    Ok(gradient)
284}
285
286/// Calculate the Jacobian matrix for a function that maps from R^n to R^m
287///
288/// Computes a numerical approximation of the Jacobian matrix for a function
289/// that takes an n-dimensional input and produces an m-dimensional output.
290///
291/// # Arguments
292///
293/// * `f` - A function that maps from R^n to R^m
294/// * `x` - The point at which to evaluate the Jacobian
295/// * `epsilon` - The step size for the finite difference approximation
296///
297/// # Returns
298///
299/// * The Jacobian matrix of shape (m, n)
300///
301/// # Examples
302///
303/// ```
304/// use scirs2_core::ndarray::{array, Array1};
305/// use scirs2_linalg::gradient::jacobian;
306///
307/// // Define a simple function R^2 -> R^3
308/// // f(x,y) = [x^2 + y, 2*x + y^2, x*y]
309/// let f = |v: &Array1<f64>| -> Array1<f64> {
310///     let x = v[0];
311///     let y = v[1];
312///     array![x*x + y, 2.0*x + y*y, x*y]
313/// };
314///
315/// let x = array![2.0, 3.0];  // Point at which to evaluate the Jacobian
316/// let epsilon = 1e-5;
317///
318/// let jac = jacobian(&f, &x, epsilon).unwrap();
319///
320/// // Analytical Jacobian at (2,3) is:
321/// // [2x, 1]     [4, 1]
322/// // [2, 2y]  =  [2, 6]
323/// // [y, x]      [3, 2]
324///
325/// assert!((jac[[0, 0]] - 4.0).abs() < 1e-4);
326/// assert!((jac[[0, 1]] - 1.0).abs() < 1e-4);
327/// assert!((jac[[1, 0]] - 2.0).abs() < 1e-4);
328/// assert!((jac[[1, 1]] - 6.0).abs() < 1e-4);
329/// assert!((jac[[2, 0]] - 3.0).abs() < 1e-4);
330/// assert!((jac[[2, 1]] - 2.0).abs() < 1e-4);
331/// ```
332#[allow(dead_code)]
333pub fn jacobian<F, G>(f: &G, x: &Array1<F>, epsilon: F) -> LinalgResult<Array2<F>>
334where
335    F: Float + NumAssign + Sum + ScalarOperand,
336    G: Fn(&Array1<F>) -> Array1<F>,
337{
338    let n = x.len();
339
340    // Evaluate function at the given point
341    let f_x = f(x);
342    let m = f_x.len();
343
344    let mut jacobian = Array2::zeros((m, n));
345    let two_epsilon = F::from(2.0).unwrap() * epsilon;
346
347    // Compute each column of the Jacobian by central finite differences
348    for j in 0..n {
349        // Create forward and backward perturbations for central differences
350        let mut x_forward = x.clone();
351        let mut x_backward = x.clone();
352
353        // Perturb jth component in both directions
354        x_forward[j] = x[j] + epsilon;
355        x_backward[j] = x[j] - epsilon;
356
357        // Evaluate function at perturbed points
358        let f_forward = f(&x_forward);
359        let f_backward = f(&x_backward);
360
361        // Compute jth column of Jacobian by central difference formula
362        // This is more accurate than forward or backward differences
363        for i in 0..m {
364            jacobian[[i, j]] = (f_forward[i] - f_backward[i]) / two_epsilon;
365        }
366    }
367
368    Ok(jacobian)
369}
370
371/// Calculate the Hessian matrix for a scalar-valued function
372///
373/// Computes a numerical approximation of the Hessian matrix (second derivatives)
374/// for a function that takes an n-dimensional input and produces a scalar output.
375///
376/// # Arguments
377///
378/// * `f` - A function that maps from R^n to R
379/// * `x` - The point at which to evaluate the Hessian
380/// * `epsilon` - The step size for the finite difference approximation
381///
382/// # Returns
383///
384/// * The Hessian matrix of shape (n, n)
385///
386/// # Examples
387///
388/// ```
389/// use scirs2_core::ndarray::{array, Array1};
390/// use scirs2_linalg::gradient::hessian;
391///
392/// // Define a simple quadratic function: f(x,y) = x^2 + xy + 2y^2
393/// let f = |v: &Array1<f64>| -> f64 {
394///     let x = v[0];
395///     let y = v[1];
396///     x*x + x*y + 2.0*y*y
397/// };
398///
399/// let x = array![1.0, 2.0];  // Point at which to evaluate the Hessian
400/// let epsilon = 1e-5;
401///
402/// let hess = hessian(&f, &x, epsilon).unwrap();
403///
404/// // Analytical Hessian is:
405/// // [∂²f/∂x², ∂²f/∂x∂y]   [2, 1]
406/// // [∂²f/∂y∂x, ∂²f/∂y²] = [1, 4]
407///
408/// assert!((hess[[0, 0]] - 2.0).abs() < 1e-4);
409/// assert!((hess[[0, 1]] - 1.0).abs() < 1e-4);
410/// assert!((hess[[1, 0]] - 1.0).abs() < 1e-4);
411/// assert!((hess[[1, 1]] - 4.0).abs() < 1e-4);
412/// ```
413#[allow(dead_code)]
414pub fn hessian<F, G>(f: &G, x: &Array1<F>, epsilon: F) -> LinalgResult<Array2<F>>
415where
416    F: Float + NumAssign + Sum + ScalarOperand,
417    G: Fn(&Array1<F>) -> F,
418{
419    let n = x.len();
420    let mut hessian = Array2::zeros((n, n));
421
422    let two = F::from(2.0).unwrap();
423    let epsilon_squared = epsilon * epsilon;
424
425    // Use central difference method for better accuracy
426    let f_x = f(x);
427
428    // Create arrays for the perturbed points
429    for i in 0..n {
430        for j in 0..=i {
431            // Use symmetry: compute only lower triangle
432            if i == j {
433                // Diagonal elements: use central difference formula for second derivative
434                let mut x_plus = x.clone();
435                let mut x_minus = x.clone();
436
437                x_plus[i] = x[i] + epsilon;
438                x_minus[i] = x[i] - epsilon;
439
440                let f_plus = f(&x_plus);
441                let f_minus = f(&x_minus);
442
443                // Central difference formula for second derivative:
444                // f''(x) ≈ (f(x+h) - 2f(x) + f(x-h)) / h²
445                let h_ii = (f_plus - two * f_x + f_minus) / epsilon_squared;
446                hessian[[i, i]] = h_ii;
447            } else {
448                // Off-diagonal elements (mixed partial derivatives): use central difference
449                let mut x_plus_plus = x.clone();
450                let mut x_plus_minus = x.clone();
451                let mut x_minus_plus = x.clone();
452                let mut x_minus_minus = x.clone();
453
454                // (i+,j+): Both variables increased by epsilon
455                x_plus_plus[i] = x[i] + epsilon;
456                x_plus_plus[j] = x[j] + epsilon;
457
458                // (i+,j-): First variable increased, second decreased
459                x_plus_minus[i] = x[i] + epsilon;
460                x_plus_minus[j] = x[j] - epsilon;
461
462                // (i-,j+): First variable decreased, second increased
463                x_minus_plus[i] = x[i] - epsilon;
464                x_minus_plus[j] = x[j] + epsilon;
465
466                // (i-,j-): Both variables decreased by epsilon
467                x_minus_minus[i] = x[i] - epsilon;
468                x_minus_minus[j] = x[j] - epsilon;
469
470                // Evaluate function at all these points
471                let f_plus_plus = f(&x_plus_plus);
472                let f_plus_minus = f(&x_plus_minus);
473                let f_minus_plus = f(&x_minus_plus);
474                let f_minus_minus = f(&x_minus_minus);
475
476                // Mixed partial derivative using central difference:
477                // ∂²f/∂x∂y ≈ (f(x+h,y+h) - f(x+h,y-h) - f(x-h,y+h) + f(x-h,y-h)) / (4h²)
478                let four = F::from(4.0).unwrap();
479                let h_ij = (f_plus_plus - f_plus_minus - f_minus_plus + f_minus_minus)
480                    / (four * epsilon_squared);
481
482                hessian[[i, j]] = h_ij;
483                hessian[[j, i]] = h_ij; // Hessian is symmetric
484            }
485        }
486    }
487
488    Ok(hessian)
489}
490
491#[cfg(test)]
492mod tests {
493    use super::*;
494    use approx::assert_relative_eq;
495    use scirs2_core::ndarray::array;
496
497    #[test]
498    fn test_mse_gradient() {
499        // Test with a simple case
500        let predictions = array![3.0, 1.0, 2.0];
501        let targets = array![2.5, 0.5, 2.0];
502
503        let gradient = mse_gradient(&predictions.view(), &targets.view()).unwrap();
504
505        // gradients = 2 * (predictions - targets) / n
506        // = 2 * ([3.0, 1.0, 2.0] - [2.5, 0.5, 2.0]) / 3
507        // = 2 * [0.5, 0.5, 0.0] / 3
508        // = [0.333..., 0.333..., 0.0]
509        assert_relative_eq!(gradient[0], 1.0 / 3.0, epsilon = 1e-10);
510        assert_relative_eq!(gradient[1], 1.0 / 3.0, epsilon = 1e-10);
511        assert_relative_eq!(gradient[2], 0.0, epsilon = 1e-10);
512    }
513
514    #[test]
515    fn test_binary_crossentropy_gradient() {
516        // Test with a simple case
517        let predictions = array![0.7, 0.3, 0.9];
518        let targets = array![1.0, 0.0, 1.0];
519
520        let gradient = binary_crossentropy_gradient(&predictions.view(), &targets.view()).unwrap();
521
522        // gradients = -targets/predictions + (1-targets)/(1-predictions)
523        // = -[1.0, 0.0, 1.0]/[0.7, 0.3, 0.9] + [0.0, 1.0, 0.0]/[0.3, 0.7, 0.1]
524        // = [-1.428..., 0.0, -1.111...] + [0.0, 1.428..., 0.0]
525        // = [-1.428..., 1.428..., -1.111...]
526        assert_relative_eq!(gradient[0], -1.428571, epsilon = 1e-6);
527        assert_relative_eq!(gradient[1], 1.428571, epsilon = 1e-6);
528        assert_relative_eq!(gradient[2], -1.111111, epsilon = 1e-6);
529    }
530
531    #[test]
532    fn test_softmax_crossentropy_gradient() {
533        // Test with a simple case
534        let softmax_output = array![[0.7, 0.2, 0.1], [0.3, 0.6, 0.1]];
535        let targets = array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]];
536
537        let gradient =
538            softmax_crossentropy_gradient(&softmax_output.view(), &targets.view()).unwrap();
539
540        // For each example, gradient = (softmax_output - targets) / batchsize
541        // = ([0.7, 0.2, 0.1] - [1.0, 0.0, 0.0]) / 2
542        // = [-0.15, 0.1, 0.05]
543        // For the second example:
544        // = ([0.3, 0.6, 0.1] - [0.0, 1.0, 0.0]) / 2
545        // = [0.15, -0.2, 0.05]
546        assert_relative_eq!(gradient[[0, 0]], -0.15, epsilon = 1e-10);
547        assert_relative_eq!(gradient[[0, 1]], 0.1, epsilon = 1e-10);
548        assert_relative_eq!(gradient[[0, 2]], 0.05, epsilon = 1e-10);
549        assert_relative_eq!(gradient[[1, 0]], 0.15, epsilon = 1e-10);
550        assert_relative_eq!(gradient[[1, 1]], -0.2, epsilon = 1e-10);
551        assert_relative_eq!(gradient[[1, 2]], 0.05, epsilon = 1e-10);
552    }
553
554    #[test]
555    fn test_jacobian() {
556        // Define a simple function R^2 -> R^3
557        // f(x,y) = [x^2 + y, 2*x + y^2, x*y]
558        let f = |v: &Array1<f64>| -> Array1<f64> {
559            let x = v[0];
560            let y = v[1];
561            array![x * x + y, 2.0 * x + y * y, x * y]
562        };
563
564        let x = array![2.0, 3.0]; // Point at which to evaluate the Jacobian
565        let epsilon = 1e-5;
566
567        let jac = jacobian(&f, &x, epsilon).unwrap();
568
569        // Analytical Jacobian at (2,3) is:
570        // [2x, 1]     [4, 1]
571        // [2, 2y]  =  [2, 6]
572        // [y, x]      [3, 2]
573
574        assert_relative_eq!(jac[[0, 0]], 4.0, epsilon = 1e-4);
575        assert_relative_eq!(jac[[0, 1]], 1.0, epsilon = 1e-4);
576        assert_relative_eq!(jac[[1, 0]], 2.0, epsilon = 1e-4);
577        assert_relative_eq!(jac[[1, 1]], 6.0, epsilon = 1e-4);
578        assert_relative_eq!(jac[[2, 0]], 3.0, epsilon = 1e-4);
579        assert_relative_eq!(jac[[2, 1]], 2.0, epsilon = 1e-4);
580    }
581
582    #[test]
583    fn test_hessian() {
584        // A very simple quadratic function: f(x) = 2x²
585        // Has constant second derivative: f''(x) = 4
586        let f = |v: &Array1<f64>| -> f64 {
587            let x = v[0];
588            2.0 * x * x
589        };
590
591        let x = array![0.5];
592        let epsilon = 1e-4;
593
594        let hess = hessian(&f, &x, epsilon).unwrap();
595
596        // The Hessian (second derivative) of f(x) = 2x² is 4
597        assert_relative_eq!(hess[[0, 0]], 4.0, epsilon = 1e-2);
598    }
599
600    #[test]
601    fn test_hessian_multidimensional() {
602        // Multivariable function: f(x,y,z) = x²y + y²z + z²x
603        let f = |v: &Array1<f64>| -> f64 {
604            let x = v[0];
605            let y = v[1];
606            let z = v[2];
607            x * x * y + y * y * z + z * z * x
608        };
609
610        let x = array![1.0, 1.0, 1.0];
611        let epsilon = 1e-4;
612
613        let hess = hessian(&f, &x, epsilon).unwrap();
614
615        // Analytical Hessian at (1,1,1) for f(x,y,z) = x²y + y²z + z²x:
616        // First-order derivatives:
617        // ∂f/∂x = 2xy + z²
618        // ∂f/∂y = x² + 2yz
619        // ∂f/∂z = y² + 2zx
620        //
621        // Second-order derivatives:
622        // ∂²f/∂x² = 2y = 2 (at point [1,1,1])
623        // ∂²f/∂y² = 2z = 2 (at point [1,1,1])
624        // ∂²f/∂z² = 2x = 2 (at point [1,1,1])
625        // ∂²f/∂x∂y = ∂²f/∂y∂x = 2x = 2 (at point [1,1,1])
626        // ∂²f/∂y∂z = ∂²f/∂z∂y = 2y = 2 (at point [1,1,1])
627        // ∂²f/∂z∂x = ∂²f/∂x∂z = 2z = 2 (at point [1,1,1])
628
629        // Diagonal elements
630        assert_relative_eq!(hess[[0, 0]], 2.0, epsilon = 1e-2);
631        assert_relative_eq!(hess[[1, 1]], 2.0, epsilon = 1e-2);
632        assert_relative_eq!(hess[[2, 2]], 2.0, epsilon = 1e-2);
633
634        // Off-diagonal elements
635        assert_relative_eq!(hess[[0, 1]], 2.0, epsilon = 1e-2);
636        assert_relative_eq!(hess[[1, 0]], 2.0, epsilon = 1e-2);
637        assert_relative_eq!(hess[[1, 2]], 2.0, epsilon = 1e-2);
638        assert_relative_eq!(hess[[2, 1]], 2.0, epsilon = 1e-2);
639        assert_relative_eq!(hess[[0, 2]], 2.0, epsilon = 1e-2);
640        assert_relative_eq!(hess[[2, 0]], 2.0, epsilon = 1e-2);
641    }
642
643    #[test]
644    fn test_hessian_quadratic_form() {
645        // Quadratic form: f(x,y) = x² + xy + 2y²
646        let f = |v: &Array1<f64>| -> f64 {
647            let x = v[0];
648            let y = v[1];
649            x * x + x * y + 2.0 * y * y
650        };
651
652        let x = array![1.0, 2.0];
653        let epsilon = 1e-5;
654
655        let hess = hessian(&f, &x, epsilon).unwrap();
656
657        // Analytical Hessian is:
658        // [∂²f/∂x², ∂²f/∂x∂y]   [2, 1]
659        // [∂²f/∂y∂x, ∂²f/∂y²] = [1, 4]
660        assert_relative_eq!(hess[[0, 0]], 2.0, epsilon = 1e-4);
661        assert_relative_eq!(hess[[0, 1]], 1.0, epsilon = 1e-4);
662        assert_relative_eq!(hess[[1, 0]], 1.0, epsilon = 1e-4);
663        assert_relative_eq!(hess[[1, 1]], 4.0, epsilon = 1e-4);
664    }
665
666    #[test]
667    fn test_mse_gradient_dimension_error() {
668        let predictions = array![1.0, 2.0, 3.0];
669        let targets = array![1.0, 2.0];
670
671        let result = mse_gradient(&predictions.view(), &targets.view());
672        assert!(result.is_err());
673    }
674
675    #[test]
676    fn test_binary_crossentropy_gradient_invalid_predictions() {
677        let predictions = array![0.5, 1.2, 0.3]; // Contains value > 1
678        let targets = array![1.0, 0.0, 1.0];
679
680        let result = binary_crossentropy_gradient(&predictions.view(), &targets.view());
681        assert!(result.is_err());
682
683        let predictions = array![0.5, -0.1, 0.3]; // Contains value < 0
684        let targets = array![1.0, 0.0, 1.0];
685
686        let result = binary_crossentropy_gradient(&predictions.view(), &targets.view());
687        assert!(result.is_err());
688    }
689
690    #[test]
691    fn test_binary_crossentropy_gradient_invalid_targets() {
692        let predictions = array![0.5, 0.7, 0.3];
693        let targets = array![1.0, 0.5, 1.0]; // Contains value neither 0 nor 1
694
695        let result = binary_crossentropy_gradient(&predictions.view(), &targets.view());
696        assert!(result.is_err());
697    }
698
699    #[test]
700    fn test_softmax_crossentropy_gradient_invalid_softmax() {
701        let softmax_output = array![[0.7, 0.2, 0.2], [0.3, 0.6, 0.1]]; // First row sums to 1.1
702        let targets = array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]];
703
704        let result = softmax_crossentropy_gradient(&softmax_output.view(), &targets.view());
705        assert!(result.is_err());
706    }
707
708    #[test]
709    fn test_softmax_crossentropy_gradient_invalid_targets() {
710        let softmax_output = array![[0.7, 0.2, 0.1], [0.3, 0.6, 0.1]];
711        let targets = array![[1.0, 0.0, 0.0], [0.3, 0.3, 0.4]]; // Second row is definitely not one-hot (sum = 1 but not one-hot)
712
713        let result = softmax_crossentropy_gradient(&softmax_output.view(), &targets.view());
714        assert!(result.is_err());
715    }
716}