rustygrad/
neuron.rs

1use crate::Value;
2use rand::{distributions::Uniform, Rng};
3use std::fmt::{self, Debug};
4
5pub struct Neuron {
6    w: Vec<Value>,
7    b: Value,
8    nonlin: bool,
9}
10
11impl Debug for Neuron {
12    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
13        let name = if self.nonlin { "ReLU" } else { "Linear" };
14        write!(f, "{}({})", name, self.w.len())
15    }
16}
17
18impl Neuron {
19    pub fn new(nin: i32, nonlin: bool) -> Neuron {
20        let mut rng = rand::thread_rng();
21        let range = Uniform::<f64>::new(-1.0, 1.0);
22
23        Neuron {
24            w: (0..nin).map(|_| Value::from(rng.sample(range))).collect(),
25            b: Value::from(0.0),
26            nonlin,
27        }
28    }
29
30    pub fn from(nin: i32) -> Neuron {
31        Neuron::new(nin, true)
32    }
33
34    pub fn forward(&self, x: &Vec<Value>) -> Value {
35        let wixi_sum: Value = self.w.iter().zip(x).map(|(wi, xi)| wi * xi).sum();
36        let out = wixi_sum + &self.b;
37
38        if self.nonlin {
39            return out.relu();
40        }
41        out
42    }
43
44    pub fn parameters(&self) -> Vec<Value> {
45        let mut out = self.w.clone();
46        out.insert(0, self.b.clone());
47        out
48    }
49}