Struct Swish

Source
pub struct Swish { /* private fields */ }
Expand description

Swish activation function.

Swish is defined as: f(x) = x * sigmoid(β * x)

where β is a trainable parameter. When β = 1, it’s called Swish-1 and is commonly used. This activation was introduced in “Searching for Activation Functions” by Ramachandran et al.

§Examples

use scirs2_neural::activations::Swish;
use scirs2_neural::activations::Activation;
use ndarray::Array;

let swish = Swish::new(1.0);
let input = Array::from_vec(vec![1.0, -1.0, 2.0, -2.0]).into_dyn();
let output = swish.forward(&input).unwrap();

Implementations§

Source§

impl Swish

Source

pub fn new(beta: f64) -> Self

Create a new Swish activation function with given beta parameter.

§Arguments
  • beta - Parameter controlling the shape of the Swish function. β = 1.0 gives the standard Swish-1 function.
Examples found in repository?
examples/activations_example.rs (line 19)
4fn main() -> Result<(), Box<dyn std::error::Error>> {
5    println!("Activation Functions Demonstration");
6
7    // Create a set of input values
8    let x_values: Vec<f64> = (-50..=50).map(|i| i as f64 / 10.0).collect();
9    let x = Array1::from(x_values.clone());
10    let x_dyn = x.clone().into_dyn();
11
12    // Initialize all activation functions
13    let relu = ReLU::new();
14    let leaky_relu = ReLU::leaky(0.1);
15    let sigmoid = Sigmoid::new();
16    let tanh = Tanh::new();
17    let gelu = GELU::new();
18    let gelu_fast = GELU::fast();
19    let swish = Swish::new(1.0);
20    let mish = Mish::new();
21
22    // Compute outputs for each activation function
23    let relu_output = relu.forward(&x_dyn)?;
24    let leaky_relu_output = leaky_relu.forward(&x_dyn)?;
25    let sigmoid_output = sigmoid.forward(&x_dyn)?;
26    let tanh_output = tanh.forward(&x_dyn)?;
27    let gelu_output = gelu.forward(&x_dyn)?;
28    let gelu_fast_output = gelu_fast.forward(&x_dyn)?;
29    let swish_output = swish.forward(&x_dyn)?;
30    let mish_output = mish.forward(&x_dyn)?;
31
32    // Print sample values for each activation
33    println!("Sample activation values for input x = -2.0, -1.0, 0.0, 1.0, 2.0:");
34
35    let indices = [5, 40, 50, 60, 95]; // Corresponding to x = -2, -1, 0, 1, 2
36
37    println!(
38        "| {:<10} | {:<10} | {:<10} | {:<10} | {:<10} | {:<10} |",
39        "x", "-2.0", "-1.0", "0.0", "1.0", "2.0"
40    );
41    println!(
42        "|{:-<12}|{:-<12}|{:-<12}|{:-<12}|{:-<12}|{:-<12}|",
43        "", "", "", "", "", ""
44    );
45
46    println!(
47        "| {:<10} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} |",
48        "ReLU",
49        relu_output[[indices[0]]],
50        relu_output[[indices[1]]],
51        relu_output[[indices[2]]],
52        relu_output[[indices[3]]],
53        relu_output[[indices[4]]]
54    );
55
56    println!(
57        "| {:<10} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} |",
58        "LeakyReLU",
59        leaky_relu_output[[indices[0]]],
60        leaky_relu_output[[indices[1]]],
61        leaky_relu_output[[indices[2]]],
62        leaky_relu_output[[indices[3]]],
63        leaky_relu_output[[indices[4]]]
64    );
65
66    println!(
67        "| {:<10} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} |",
68        "Sigmoid",
69        sigmoid_output[[indices[0]]],
70        sigmoid_output[[indices[1]]],
71        sigmoid_output[[indices[2]]],
72        sigmoid_output[[indices[3]]],
73        sigmoid_output[[indices[4]]]
74    );
75
76    println!(
77        "| {:<10} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} |",
78        "Tanh",
79        tanh_output[[indices[0]]],
80        tanh_output[[indices[1]]],
81        tanh_output[[indices[2]]],
82        tanh_output[[indices[3]]],
83        tanh_output[[indices[4]]]
84    );
85
86    println!(
87        "| {:<10} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} |",
88        "GELU",
89        gelu_output[[indices[0]]],
90        gelu_output[[indices[1]]],
91        gelu_output[[indices[2]]],
92        gelu_output[[indices[3]]],
93        gelu_output[[indices[4]]]
94    );
95
96    println!(
97        "| {:<10} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} |",
98        "GELU Fast",
99        gelu_fast_output[[indices[0]]],
100        gelu_fast_output[[indices[1]]],
101        gelu_fast_output[[indices[2]]],
102        gelu_fast_output[[indices[3]]],
103        gelu_fast_output[[indices[4]]]
104    );
105
106    println!(
107        "| {:<10} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} |",
108        "Swish",
109        swish_output[[indices[0]]],
110        swish_output[[indices[1]]],
111        swish_output[[indices[2]]],
112        swish_output[[indices[3]]],
113        swish_output[[indices[4]]]
114    );
115
116    println!(
117        "| {:<10} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} |",
118        "Mish",
119        mish_output[[indices[0]]],
120        mish_output[[indices[1]]],
121        mish_output[[indices[2]]],
122        mish_output[[indices[3]]],
123        mish_output[[indices[4]]]
124    );
125
126    // Now test the backward pass with some dummy gradient output
127    println!("\nTesting backward pass...");
128
129    // Create a dummy gradient output
130    let dummy_grad = Array1::<f64>::ones(x.len()).into_dyn();
131
132    // Compute gradients for each activation function
133    let _relu_grad = relu.backward(&dummy_grad, &relu_output)?;
134    let _leaky_relu_grad = leaky_relu.backward(&dummy_grad, &leaky_relu_output)?;
135    let _sigmoid_grad = sigmoid.backward(&dummy_grad, &sigmoid_output)?;
136    let _tanh_grad = tanh.backward(&dummy_grad, &tanh_output)?;
137    let _gelu_grad = gelu.backward(&dummy_grad, &gelu_output)?;
138    let _gelu_fast_grad = gelu_fast.backward(&dummy_grad, &gelu_fast_output)?;
139    let _swish_grad = swish.backward(&dummy_grad, &swish_output)?;
140    let _mish_grad = mish.backward(&dummy_grad, &mish_output)?;
141
142    println!("Backward pass completed successfully.");
143
144    // Test with matrix input instead of vector
145    println!("\nTesting with matrix input...");
146
147    // Create a 3x4 matrix
148    let mut matrix = Array2::<f64>::zeros((3, 4));
149    for i in 0..3 {
150        for j in 0..4 {
151            matrix[[i, j]] = -2.0 + (i as f64 * 4.0 + j as f64) * 0.5;
152        }
153    }
154
155    // Print input matrix
156    println!("Input matrix:");
157    for i in 0..3 {
158        print!("[ ");
159        for j in 0..4 {
160            print!("{:6.2} ", matrix[[i, j]]);
161        }
162        println!("]");
163    }
164
165    // Apply GELU activation to the matrix
166    let gelu_matrix_output = gelu.forward(&matrix.into_dyn())?;
167
168    // Print output matrix
169    println!("\nAfter GELU activation:");
170    for i in 0..3 {
171        print!("[ ");
172        for j in 0..4 {
173            print!("{:6.2} ", gelu_matrix_output[[i, j]]);
174        }
175        println!("]");
176    }
177
178    println!("\nActivation functions demonstration completed successfully!");
179
180    // Note about visualization
181    println!("\nFor visualization of activation functions:");
182    println!("1. You can use external plotting libraries like plotly or matplotlib");
183    println!("2. To visualize these functions, you would plot the x_values against");
184    println!("   the output values for each activation function");
185    println!("3. The data from this example can be exported for plotting as needed");
186
187    // Example of how to access the data for plotting
188    println!("\nExample data points for plotting ReLU:");
189    for i in 0..5 {
190        let idx = i * 20; // Sample every 20th point
191        if idx < x_values.len() {
192            println!(
193                "x: {:.2}, y: {:.6}",
194                x_values[idx],
195                convert_to_vec(&relu_output)[idx]
196            );
197        }
198    }
199
200    Ok(())
201}

Trait Implementations§

Source§

impl<F: Float + Debug> Activation<F> for Swish

Source§

fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>>

Apply the activation function to the input Read more
Source§

fn backward( &self, grad_output: &Array<F, IxDyn>, output: &Array<F, IxDyn>, ) -> Result<Array<F, IxDyn>>

Compute the derivative of the activation function with respect to the input Read more
Source§

impl Clone for Swish

Source§

fn clone(&self) -> Swish

Returns a duplicate of the value. Read more
1.0.0 · Source§

const fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl Debug for Swish

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more
Source§

impl Default for Swish

Source§

fn default() -> Self

Returns the “default value” for a type. Read more
Source§

impl Copy for Swish

Auto Trait Implementations§

§

impl Freeze for Swish

§

impl RefUnwindSafe for Swish

§

impl Send for Swish

§

impl Sync for Swish

§

impl Unpin for Swish

§

impl UnwindSafe for Swish

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V