activations_example/
activations_example.rs

1use scirs2_core::ndarray::{Array, Array1, Array2};
2use scirs2_neural::activations_minimal::{Activation, ReLU, Sigmoid, Tanh, GELU};
3
4#[allow(dead_code)]
5fn main() -> Result<(), Box<dyn std::error::Error>> {
6    println!("Activation Functions Demonstration");
7    // Create a set of input values
8    let x_values: Vec<f64> = (-50..=50).map(|i| i as f64 / 10.0).collect();
9    let x = Array1::from(x_values.clone());
10    let x_dyn = x.clone().into_dyn();
11    // Initialize all activation functions
12    let relu = ReLU::new();
13    let leaky_relu = ReLU::leaky(0.1);
14    let sigmoid = Sigmoid::new();
15    let tanh = Tanh::new();
16    let gelu = GELU::new();
17    let gelu_fast = GELU::fast();
18    // Compute outputs for each activation function
19    let relu_output = relu.forward(&x_dyn)?;
20    let leaky_relu_output = leaky_relu.forward(&x_dyn)?;
21    let sigmoid_output = sigmoid.forward(&x_dyn)?;
22    let tanh_output = tanh.forward(&x_dyn)?;
23    let gelu_output = gelu.forward(&x_dyn)?;
24    let gelu_fast_output = gelu_fast.forward(&x_dyn)?;
25    // Note: Swish and Mish are not available in the minimal activation set
26    // Print sample values for each activation
27    println!("Sample activation values for input x = -2.0, -1.0, 0.0, 1.0, 2.0:");
28    let indices = [5, 40, 50, 60, 95]; // Corresponding to x = -2, -1, 0, 1, 2
29    println!(
30        "| {:<10} | {:<10} | {:<10} | {:<10} | {:<10} | {:<10} |",
31        "x", "-2.0", "-1.0", "0.0", "1.0", "2.0"
32    );
33    println!(
34        "|{:-<12}|{:-<12}|{:-<12}|{:-<12}|{:-<12}|{:-<12}|",
35        "", "", "", "", "", ""
36    );
37    println!(
38        "| {:<10} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} |",
39        "ReLU",
40        relu_output[[indices[0]]],
41        relu_output[[indices[1]]],
42        relu_output[[indices[2]]],
43        relu_output[[indices[3]]],
44        relu_output[[indices[4]]]
45    );
46    println!(
47        "| {:<10} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} |",
48        "LeakyReLU",
49        leaky_relu_output[[indices[0]]],
50        leaky_relu_output[[indices[1]]],
51        leaky_relu_output[[indices[2]]],
52        leaky_relu_output[[indices[3]]],
53        leaky_relu_output[[indices[4]]]
54    );
55    println!(
56        "| {:<10} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} |",
57        "Sigmoid",
58        sigmoid_output[[indices[0]]],
59        sigmoid_output[[indices[1]]],
60        sigmoid_output[[indices[2]]],
61        sigmoid_output[[indices[3]]],
62        sigmoid_output[[indices[4]]]
63    );
64    println!(
65        "| {:<10} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} |",
66        "Tanh",
67        tanh_output[[indices[0]]],
68        tanh_output[[indices[1]]],
69        tanh_output[[indices[2]]],
70        tanh_output[[indices[3]]],
71        tanh_output[[indices[4]]]
72    );
73    println!(
74        "| {:<10} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} |",
75        "GELU",
76        gelu_output[[indices[0]]],
77        gelu_output[[indices[1]]],
78        gelu_output[[indices[2]]],
79        gelu_output[[indices[3]]],
80        gelu_output[[indices[4]]]
81    );
82    println!(
83        "| {:<10} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} |",
84        "GELU Fast",
85        gelu_fast_output[[indices[0]]],
86        gelu_fast_output[[indices[1]]],
87        gelu_fast_output[[indices[2]]],
88        gelu_fast_output[[indices[3]]],
89        gelu_fast_output[[indices[4]]]
90    );
91    // Swish and Mish not available in minimal activation set
92    // Now test the backward pass with some dummy gradient output
93    println!("\nTesting backward pass...");
94    // Create a dummy gradient output
95    let dummy_grad = Array1::<f64>::ones(x.len()).into_dyn();
96    // Compute gradients for each activation function
97    let _relu_grad = relu.backward(&dummy_grad, &relu_output)?;
98    let _leaky_relu_grad = leaky_relu.backward(&dummy_grad, &leaky_relu_output)?;
99    let _sigmoid_grad = sigmoid.backward(&dummy_grad, &sigmoid_output)?;
100    let _tanh_grad = tanh.backward(&dummy_grad, &tanh_output)?;
101    let _gelu_grad = gelu.backward(&dummy_grad, &gelu_output)?;
102    let _gelu_fast_grad = gelu_fast.backward(&dummy_grad, &gelu_fast_output)?;
103    // Note: Swish and Mish gradients not available in minimal set
104    println!("Backward pass completed successfully.");
105    // Test with matrix input instead of vector
106    println!("\nTesting with matrix input...");
107    // Create a 3x4 matrix
108    let mut matrix = Array2::<f64>::zeros((3, 4));
109    for i in 0..3 {
110        for j in 0..4 {
111            matrix[[i, j]] = -2.0 + (i as f64 * 4.0 + j as f64) * 0.5;
112        }
113    }
114    // Print input matrix
115    println!("Input matrix:");
116    for i in 0..3 {
117        print!("[ ");
118        for j in 0..4 {
119            print!("{:6.2} ", matrix[[i, j]]);
120        }
121        println!("]");
122    }
123    // Apply GELU activation to the matrix
124    let gelu_matrix_output = gelu.forward(&matrix.into_dyn())?;
125    // Print output matrix
126    println!("\nAfter GELU activation:");
127    for i in 0..3 {
128        print!("[ ");
129        for j in 0..4 {
130            print!("{:6.2} ", gelu_matrix_output[[i, j]]);
131        }
132        println!("]");
133    }
134    println!("\nActivation functions demonstration completed successfully!");
135    // Note about visualization
136    println!("\nFor visualization of activation functions:");
137    println!("1. You can use external plotting libraries like plotly or matplotlib");
138    println!("2. To visualize these functions, you would plot the x_values against");
139    println!("   the output values for each activation function");
140    println!("3. The data from this example can be exported for plotting as needed");
141    // Example of how to access the data for plotting
142    println!("\nExample data points for plotting ReLU:");
143    for i in 0..5 {
144        let idx = i * 20; // Sample every 20th point
145        if idx < x_values.len() {
146            println!(
147                "x: {:.2}, y: {:.6}",
148                x_values[idx],
149                convert_to_vec(&relu_output)[idx]
150            );
151        }
152    }
153    Ok(())
154}
155#[allow(dead_code)]
156fn convert_to_vec<F: Clone>(array: &Array<F, scirs2_core::ndarray::IxDyn>) -> Vec<F> {
157    array.iter().cloned().collect()
158}