ruv_fann/
neuron.rs

1use crate::{ActivationFunction, Connection};
2use num_traits::Float;
3#[cfg(feature = "serde")]
4use serde::{Deserialize, Serialize};
5
6/// Represents a single neuron in the neural network
7#[derive(Debug, Clone)]
8#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9pub struct Neuron<T: Float> {
10    /// The sum of inputs multiplied by weights
11    pub sum: T,
12
13    /// The output value after applying the activation function
14    pub value: T,
15
16    /// The steepness parameter for the activation function
17    pub activation_steepness: T,
18
19    /// The activation function to use
20    pub activation_function: ActivationFunction,
21
22    /// Incoming connections to this neuron
23    pub connections: Vec<Connection<T>>,
24
25    /// Whether this is a bias neuron
26    pub is_bias: bool,
27}
28
29impl<T: Float> Neuron<T> {
30    /// Creates a new neuron with the specified activation function and steepness
31    ///
32    /// # Arguments
33    /// * `activation_function` - The activation function to use
34    /// * `activation_steepness` - The steepness parameter for the activation function
35    ///
36    /// # Example
37    /// ```
38    /// use ruv_fann::{Neuron, ActivationFunction};
39    ///
40    /// let neuron = Neuron::<f32>::new(ActivationFunction::Sigmoid, 1.0);
41    /// assert_eq!(neuron.activation_function, ActivationFunction::Sigmoid);
42    /// ```
43    pub fn new(activation_function: ActivationFunction, activation_steepness: T) -> Self {
44        Neuron {
45            sum: T::zero(),
46            value: T::zero(),
47            activation_steepness,
48            activation_function,
49            connections: Vec::new(),
50            is_bias: false,
51        }
52    }
53
54    /// Creates a new bias neuron with a constant output value of 1.0
55    pub fn new_bias() -> Self {
56        let one = T::one();
57        Neuron {
58            sum: one,
59            value: one,
60            activation_steepness: one,
61            activation_function: ActivationFunction::Linear,
62            connections: Vec::new(),
63            is_bias: true,
64        }
65    }
66
67    /// Adds a connection from another neuron to this neuron
68    ///
69    /// # Arguments
70    /// * `from_neuron` - Index of the source neuron
71    /// * `weight` - Initial weight of the connection
72    pub fn add_connection(&mut self, from_neuron: usize, weight: T) {
73        let neuron_index = self.connections.len();
74        self.connections
75            .push(Connection::new(from_neuron, neuron_index, weight));
76    }
77
78    /// Clears all connections
79    pub fn clear_connections(&mut self) {
80        self.connections.clear();
81    }
82
83    /// Resets the neuron's sum and value to zero
84    /// (except for bias neurons which maintain value = 1.0)
85    pub fn reset(&mut self) {
86        if self.is_bias {
87            self.sum = T::one();
88            self.value = T::one();
89        } else {
90            self.sum = T::zero();
91            self.value = T::zero();
92        }
93    }
94
95    /// Calculates the neuron's output based on inputs and weights
96    ///
97    /// # Arguments
98    /// * `inputs` - Values from neurons in the previous layer
99    pub fn calculate(&mut self, inputs: &[T]) {
100        if self.is_bias {
101            // Bias neurons always output 1.0
102            return;
103        }
104
105        // Calculate weighted sum
106        self.sum = T::zero();
107        for connection in &self.connections {
108            if connection.from_neuron < inputs.len() {
109                self.sum = self.sum + inputs[connection.from_neuron] * connection.weight;
110            }
111        }
112
113        // Apply activation function (will be implemented in activation module)
114        // For now, just store the sum as the value
115        self.value = self.sum;
116    }
117
118    /// Sets the neuron's output value directly (used for input neurons)
119    pub fn set_value(&mut self, value: T) {
120        if !self.is_bias {
121            self.value = value;
122            self.sum = value;
123        }
124    }
125
126    /// Gets the weight of a specific connection by index
127    pub fn get_connection_weight(&self, index: usize) -> Option<T> {
128        self.connections.get(index).map(|c| c.weight)
129    }
130
131    /// Sets the weight of a specific connection by index
132    pub fn set_connection_weight(&mut self, index: usize, weight: T) -> Result<(), &'static str> {
133        if let Some(connection) = self.connections.get_mut(index) {
134            connection.set_weight(weight);
135            Ok(())
136        } else {
137            Err("Connection index out of bounds")
138        }
139    }
140}
141
142impl<T: Float> PartialEq for Neuron<T> {
143    fn eq(&self, other: &Self) -> bool {
144        self.activation_function == other.activation_function
145            && self.activation_steepness == other.activation_steepness
146            && self.is_bias == other.is_bias
147            && self.connections == other.connections
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154
155    #[test]
156    fn test_neuron_creation() {
157        let neuron = Neuron::<f32>::new(ActivationFunction::Sigmoid, 1.0);
158        assert_eq!(neuron.activation_function, ActivationFunction::Sigmoid);
159        assert_eq!(neuron.activation_steepness, 1.0);
160        assert_eq!(neuron.sum, 0.0);
161        assert_eq!(neuron.value, 0.0);
162        assert!(!neuron.is_bias);
163        assert!(neuron.connections.is_empty());
164    }
165
166    #[test]
167    fn test_bias_neuron() {
168        let bias = Neuron::<f32>::new_bias();
169        assert!(bias.is_bias);
170        assert_eq!(bias.value, 1.0);
171        assert_eq!(bias.sum, 1.0);
172    }
173
174    #[test]
175    fn test_add_connection() {
176        let mut neuron = Neuron::<f32>::new(ActivationFunction::ReLU, 1.0);
177        neuron.add_connection(0, 0.5);
178        neuron.add_connection(1, -0.3);
179
180        assert_eq!(neuron.connections.len(), 2);
181        assert_eq!(neuron.connections[0].from_neuron, 0);
182        assert_eq!(neuron.connections[0].weight, 0.5);
183        assert_eq!(neuron.connections[1].from_neuron, 1);
184        assert_eq!(neuron.connections[1].weight, -0.3);
185    }
186
187    #[test]
188    fn test_reset_neuron() {
189        let mut neuron = Neuron::<f32>::new(ActivationFunction::Sigmoid, 1.0);
190        neuron.sum = 5.0;
191        neuron.value = 2.5;
192
193        neuron.reset();
194        assert_eq!(neuron.sum, 0.0);
195        assert_eq!(neuron.value, 0.0);
196    }
197
198    #[test]
199    fn test_reset_bias_neuron() {
200        let mut bias = Neuron::<f32>::new_bias();
201        bias.sum = 5.0;
202        bias.value = 2.5;
203
204        bias.reset();
205        assert_eq!(bias.sum, 1.0);
206        assert_eq!(bias.value, 1.0);
207    }
208
209    #[test]
210    fn test_set_value() {
211        let mut neuron = Neuron::<f32>::new(ActivationFunction::Linear, 1.0);
212        // Use std::f32::consts::PI instead of hardcoded value
213        neuron.set_value(std::f32::consts::PI);
214        assert_eq!(neuron.value, std::f32::consts::PI);
215        assert_eq!(neuron.sum, std::f32::consts::PI);
216    }
217
218    #[test]
219    fn test_calculate() {
220        let mut neuron = Neuron::<f32>::new(ActivationFunction::Linear, 1.0);
221        neuron.add_connection(0, 0.5);
222        neuron.add_connection(1, -0.3);
223        neuron.add_connection(2, 0.2);
224
225        let inputs = vec![1.0, 2.0, -1.0];
226        neuron.calculate(&inputs);
227
228        // 1.0 * 0.5 + 2.0 * -0.3 + -1.0 * 0.2 = 0.5 - 0.6 - 0.2 = -0.3
229        assert_eq!(neuron.sum, -0.3);
230    }
231}