pub struct Softmax { /* private fields */ }
Expand description
Softmax activation function
Implementations§
Source§impl Softmax
impl Softmax
Sourcepub fn new(axis: isize) -> Self
pub fn new(axis: isize) -> Self
Examples found in repository?
examples/test_softmax.rs (line 10)
5fn main() {
6 println!("Testing softmax implementation...\n");
7 // Test case 1: Simple 1D array
8 let input = arr1(&[1.0, 2.0, 3.0]);
9 println!("Input: {input:?}");
10 let softmax = Softmax::new(0);
11 let output = softmax.forward(&input.clone().into_dyn()).unwrap();
12 println!("Softmax output: {output:?}");
13 // Verify that output sums to 1
14 let sum: f64 = output.sum();
15 println!("Sum of softmax: {sum}");
16 assert!((sum - 1.0).abs() < 1e-6, "Softmax should sum to 1");
17 // Test case 2: 2D array (batch processing)
18 println!("\nTest case 2: 2D batch");
19 let input_2d = arr2(&[[1.0, 2.0, 3.0], [3.0, 2.0, 1.0], [2.0, 2.0, 2.0]]);
20 println!("Input 2D:\n{input_2d:?}");
21 // Apply softmax along axis 1 (row-wise)
22 let softmax_2d = Softmax::new(1);
23 let output_2d = softmax_2d.forward(&input_2d.clone().into_dyn()).unwrap();
24 println!("Softmax output 2D:\n{output_2d:?}");
25 // Verify each row sums to 1
26 for i in 0..output_2d.shape()[0] {
27 let row_sum: f64 = output_2d.slice(scirs2_core::ndarray::s![i, ..]).sum();
28 println!("Row {i} sum: {row_sum}");
29 assert!((row_sum - 1.0).abs() < 1e-6, "Each row should sum to 1");
30 }
31 // Test case 3: Gradient computation
32 println!("\nTest case 3: Gradient computation");
33 let grad_output = arr1(&[0.1, 0.2, 0.3]).into_dyn();
34 let forward_output = softmax.forward(&input.clone().into_dyn()).unwrap();
35 let grad_input = softmax.backward(&grad_output, &forward_output).unwrap();
36 println!("Gradient input: {grad_input:?}");
37 println!("\nAll tests passed!");
38}
Trait Implementations§
impl Copy for Softmax
Auto Trait Implementations§
impl Freeze for Softmax
impl RefUnwindSafe for Softmax
impl Send for Softmax
impl Sync for Softmax
impl Unpin for Softmax
impl UnwindSafe for Softmax
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self
into a Left
variant of Either<Self, Self>
if into_left
is true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self
into a Left
variant of Either<Self, Self>
if into_left(&self)
returns true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read more