pub struct Swish { /* private fields */ }
Expand description
Swish activation function.
Swish is defined as: f(x) = x * sigmoid(β * x)
where β is a trainable parameter. When β = 1, it’s called Swish-1 and is commonly used. This activation was introduced in “Searching for Activation Functions” by Ramachandran et al.
§Examples
use scirs2_neural::activations::Swish;
use scirs2_neural::activations::Activation;
use ndarray::Array;
let swish = Swish::new(1.0);
let input = Array::from_vec(vec![1.0, -1.0, 2.0, -2.0]).into_dyn();
let output = swish.forward(&input).unwrap();
Implementations§
Source§impl Swish
impl Swish
Sourcepub fn new(beta: f64) -> Self
pub fn new(beta: f64) -> Self
Create a new Swish activation function with given beta parameter.
§Arguments
beta
- Parameter controlling the shape of the Swish function. β = 1.0 gives the standard Swish-1 function.
Examples found in repository?
examples/activations_example.rs (line 19)
4fn main() -> Result<(), Box<dyn std::error::Error>> {
5 println!("Activation Functions Demonstration");
6
7 // Create a set of input values
8 let x_values: Vec<f64> = (-50..=50).map(|i| i as f64 / 10.0).collect();
9 let x = Array1::from(x_values.clone());
10 let x_dyn = x.clone().into_dyn();
11
12 // Initialize all activation functions
13 let relu = ReLU::new();
14 let leaky_relu = ReLU::leaky(0.1);
15 let sigmoid = Sigmoid::new();
16 let tanh = Tanh::new();
17 let gelu = GELU::new();
18 let gelu_fast = GELU::fast();
19 let swish = Swish::new(1.0);
20 let mish = Mish::new();
21
22 // Compute outputs for each activation function
23 let relu_output = relu.forward(&x_dyn)?;
24 let leaky_relu_output = leaky_relu.forward(&x_dyn)?;
25 let sigmoid_output = sigmoid.forward(&x_dyn)?;
26 let tanh_output = tanh.forward(&x_dyn)?;
27 let gelu_output = gelu.forward(&x_dyn)?;
28 let gelu_fast_output = gelu_fast.forward(&x_dyn)?;
29 let swish_output = swish.forward(&x_dyn)?;
30 let mish_output = mish.forward(&x_dyn)?;
31
32 // Print sample values for each activation
33 println!("Sample activation values for input x = -2.0, -1.0, 0.0, 1.0, 2.0:");
34
35 let indices = [5, 40, 50, 60, 95]; // Corresponding to x = -2, -1, 0, 1, 2
36
37 println!(
38 "| {:<10} | {:<10} | {:<10} | {:<10} | {:<10} | {:<10} |",
39 "x", "-2.0", "-1.0", "0.0", "1.0", "2.0"
40 );
41 println!(
42 "|{:-<12}|{:-<12}|{:-<12}|{:-<12}|{:-<12}|{:-<12}|",
43 "", "", "", "", "", ""
44 );
45
46 println!(
47 "| {:<10} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} |",
48 "ReLU",
49 relu_output[[indices[0]]],
50 relu_output[[indices[1]]],
51 relu_output[[indices[2]]],
52 relu_output[[indices[3]]],
53 relu_output[[indices[4]]]
54 );
55
56 println!(
57 "| {:<10} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} |",
58 "LeakyReLU",
59 leaky_relu_output[[indices[0]]],
60 leaky_relu_output[[indices[1]]],
61 leaky_relu_output[[indices[2]]],
62 leaky_relu_output[[indices[3]]],
63 leaky_relu_output[[indices[4]]]
64 );
65
66 println!(
67 "| {:<10} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} |",
68 "Sigmoid",
69 sigmoid_output[[indices[0]]],
70 sigmoid_output[[indices[1]]],
71 sigmoid_output[[indices[2]]],
72 sigmoid_output[[indices[3]]],
73 sigmoid_output[[indices[4]]]
74 );
75
76 println!(
77 "| {:<10} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} |",
78 "Tanh",
79 tanh_output[[indices[0]]],
80 tanh_output[[indices[1]]],
81 tanh_output[[indices[2]]],
82 tanh_output[[indices[3]]],
83 tanh_output[[indices[4]]]
84 );
85
86 println!(
87 "| {:<10} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} |",
88 "GELU",
89 gelu_output[[indices[0]]],
90 gelu_output[[indices[1]]],
91 gelu_output[[indices[2]]],
92 gelu_output[[indices[3]]],
93 gelu_output[[indices[4]]]
94 );
95
96 println!(
97 "| {:<10} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} |",
98 "GELU Fast",
99 gelu_fast_output[[indices[0]]],
100 gelu_fast_output[[indices[1]]],
101 gelu_fast_output[[indices[2]]],
102 gelu_fast_output[[indices[3]]],
103 gelu_fast_output[[indices[4]]]
104 );
105
106 println!(
107 "| {:<10} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} |",
108 "Swish",
109 swish_output[[indices[0]]],
110 swish_output[[indices[1]]],
111 swish_output[[indices[2]]],
112 swish_output[[indices[3]]],
113 swish_output[[indices[4]]]
114 );
115
116 println!(
117 "| {:<10} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} | {:<10.6} |",
118 "Mish",
119 mish_output[[indices[0]]],
120 mish_output[[indices[1]]],
121 mish_output[[indices[2]]],
122 mish_output[[indices[3]]],
123 mish_output[[indices[4]]]
124 );
125
126 // Now test the backward pass with some dummy gradient output
127 println!("\nTesting backward pass...");
128
129 // Create a dummy gradient output
130 let dummy_grad = Array1::<f64>::ones(x.len()).into_dyn();
131
132 // Compute gradients for each activation function
133 let _relu_grad = relu.backward(&dummy_grad, &relu_output)?;
134 let _leaky_relu_grad = leaky_relu.backward(&dummy_grad, &leaky_relu_output)?;
135 let _sigmoid_grad = sigmoid.backward(&dummy_grad, &sigmoid_output)?;
136 let _tanh_grad = tanh.backward(&dummy_grad, &tanh_output)?;
137 let _gelu_grad = gelu.backward(&dummy_grad, &gelu_output)?;
138 let _gelu_fast_grad = gelu_fast.backward(&dummy_grad, &gelu_fast_output)?;
139 let _swish_grad = swish.backward(&dummy_grad, &swish_output)?;
140 let _mish_grad = mish.backward(&dummy_grad, &mish_output)?;
141
142 println!("Backward pass completed successfully.");
143
144 // Test with matrix input instead of vector
145 println!("\nTesting with matrix input...");
146
147 // Create a 3x4 matrix
148 let mut matrix = Array2::<f64>::zeros((3, 4));
149 for i in 0..3 {
150 for j in 0..4 {
151 matrix[[i, j]] = -2.0 + (i as f64 * 4.0 + j as f64) * 0.5;
152 }
153 }
154
155 // Print input matrix
156 println!("Input matrix:");
157 for i in 0..3 {
158 print!("[ ");
159 for j in 0..4 {
160 print!("{:6.2} ", matrix[[i, j]]);
161 }
162 println!("]");
163 }
164
165 // Apply GELU activation to the matrix
166 let gelu_matrix_output = gelu.forward(&matrix.into_dyn())?;
167
168 // Print output matrix
169 println!("\nAfter GELU activation:");
170 for i in 0..3 {
171 print!("[ ");
172 for j in 0..4 {
173 print!("{:6.2} ", gelu_matrix_output[[i, j]]);
174 }
175 println!("]");
176 }
177
178 println!("\nActivation functions demonstration completed successfully!");
179
180 // Note about visualization
181 println!("\nFor visualization of activation functions:");
182 println!("1. You can use external plotting libraries like plotly or matplotlib");
183 println!("2. To visualize these functions, you would plot the x_values against");
184 println!(" the output values for each activation function");
185 println!("3. The data from this example can be exported for plotting as needed");
186
187 // Example of how to access the data for plotting
188 println!("\nExample data points for plotting ReLU:");
189 for i in 0..5 {
190 let idx = i * 20; // Sample every 20th point
191 if idx < x_values.len() {
192 println!(
193 "x: {:.2}, y: {:.6}",
194 x_values[idx],
195 convert_to_vec(&relu_output)[idx]
196 );
197 }
198 }
199
200 Ok(())
201}
Trait Implementations§
Source§impl<F: Float + Debug> Activation<F> for Swish
impl<F: Float + Debug> Activation<F> for Swish
impl Copy for Swish
Auto Trait Implementations§
impl Freeze for Swish
impl RefUnwindSafe for Swish
impl Send for Swish
impl Sync for Swish
impl Unpin for Swish
impl UnwindSafe for Swish
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self
into a Left
variant of Either<Self, Self>
if into_left
is true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self
into a Left
variant of Either<Self, Self>
if into_left(&self)
returns true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read more