1use crate::Value;
2use rand::{distributions::Uniform, Rng};
3use std::fmt::{self, Debug};
4
5pub struct Neuron {
6 w: Vec<Value>,
7 b: Value,
8 nonlin: bool,
9}
10
11impl Debug for Neuron {
12 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
13 let name = if self.nonlin { "ReLU" } else { "Linear" };
14 write!(f, "{}({})", name, self.w.len())
15 }
16}
17
18impl Neuron {
19 pub fn new(nin: i32, nonlin: bool) -> Neuron {
20 let mut rng = rand::thread_rng();
21 let range = Uniform::<f64>::new(-1.0, 1.0);
22
23 Neuron {
24 w: (0..nin).map(|_| Value::from(rng.sample(range))).collect(),
25 b: Value::from(0.0),
26 nonlin,
27 }
28 }
29
30 pub fn from(nin: i32) -> Neuron {
31 Neuron::new(nin, true)
32 }
33
34 pub fn forward(&self, x: &Vec<Value>) -> Value {
35 let wixi_sum: Value = self.w.iter().zip(x).map(|(wi, xi)| wi * xi).sum();
36 let out = wixi_sum + &self.b;
37
38 if self.nonlin {
39 return out.relu();
40 }
41 out
42 }
43
44 pub fn parameters(&self) -> Vec<Value> {
45 let mut out = self.w.clone();
46 out.insert(0, self.b.clone());
47 out
48 }
49}