scirs2_linalg/gradient/mod.rs
1//! Gradient calculation utilities for neural networks
2//!
3//! This module provides utilities for calculating gradients in the context of
4//! neural network training, focusing on efficiency and numerical stability.
5
6use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, ScalarOperand};
7use scirs2_core::numeric::{Float, NumAssign};
8use std::iter::Sum;
9
10use crate::error::{LinalgError, LinalgResult};
11
12/// Calculate the gradient of mean squared error with respect to predictions
13///
14/// Computes the gradient of MSE loss function with respect to predictions.
15/// This is a common gradient calculation in regression tasks.
16///
17/// # Arguments
18///
19/// * `predictions` - Predicted values
20/// * `targets` - Target values (ground truth)
21///
22/// # Returns
23///
24/// * The gradient of MSE with respect to predictions
25///
26/// # Examples
27///
28/// ```
29/// use scirs2_core::ndarray::array;
30/// use scirs2_linalg::gradient::mse_gradient;
31/// use approx::assert_relative_eq;
32///
33/// let predictions = array![3.0, 1.0, 2.0];
34/// let targets = array![2.5, 0.5, 2.0];
35///
36/// let gradient = mse_gradient(&predictions.view(), &targets.view()).unwrap();
37///
38/// // gradients = 2 * (predictions - targets) / n
39/// // = 2 * ([3.0, 1.0, 2.0] - [2.5, 0.5, 2.0]) / 3
40/// // = 2 * [0.5, 0.5, 0.0] / 3
41/// // = [0.333..., 0.333..., 0.0]
42/// assert_relative_eq!(gradient[0], 1.0/3.0, epsilon = 1e-10);
43/// assert_relative_eq!(gradient[1], 1.0/3.0, epsilon = 1e-10);
44/// assert_relative_eq!(gradient[2], 0.0, epsilon = 1e-10);
45/// ```
46#[allow(dead_code)]
47pub fn mse_gradient<F>(
48 predictions: &ArrayView1<F>,
49 targets: &ArrayView1<F>,
50) -> LinalgResult<Array1<F>>
51where
52 F: Float + NumAssign + Sum + ScalarOperand,
53{
54 // Check dimensions compatibility
55 if predictions.shape() != targets.shape() {
56 return Err(LinalgError::ShapeError(format!(
57 "Shape mismatch for mse_gradient: predictions has shape {:?} but targets has shape {:?}",
58 predictions.shape(),
59 targets.shape()
60 )));
61 }
62
63 let n = F::from(predictions.len()).unwrap();
64 let two = F::from(2.0).unwrap();
65
66 // Compute (predictions - targets) * 2/n
67 // This is the gradient of MSE with respect to predictions
68 let scale = two / n;
69 let gradient = predictions - targets;
70 let gradient = &gradient * scale;
71
72 Ok(gradient.to_owned())
73}
74
75/// Calculate the gradient of binary cross-entropy with respect to predictions
76///
77/// Computes the gradient of binary cross-entropy loss function with respect to predictions.
78/// This is a common gradient calculation in binary classification tasks.
79///
80/// # Arguments
81///
82/// * `predictions` - Predicted probabilities (must be between 0 and 1)
83/// * `targets` - Target values (ground truth, must be 0 or 1)
84///
85/// # Returns
86///
87/// * The gradient of binary cross-entropy with respect to predictions
88///
89/// # Examples
90///
91/// ```
92/// use scirs2_core::ndarray::array;
93/// use scirs2_linalg::gradient::binary_crossentropy_gradient;
94/// use approx::assert_relative_eq;
95///
96/// let predictions = array![0.7, 0.3, 0.9];
97/// let targets = array![1.0, 0.0, 1.0];
98///
99/// let gradient = binary_crossentropy_gradient(&predictions.view(), &targets.view()).unwrap();
100///
101/// // gradients = -targets/predictions + (1-targets)/(1-predictions)
102/// // = -[1.0, 0.0, 1.0]/[0.7, 0.3, 0.9] + [0.0, 1.0, 0.0]/[0.3, 0.7, 0.1]
103/// // = [-1.428..., 0.0, -1.111...] + [0.0, 1.428..., 0.0]
104/// // = [-1.428..., 1.428..., -1.111...]
105///
106/// assert_relative_eq!(gradient[0], -1.428571, epsilon = 1e-6);
107/// assert_relative_eq!(gradient[1], 1.428571, epsilon = 1e-6);
108/// assert_relative_eq!(gradient[2], -1.111111, epsilon = 1e-6);
109/// ```
110#[allow(dead_code)]
111pub fn binary_crossentropy_gradient<F>(
112 predictions: &ArrayView1<F>,
113 targets: &ArrayView1<F>,
114) -> LinalgResult<Array1<F>>
115where
116 F: Float + NumAssign + Sum + ScalarOperand,
117{
118 // Check dimensions compatibility
119 if predictions.shape() != targets.shape() {
120 return Err(LinalgError::ShapeError(format!(
121 "Shape mismatch for binary_crossentropy_gradient: predictions has shape {:?} but targets has shape {:?}",
122 predictions.shape(),
123 targets.shape()
124 )));
125 }
126
127 // Check that predictions are between 0 and 1
128 for &p in predictions.iter() {
129 if p <= F::zero() || p >= F::one() {
130 return Err(LinalgError::InvalidInputError(
131 "Predictions must be between 0 and 1 for binary cross-entropy".to_string(),
132 ));
133 }
134 }
135
136 // Check that targets are either 0 or 1
137 for &t in targets.iter() {
138 if (t - F::zero()).abs() > F::epsilon() && (t - F::one()).abs() > F::epsilon() {
139 return Err(LinalgError::InvalidInputError(
140 "Targets must be 0 or 1 for binary cross-entropy".to_string(),
141 ));
142 }
143 }
144
145 let one = F::one();
146 let eps = F::from(1e-15).unwrap(); // Small epsilon to prevent division by zero
147
148 // Compute -targets/(predictions+eps) + (1-targets)/(1-predictions+eps)
149 // This is the gradient of binary cross-entropy with respect to predictions
150 let mut gradient = Array1::zeros(predictions.len());
151 for i in 0..predictions.len() {
152 let p = predictions[i];
153 let t = targets[i];
154 let term1 = if t > F::epsilon() {
155 -t / (p + eps)
156 } else {
157 F::zero()
158 };
159 let term2 = if (one - t) > F::epsilon() {
160 (one - t) / (one - p + eps)
161 } else {
162 F::zero()
163 };
164 gradient[i] = term1 + term2;
165 }
166
167 Ok(gradient)
168}
169
170/// Calculate the gradient of softmax cross-entropy with respect to logits
171///
172/// Computes the gradient of softmax + cross-entropy loss with respect to logits (pre-softmax).
173/// This is a common gradient calculation in multi-class classification tasks.
174///
175/// # Arguments
176///
177/// * `softmax_output` - The output of the softmax function (probabilities that sum to 1)
178/// * `targets` - Target one-hot encoded vectors
179///
180/// # Returns
181///
182/// * The gradient of softmax cross-entropy with respect to logits
183///
184/// # Examples
185///
186/// ```
187/// use scirs2_core::ndarray::array;
188/// use scirs2_linalg::gradient::softmax_crossentropy_gradient;
189/// use approx::assert_relative_eq;
190///
191/// // Softmax outputs (probabilities)
192/// let softmax_output = array![[0.7, 0.2, 0.1], [0.3, 0.6, 0.1]];
193/// // One-hot encoded targets
194/// let targets = array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]];
195///
196/// let gradient = softmax_crossentropy_gradient(&softmax_output.view(), &targets.view()).unwrap();
197///
198/// // For each example, gradient = (softmax_output - targets) / batchsize
199/// // = ([0.7, 0.2, 0.1] - [1.0, 0.0, 0.0]) / 2
200/// // = [-0.15, 0.1, 0.05]
201/// // For the second example:
202/// // = ([0.3, 0.6, 0.1] - [0.0, 1.0, 0.0]) / 2
203/// // = [0.15, -0.2, 0.05]
204///
205/// assert_relative_eq!(gradient[[0, 0]], -0.15, epsilon = 1e-10);
206/// assert_relative_eq!(gradient[[0, 1]], 0.1, epsilon = 1e-10);
207/// assert_relative_eq!(gradient[[0, 2]], 0.05, epsilon = 1e-10);
208/// assert_relative_eq!(gradient[[1, 0]], 0.15, epsilon = 1e-10);
209/// assert_relative_eq!(gradient[[1, 1]], -0.2, epsilon = 1e-10);
210/// assert_relative_eq!(gradient[[1, 2]], 0.05, epsilon = 1e-10);
211/// ```
212#[allow(dead_code)]
213pub fn softmax_crossentropy_gradient<F>(
214 softmax_output: &ArrayView2<F>,
215 targets: &ArrayView2<F>,
216) -> LinalgResult<Array2<F>>
217where
218 F: Float + NumAssign + Sum + ScalarOperand + std::fmt::Display,
219{
220 // Check dimensions compatibility
221 if softmax_output.shape() != targets.shape() {
222 return Err(LinalgError::ShapeError(format!(
223 "Shape mismatch for softmax_crossentropy_gradient: softmax_output has shape {:?} but targets has shape {:?}",
224 softmax_output.shape(),
225 targets.shape()
226 )));
227 }
228
229 // Check that softmax outputs sum to 1 for each example
230 let (batchsize, _num_classes) = softmax_output.dim();
231 for i in 0..batchsize {
232 let row_sum = softmax_output.slice(s![i, ..]).sum();
233 if (row_sum - F::one()).abs() > F::from(1e-5).unwrap() {
234 return Err(LinalgError::InvalidInputError(format!(
235 "softmax_output row {i} does not sum to 1: sum is {row_sum}"
236 )));
237 }
238 }
239
240 // Check that targets are valid one-hot vectors
241 for i in 0..batchsize {
242 let row_sum = targets.slice(s![i, ..]).sum();
243 if (row_sum - F::one()).abs() > F::from(1e-6).unwrap() {
244 return Err(LinalgError::InvalidInputError(format!(
245 "targets row {i} is not a valid one-hot vector: sum is {row_sum}"
246 )));
247 }
248
249 // Check that only one element is (close to) 1, rest are 0
250 let mut has_one = false;
251 for val in targets.slice(s![i, ..]).iter() {
252 if (*val - F::one()).abs() < F::from(1e-6).unwrap() {
253 if has_one {
254 // More than one value close to 1
255 return Err(LinalgError::InvalidInputError(format!(
256 "targets row {i} is not a valid one-hot vector: multiple entries close to 1"
257 )));
258 }
259 has_one = true;
260 } else if *val > F::from(1e-6).unwrap() {
261 // Value is not close to 0 or 1
262 return Err(LinalgError::InvalidInputError(format!(
263 "targets row {i} is not a valid one-hot vector: contains value {} not close to 0 or 1", *val
264 )));
265 }
266 }
267
268 if !has_one {
269 return Err(LinalgError::InvalidInputError(format!(
270 "targets row {i} is not a valid one-hot vector: no entry close to 1"
271 )));
272 }
273 }
274
275 let batchsize_f = F::from(batchsize).unwrap();
276
277 // Compute softmax_output - targets
278 let mut gradient = softmax_output.to_owned() - targets;
279
280 // Scale by 1/batchsize
281 gradient /= batchsize_f;
282
283 Ok(gradient)
284}
285
286/// Calculate the Jacobian matrix for a function that maps from R^n to R^m
287///
288/// Computes a numerical approximation of the Jacobian matrix for a function
289/// that takes an n-dimensional input and produces an m-dimensional output.
290///
291/// # Arguments
292///
293/// * `f` - A function that maps from R^n to R^m
294/// * `x` - The point at which to evaluate the Jacobian
295/// * `epsilon` - The step size for the finite difference approximation
296///
297/// # Returns
298///
299/// * The Jacobian matrix of shape (m, n)
300///
301/// # Examples
302///
303/// ```
304/// use scirs2_core::ndarray::{array, Array1};
305/// use scirs2_linalg::gradient::jacobian;
306///
307/// // Define a simple function R^2 -> R^3
308/// // f(x,y) = [x^2 + y, 2*x + y^2, x*y]
309/// let f = |v: &Array1<f64>| -> Array1<f64> {
310/// let x = v[0];
311/// let y = v[1];
312/// array![x*x + y, 2.0*x + y*y, x*y]
313/// };
314///
315/// let x = array![2.0, 3.0]; // Point at which to evaluate the Jacobian
316/// let epsilon = 1e-5;
317///
318/// let jac = jacobian(&f, &x, epsilon).unwrap();
319///
320/// // Analytical Jacobian at (2,3) is:
321/// // [2x, 1] [4, 1]
322/// // [2, 2y] = [2, 6]
323/// // [y, x] [3, 2]
324///
325/// assert!((jac[[0, 0]] - 4.0).abs() < 1e-4);
326/// assert!((jac[[0, 1]] - 1.0).abs() < 1e-4);
327/// assert!((jac[[1, 0]] - 2.0).abs() < 1e-4);
328/// assert!((jac[[1, 1]] - 6.0).abs() < 1e-4);
329/// assert!((jac[[2, 0]] - 3.0).abs() < 1e-4);
330/// assert!((jac[[2, 1]] - 2.0).abs() < 1e-4);
331/// ```
332#[allow(dead_code)]
333pub fn jacobian<F, G>(f: &G, x: &Array1<F>, epsilon: F) -> LinalgResult<Array2<F>>
334where
335 F: Float + NumAssign + Sum + ScalarOperand,
336 G: Fn(&Array1<F>) -> Array1<F>,
337{
338 let n = x.len();
339
340 // Evaluate function at the given point
341 let f_x = f(x);
342 let m = f_x.len();
343
344 let mut jacobian = Array2::zeros((m, n));
345 let two_epsilon = F::from(2.0).unwrap() * epsilon;
346
347 // Compute each column of the Jacobian by central finite differences
348 for j in 0..n {
349 // Create forward and backward perturbations for central differences
350 let mut x_forward = x.clone();
351 let mut x_backward = x.clone();
352
353 // Perturb jth component in both directions
354 x_forward[j] = x[j] + epsilon;
355 x_backward[j] = x[j] - epsilon;
356
357 // Evaluate function at perturbed points
358 let f_forward = f(&x_forward);
359 let f_backward = f(&x_backward);
360
361 // Compute jth column of Jacobian by central difference formula
362 // This is more accurate than forward or backward differences
363 for i in 0..m {
364 jacobian[[i, j]] = (f_forward[i] - f_backward[i]) / two_epsilon;
365 }
366 }
367
368 Ok(jacobian)
369}
370
371/// Calculate the Hessian matrix for a scalar-valued function
372///
373/// Computes a numerical approximation of the Hessian matrix (second derivatives)
374/// for a function that takes an n-dimensional input and produces a scalar output.
375///
376/// # Arguments
377///
378/// * `f` - A function that maps from R^n to R
379/// * `x` - The point at which to evaluate the Hessian
380/// * `epsilon` - The step size for the finite difference approximation
381///
382/// # Returns
383///
384/// * The Hessian matrix of shape (n, n)
385///
386/// # Examples
387///
388/// ```
389/// use scirs2_core::ndarray::{array, Array1};
390/// use scirs2_linalg::gradient::hessian;
391///
392/// // Define a simple quadratic function: f(x,y) = x^2 + xy + 2y^2
393/// let f = |v: &Array1<f64>| -> f64 {
394/// let x = v[0];
395/// let y = v[1];
396/// x*x + x*y + 2.0*y*y
397/// };
398///
399/// let x = array![1.0, 2.0]; // Point at which to evaluate the Hessian
400/// let epsilon = 1e-5;
401///
402/// let hess = hessian(&f, &x, epsilon).unwrap();
403///
404/// // Analytical Hessian is:
405/// // [∂²f/∂x², ∂²f/∂x∂y] [2, 1]
406/// // [∂²f/∂y∂x, ∂²f/∂y²] = [1, 4]
407///
408/// assert!((hess[[0, 0]] - 2.0).abs() < 1e-4);
409/// assert!((hess[[0, 1]] - 1.0).abs() < 1e-4);
410/// assert!((hess[[1, 0]] - 1.0).abs() < 1e-4);
411/// assert!((hess[[1, 1]] - 4.0).abs() < 1e-4);
412/// ```
413#[allow(dead_code)]
414pub fn hessian<F, G>(f: &G, x: &Array1<F>, epsilon: F) -> LinalgResult<Array2<F>>
415where
416 F: Float + NumAssign + Sum + ScalarOperand,
417 G: Fn(&Array1<F>) -> F,
418{
419 let n = x.len();
420 let mut hessian = Array2::zeros((n, n));
421
422 let two = F::from(2.0).unwrap();
423 let epsilon_squared = epsilon * epsilon;
424
425 // Use central difference method for better accuracy
426 let f_x = f(x);
427
428 // Create arrays for the perturbed points
429 for i in 0..n {
430 for j in 0..=i {
431 // Use symmetry: compute only lower triangle
432 if i == j {
433 // Diagonal elements: use central difference formula for second derivative
434 let mut x_plus = x.clone();
435 let mut x_minus = x.clone();
436
437 x_plus[i] = x[i] + epsilon;
438 x_minus[i] = x[i] - epsilon;
439
440 let f_plus = f(&x_plus);
441 let f_minus = f(&x_minus);
442
443 // Central difference formula for second derivative:
444 // f''(x) ≈ (f(x+h) - 2f(x) + f(x-h)) / h²
445 let h_ii = (f_plus - two * f_x + f_minus) / epsilon_squared;
446 hessian[[i, i]] = h_ii;
447 } else {
448 // Off-diagonal elements (mixed partial derivatives): use central difference
449 let mut x_plus_plus = x.clone();
450 let mut x_plus_minus = x.clone();
451 let mut x_minus_plus = x.clone();
452 let mut x_minus_minus = x.clone();
453
454 // (i+,j+): Both variables increased by epsilon
455 x_plus_plus[i] = x[i] + epsilon;
456 x_plus_plus[j] = x[j] + epsilon;
457
458 // (i+,j-): First variable increased, second decreased
459 x_plus_minus[i] = x[i] + epsilon;
460 x_plus_minus[j] = x[j] - epsilon;
461
462 // (i-,j+): First variable decreased, second increased
463 x_minus_plus[i] = x[i] - epsilon;
464 x_minus_plus[j] = x[j] + epsilon;
465
466 // (i-,j-): Both variables decreased by epsilon
467 x_minus_minus[i] = x[i] - epsilon;
468 x_minus_minus[j] = x[j] - epsilon;
469
470 // Evaluate function at all these points
471 let f_plus_plus = f(&x_plus_plus);
472 let f_plus_minus = f(&x_plus_minus);
473 let f_minus_plus = f(&x_minus_plus);
474 let f_minus_minus = f(&x_minus_minus);
475
476 // Mixed partial derivative using central difference:
477 // ∂²f/∂x∂y ≈ (f(x+h,y+h) - f(x+h,y-h) - f(x-h,y+h) + f(x-h,y-h)) / (4h²)
478 let four = F::from(4.0).unwrap();
479 let h_ij = (f_plus_plus - f_plus_minus - f_minus_plus + f_minus_minus)
480 / (four * epsilon_squared);
481
482 hessian[[i, j]] = h_ij;
483 hessian[[j, i]] = h_ij; // Hessian is symmetric
484 }
485 }
486 }
487
488 Ok(hessian)
489}
490
491#[cfg(test)]
492mod tests {
493 use super::*;
494 use approx::assert_relative_eq;
495 use scirs2_core::ndarray::array;
496
497 #[test]
498 fn test_mse_gradient() {
499 // Test with a simple case
500 let predictions = array![3.0, 1.0, 2.0];
501 let targets = array![2.5, 0.5, 2.0];
502
503 let gradient = mse_gradient(&predictions.view(), &targets.view()).unwrap();
504
505 // gradients = 2 * (predictions - targets) / n
506 // = 2 * ([3.0, 1.0, 2.0] - [2.5, 0.5, 2.0]) / 3
507 // = 2 * [0.5, 0.5, 0.0] / 3
508 // = [0.333..., 0.333..., 0.0]
509 assert_relative_eq!(gradient[0], 1.0 / 3.0, epsilon = 1e-10);
510 assert_relative_eq!(gradient[1], 1.0 / 3.0, epsilon = 1e-10);
511 assert_relative_eq!(gradient[2], 0.0, epsilon = 1e-10);
512 }
513
514 #[test]
515 fn test_binary_crossentropy_gradient() {
516 // Test with a simple case
517 let predictions = array![0.7, 0.3, 0.9];
518 let targets = array![1.0, 0.0, 1.0];
519
520 let gradient = binary_crossentropy_gradient(&predictions.view(), &targets.view()).unwrap();
521
522 // gradients = -targets/predictions + (1-targets)/(1-predictions)
523 // = -[1.0, 0.0, 1.0]/[0.7, 0.3, 0.9] + [0.0, 1.0, 0.0]/[0.3, 0.7, 0.1]
524 // = [-1.428..., 0.0, -1.111...] + [0.0, 1.428..., 0.0]
525 // = [-1.428..., 1.428..., -1.111...]
526 assert_relative_eq!(gradient[0], -1.428571, epsilon = 1e-6);
527 assert_relative_eq!(gradient[1], 1.428571, epsilon = 1e-6);
528 assert_relative_eq!(gradient[2], -1.111111, epsilon = 1e-6);
529 }
530
531 #[test]
532 fn test_softmax_crossentropy_gradient() {
533 // Test with a simple case
534 let softmax_output = array![[0.7, 0.2, 0.1], [0.3, 0.6, 0.1]];
535 let targets = array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]];
536
537 let gradient =
538 softmax_crossentropy_gradient(&softmax_output.view(), &targets.view()).unwrap();
539
540 // For each example, gradient = (softmax_output - targets) / batchsize
541 // = ([0.7, 0.2, 0.1] - [1.0, 0.0, 0.0]) / 2
542 // = [-0.15, 0.1, 0.05]
543 // For the second example:
544 // = ([0.3, 0.6, 0.1] - [0.0, 1.0, 0.0]) / 2
545 // = [0.15, -0.2, 0.05]
546 assert_relative_eq!(gradient[[0, 0]], -0.15, epsilon = 1e-10);
547 assert_relative_eq!(gradient[[0, 1]], 0.1, epsilon = 1e-10);
548 assert_relative_eq!(gradient[[0, 2]], 0.05, epsilon = 1e-10);
549 assert_relative_eq!(gradient[[1, 0]], 0.15, epsilon = 1e-10);
550 assert_relative_eq!(gradient[[1, 1]], -0.2, epsilon = 1e-10);
551 assert_relative_eq!(gradient[[1, 2]], 0.05, epsilon = 1e-10);
552 }
553
554 #[test]
555 fn test_jacobian() {
556 // Define a simple function R^2 -> R^3
557 // f(x,y) = [x^2 + y, 2*x + y^2, x*y]
558 let f = |v: &Array1<f64>| -> Array1<f64> {
559 let x = v[0];
560 let y = v[1];
561 array![x * x + y, 2.0 * x + y * y, x * y]
562 };
563
564 let x = array![2.0, 3.0]; // Point at which to evaluate the Jacobian
565 let epsilon = 1e-5;
566
567 let jac = jacobian(&f, &x, epsilon).unwrap();
568
569 // Analytical Jacobian at (2,3) is:
570 // [2x, 1] [4, 1]
571 // [2, 2y] = [2, 6]
572 // [y, x] [3, 2]
573
574 assert_relative_eq!(jac[[0, 0]], 4.0, epsilon = 1e-4);
575 assert_relative_eq!(jac[[0, 1]], 1.0, epsilon = 1e-4);
576 assert_relative_eq!(jac[[1, 0]], 2.0, epsilon = 1e-4);
577 assert_relative_eq!(jac[[1, 1]], 6.0, epsilon = 1e-4);
578 assert_relative_eq!(jac[[2, 0]], 3.0, epsilon = 1e-4);
579 assert_relative_eq!(jac[[2, 1]], 2.0, epsilon = 1e-4);
580 }
581
582 #[test]
583 fn test_hessian() {
584 // A very simple quadratic function: f(x) = 2x²
585 // Has constant second derivative: f''(x) = 4
586 let f = |v: &Array1<f64>| -> f64 {
587 let x = v[0];
588 2.0 * x * x
589 };
590
591 let x = array![0.5];
592 let epsilon = 1e-4;
593
594 let hess = hessian(&f, &x, epsilon).unwrap();
595
596 // The Hessian (second derivative) of f(x) = 2x² is 4
597 assert_relative_eq!(hess[[0, 0]], 4.0, epsilon = 1e-2);
598 }
599
600 #[test]
601 fn test_hessian_multidimensional() {
602 // Multivariable function: f(x,y,z) = x²y + y²z + z²x
603 let f = |v: &Array1<f64>| -> f64 {
604 let x = v[0];
605 let y = v[1];
606 let z = v[2];
607 x * x * y + y * y * z + z * z * x
608 };
609
610 let x = array![1.0, 1.0, 1.0];
611 let epsilon = 1e-4;
612
613 let hess = hessian(&f, &x, epsilon).unwrap();
614
615 // Analytical Hessian at (1,1,1) for f(x,y,z) = x²y + y²z + z²x:
616 // First-order derivatives:
617 // ∂f/∂x = 2xy + z²
618 // ∂f/∂y = x² + 2yz
619 // ∂f/∂z = y² + 2zx
620 //
621 // Second-order derivatives:
622 // ∂²f/∂x² = 2y = 2 (at point [1,1,1])
623 // ∂²f/∂y² = 2z = 2 (at point [1,1,1])
624 // ∂²f/∂z² = 2x = 2 (at point [1,1,1])
625 // ∂²f/∂x∂y = ∂²f/∂y∂x = 2x = 2 (at point [1,1,1])
626 // ∂²f/∂y∂z = ∂²f/∂z∂y = 2y = 2 (at point [1,1,1])
627 // ∂²f/∂z∂x = ∂²f/∂x∂z = 2z = 2 (at point [1,1,1])
628
629 // Diagonal elements
630 assert_relative_eq!(hess[[0, 0]], 2.0, epsilon = 1e-2);
631 assert_relative_eq!(hess[[1, 1]], 2.0, epsilon = 1e-2);
632 assert_relative_eq!(hess[[2, 2]], 2.0, epsilon = 1e-2);
633
634 // Off-diagonal elements
635 assert_relative_eq!(hess[[0, 1]], 2.0, epsilon = 1e-2);
636 assert_relative_eq!(hess[[1, 0]], 2.0, epsilon = 1e-2);
637 assert_relative_eq!(hess[[1, 2]], 2.0, epsilon = 1e-2);
638 assert_relative_eq!(hess[[2, 1]], 2.0, epsilon = 1e-2);
639 assert_relative_eq!(hess[[0, 2]], 2.0, epsilon = 1e-2);
640 assert_relative_eq!(hess[[2, 0]], 2.0, epsilon = 1e-2);
641 }
642
643 #[test]
644 fn test_hessian_quadratic_form() {
645 // Quadratic form: f(x,y) = x² + xy + 2y²
646 let f = |v: &Array1<f64>| -> f64 {
647 let x = v[0];
648 let y = v[1];
649 x * x + x * y + 2.0 * y * y
650 };
651
652 let x = array![1.0, 2.0];
653 let epsilon = 1e-5;
654
655 let hess = hessian(&f, &x, epsilon).unwrap();
656
657 // Analytical Hessian is:
658 // [∂²f/∂x², ∂²f/∂x∂y] [2, 1]
659 // [∂²f/∂y∂x, ∂²f/∂y²] = [1, 4]
660 assert_relative_eq!(hess[[0, 0]], 2.0, epsilon = 1e-4);
661 assert_relative_eq!(hess[[0, 1]], 1.0, epsilon = 1e-4);
662 assert_relative_eq!(hess[[1, 0]], 1.0, epsilon = 1e-4);
663 assert_relative_eq!(hess[[1, 1]], 4.0, epsilon = 1e-4);
664 }
665
666 #[test]
667 fn test_mse_gradient_dimension_error() {
668 let predictions = array![1.0, 2.0, 3.0];
669 let targets = array![1.0, 2.0];
670
671 let result = mse_gradient(&predictions.view(), &targets.view());
672 assert!(result.is_err());
673 }
674
675 #[test]
676 fn test_binary_crossentropy_gradient_invalid_predictions() {
677 let predictions = array![0.5, 1.2, 0.3]; // Contains value > 1
678 let targets = array![1.0, 0.0, 1.0];
679
680 let result = binary_crossentropy_gradient(&predictions.view(), &targets.view());
681 assert!(result.is_err());
682
683 let predictions = array![0.5, -0.1, 0.3]; // Contains value < 0
684 let targets = array![1.0, 0.0, 1.0];
685
686 let result = binary_crossentropy_gradient(&predictions.view(), &targets.view());
687 assert!(result.is_err());
688 }
689
690 #[test]
691 fn test_binary_crossentropy_gradient_invalid_targets() {
692 let predictions = array![0.5, 0.7, 0.3];
693 let targets = array![1.0, 0.5, 1.0]; // Contains value neither 0 nor 1
694
695 let result = binary_crossentropy_gradient(&predictions.view(), &targets.view());
696 assert!(result.is_err());
697 }
698
699 #[test]
700 fn test_softmax_crossentropy_gradient_invalid_softmax() {
701 let softmax_output = array![[0.7, 0.2, 0.2], [0.3, 0.6, 0.1]]; // First row sums to 1.1
702 let targets = array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]];
703
704 let result = softmax_crossentropy_gradient(&softmax_output.view(), &targets.view());
705 assert!(result.is_err());
706 }
707
708 #[test]
709 fn test_softmax_crossentropy_gradient_invalid_targets() {
710 let softmax_output = array![[0.7, 0.2, 0.1], [0.3, 0.6, 0.1]];
711 let targets = array![[1.0, 0.0, 0.0], [0.3, 0.3, 0.4]]; // Second row is definitely not one-hot (sum = 1 but not one-hot)
712
713 let result = softmax_crossentropy_gradient(&softmax_output.view(), &targets.view());
714 assert!(result.is_err());
715 }
716}