1use crate::{ActivationFunction, Connection};
2use num_traits::Float;
3#[cfg(feature = "serde")]
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone)]
8#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9pub struct Neuron<T: Float> {
10 pub sum: T,
12
13 pub value: T,
15
16 pub activation_steepness: T,
18
19 pub activation_function: ActivationFunction,
21
22 pub connections: Vec<Connection<T>>,
24
25 pub is_bias: bool,
27}
28
29impl<T: Float> Neuron<T> {
30 pub fn new(activation_function: ActivationFunction, activation_steepness: T) -> Self {
44 Neuron {
45 sum: T::zero(),
46 value: T::zero(),
47 activation_steepness,
48 activation_function,
49 connections: Vec::new(),
50 is_bias: false,
51 }
52 }
53
54 pub fn new_bias() -> Self {
56 let one = T::one();
57 Neuron {
58 sum: one,
59 value: one,
60 activation_steepness: one,
61 activation_function: ActivationFunction::Linear,
62 connections: Vec::new(),
63 is_bias: true,
64 }
65 }
66
67 pub fn add_connection(&mut self, from_neuron: usize, weight: T) {
73 let neuron_index = self.connections.len();
74 self.connections
75 .push(Connection::new(from_neuron, neuron_index, weight));
76 }
77
78 pub fn clear_connections(&mut self) {
80 self.connections.clear();
81 }
82
83 pub fn reset(&mut self) {
86 if self.is_bias {
87 self.sum = T::one();
88 self.value = T::one();
89 } else {
90 self.sum = T::zero();
91 self.value = T::zero();
92 }
93 }
94
95 pub fn calculate(&mut self, inputs: &[T]) {
100 if self.is_bias {
101 return;
103 }
104
105 self.sum = T::zero();
107 for connection in &self.connections {
108 if connection.from_neuron < inputs.len() {
109 self.sum = self.sum + inputs[connection.from_neuron] * connection.weight;
110 }
111 }
112
113 self.value = self.sum;
116 }
117
118 pub fn set_value(&mut self, value: T) {
120 if !self.is_bias {
121 self.value = value;
122 self.sum = value;
123 }
124 }
125
126 pub fn get_connection_weight(&self, index: usize) -> Option<T> {
128 self.connections.get(index).map(|c| c.weight)
129 }
130
131 pub fn set_connection_weight(&mut self, index: usize, weight: T) -> Result<(), &'static str> {
133 if let Some(connection) = self.connections.get_mut(index) {
134 connection.set_weight(weight);
135 Ok(())
136 } else {
137 Err("Connection index out of bounds")
138 }
139 }
140}
141
142impl<T: Float> PartialEq for Neuron<T> {
143 fn eq(&self, other: &Self) -> bool {
144 self.activation_function == other.activation_function
145 && self.activation_steepness == other.activation_steepness
146 && self.is_bias == other.is_bias
147 && self.connections == other.connections
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154
155 #[test]
156 fn test_neuron_creation() {
157 let neuron = Neuron::<f32>::new(ActivationFunction::Sigmoid, 1.0);
158 assert_eq!(neuron.activation_function, ActivationFunction::Sigmoid);
159 assert_eq!(neuron.activation_steepness, 1.0);
160 assert_eq!(neuron.sum, 0.0);
161 assert_eq!(neuron.value, 0.0);
162 assert!(!neuron.is_bias);
163 assert!(neuron.connections.is_empty());
164 }
165
166 #[test]
167 fn test_bias_neuron() {
168 let bias = Neuron::<f32>::new_bias();
169 assert!(bias.is_bias);
170 assert_eq!(bias.value, 1.0);
171 assert_eq!(bias.sum, 1.0);
172 }
173
174 #[test]
175 fn test_add_connection() {
176 let mut neuron = Neuron::<f32>::new(ActivationFunction::ReLU, 1.0);
177 neuron.add_connection(0, 0.5);
178 neuron.add_connection(1, -0.3);
179
180 assert_eq!(neuron.connections.len(), 2);
181 assert_eq!(neuron.connections[0].from_neuron, 0);
182 assert_eq!(neuron.connections[0].weight, 0.5);
183 assert_eq!(neuron.connections[1].from_neuron, 1);
184 assert_eq!(neuron.connections[1].weight, -0.3);
185 }
186
187 #[test]
188 fn test_reset_neuron() {
189 let mut neuron = Neuron::<f32>::new(ActivationFunction::Sigmoid, 1.0);
190 neuron.sum = 5.0;
191 neuron.value = 2.5;
192
193 neuron.reset();
194 assert_eq!(neuron.sum, 0.0);
195 assert_eq!(neuron.value, 0.0);
196 }
197
198 #[test]
199 fn test_reset_bias_neuron() {
200 let mut bias = Neuron::<f32>::new_bias();
201 bias.sum = 5.0;
202 bias.value = 2.5;
203
204 bias.reset();
205 assert_eq!(bias.sum, 1.0);
206 assert_eq!(bias.value, 1.0);
207 }
208
209 #[test]
210 fn test_set_value() {
211 let mut neuron = Neuron::<f32>::new(ActivationFunction::Linear, 1.0);
212 neuron.set_value(std::f32::consts::PI);
214 assert_eq!(neuron.value, std::f32::consts::PI);
215 assert_eq!(neuron.sum, std::f32::consts::PI);
216 }
217
218 #[test]
219 fn test_calculate() {
220 let mut neuron = Neuron::<f32>::new(ActivationFunction::Linear, 1.0);
221 neuron.add_connection(0, 0.5);
222 neuron.add_connection(1, -0.3);
223 neuron.add_connection(2, 0.2);
224
225 let inputs = vec![1.0, 2.0, -1.0];
226 neuron.calculate(&inputs);
227
228 assert_eq!(neuron.sum, -0.3);
230 }
231}