Skip to main content

runn/layer/
mod.rs

1use crate::{matrix::DMat, random::Randomizer, ActivationFunction, Optimizer, Regularization, SummaryWriter};
2
3pub mod dense_layer;
4
5#[typetag::serde]
6pub trait Layer: LayerClone + Send + Sync {
7    fn forward(&self, input: &DMat) -> (DMat, DMat);
8    fn backward(
9        &self, d_output: &DMat, input: &DMat, pre_activated_output: &DMat, activated_output: &DMat,
10    ) -> (DMat, DMat, DMat);
11    fn activation_function(&self) -> &dyn ActivationFunction;
12    fn regulate(&mut self, d_weights: &mut DMat, d_biases: &mut DMat, regularization: &dyn Regularization);
13    fn update(&mut self, d_weights: &DMat, d_biases: &DMat, epoch: usize);
14    fn summarize(&self, epoch: usize, summary_writer: &mut dyn SummaryWriter);
15    fn visualize(&self);
16    fn input_output_size(&self) -> (usize, usize);
17}
18
19pub trait LayerClone {
20    fn clone_box(&self) -> Box<dyn Layer>;
21}
22
23impl<T> LayerClone for T
24where
25    T: 'static + Layer + Clone,
26{
27    fn clone_box(&self) -> Box<dyn Layer> {
28        Box::new(self.clone())
29    }
30}
31
32impl Clone for Box<dyn Layer> {
33    fn clone(&self) -> Box<dyn Layer> {
34        self.clone_box()
35    }
36}
37
38pub trait LayerConfig {
39    fn size(&self) -> usize;
40    fn create_layer(
41        &mut self, name: String, input_size: usize, optimizer: Box<dyn Optimizer>, randomizer: &Randomizer,
42    ) -> Box<dyn Layer>;
43}
44
45impl LayerConfig for Box<dyn LayerConfig> {
46    fn size(&self) -> usize {
47        (**self).size() // Dereference the Box to call the method on the inner type
48    }
49    fn create_layer(
50        &mut self, name: String, input_size: usize, optimizer: Box<dyn Optimizer>, randomizer: &Randomizer,
51    ) -> Box<dyn Layer> {
52        (**self).create_layer(name, input_size, optimizer, randomizer) // Dereference the Box to call the method on the inner type
53    }
54}