yarnn/
layer.rs

1use crate::backend::Backend;
2use crate::optimizer::{Optimizer, Optimizable};
3use crate::tensor::{Tensor, TensorShape};
4
5use core::marker::PhantomData;
6
7
8pub trait Layer<N, B: Backend<N>> {
9    type Config: Default;
10    fn name(&self) -> &str;
11    fn create(input_shape: TensorShape, cfg: Self::Config) -> Self;
12    
13    #[inline]
14    fn init(&mut self, _backend: &B) {}
15
16    fn input_shape(&self) -> TensorShape;
17
18    #[inline]
19    fn output_shape(&self) -> TensorShape {
20        self.input_shape()
21    }
22    
23    fn forward(&self, backend: &B, y: &mut B::Tensor, x: &B::Tensor);
24    fn backward(&self, backend: &B, dx: &mut B::Tensor, dy: &B::Tensor, x: &B::Tensor, y: &B::Tensor);
25}
26
27/// Temporary solution until I find a solution with problem of inference with specializations
28impl <T, N, B, O> Optimizable<N, B, O> for T
29    where T: Layer<N, B>,
30          B: Backend<N>,
31          O: Optimizer<N, B>
32{
33    default fn calc_gradients(&mut self, _backend: &B, _inputs: &B::Tensor, _deltas: &B::Tensor) {}
34    default fn optimize(&mut self, _backend: &B, _optimizer: &O) {}
35}
36
37pub trait AbstractLayer<N, B: Backend<N>, O: Optimizer<N, B>>: core::fmt::Display {
38    type Context: LayerContext<N, B>;
39
40    fn forward(&mut self, backend: &B, inputs: &B::Tensor, ctx: &mut Self::Context);
41    fn backward(&mut self, backend: &B, deltas: &B::Tensor, inputs: &B::Tensor, ctx: &mut Self::Context);
42    fn update(&mut self, backend: &B, optimizer: &O, inputs: &B::Tensor, deltas: &B::Tensor, ctx: &mut Self::Context);
43    
44    #[inline]
45    fn add_layer<L: Layer<N, B>>(self, cfg: L::Config) -> crate::layers::Chain<N, B, O, Self, LayerImpl<N, B, O, L>> 
46        where Self: Sized
47    {
48        crate::layers::Chain::new(
49            self,
50            LayerImpl::new(L::create(().into(), cfg)),
51        )
52    }
53}
54
55pub trait LayerContext<N, B: Backend<N>>: Default {
56    fn outputs(&self) -> &B::Tensor;
57    fn deltas(&self) -> &B::Tensor;
58}
59
60pub struct CommonLayerContext<N, B> 
61    where B: Backend<N>,
62{
63    pub outputs: B::Tensor,
64    pub deltas: B::Tensor,
65}
66
67impl <N, B> Default for CommonLayerContext<N, B> 
68    where B: Backend<N>,
69{
70    fn default() -> Self {
71        Self {
72            outputs: B::Tensor::new(()),
73            deltas: B::Tensor::new(()),
74        }
75    }
76}
77
78impl <N, B> CommonLayerContext<N, B> 
79    where B: Backend<N>,
80{
81    pub fn update_deltas_bs(&mut self, bs: u32, input_shape: &TensorShape) {
82        let mut new_deltas_shape = TensorShape::new1d(bs);
83        new_deltas_shape.append(input_shape.clone());
84
85        if self.deltas.shape() != &new_deltas_shape {
86            self.deltas.resize(new_deltas_shape.clone());
87        }
88    }
89
90    pub fn update_outputs_bs(&mut self, bs: u32, output_shape: &TensorShape) {
91        let mut new_output_shape = TensorShape::new1d(bs);
92
93        new_output_shape.append(output_shape.clone());
94
95        if self.outputs.shape() != &new_output_shape {
96            self.outputs.resize(new_output_shape);
97        }
98    }
99}
100
101impl <N, B> LayerContext<N, B> for CommonLayerContext<N, B>
102    where B: Backend<N>,
103{
104    #[inline]
105    fn outputs(&self) -> &B::Tensor {
106        &self.outputs
107    }
108
109    #[inline]
110    fn deltas(&self) -> &B::Tensor {
111        &self.deltas
112    }
113}
114
115pub struct LayerImpl <N, B, O, L> 
116    where B: Backend<N>,
117          O: Optimizer<N, B>
118{
119    pub layer: L,
120    initialized: bool,
121    _m: PhantomData<fn(N, B, O)>,
122}
123
124impl <N, B, O, L> LayerImpl<N, B, O, L> 
125    where B: Backend<N>,
126          O: Optimizer<N, B>,
127          L: Layer<N, B> + Optimizable<N, B, O>
128{
129    pub fn new(layer: L) -> Self {
130        Self {
131            layer,
132            initialized: false,
133            _m: Default::default(),
134        }
135    }
136}
137
138impl <N, B, O, L> core::fmt::Display for LayerImpl<N, B, O, L> 
139    where B: Backend<N>,
140          O: Optimizer<N, B>,
141          L: Layer<N, B>
142{
143    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
144        write!(f, "{} -> ", self.layer.input_shape())?;
145        write!(f, "{}", self.layer.name())?;
146        writeln!(f, " -> {}", self.layer.output_shape())?;
147
148        Ok(())
149    }
150}
151
152impl <N, B, O, L> AbstractLayer<N, B, O> for LayerImpl<N, B, O, L> 
153    where B: Backend<N>,
154          O: Optimizer<N, B>,
155          L: Layer<N, B> + Optimizable<N, B, O>
156{
157    type Context = CommonLayerContext<N, B>;
158
159    #[inline]
160    fn forward(&mut self, backend: &B, inputs: &B::Tensor, ctx: &mut Self::Context) {
161        if !self.initialized {
162            self.initialized = true;
163            self.layer.init(&backend);
164        }
165
166        ctx.update_outputs_bs(inputs.shape().get(0), &self.layer.output_shape());
167        self.layer.forward(&backend, &mut ctx.outputs, inputs);
168    }
169
170    #[inline]
171    fn backward(&mut self, backend: &B, deltas: &B::Tensor, inputs: &B::Tensor, ctx: &mut Self::Context) {
172        ctx.update_deltas_bs(deltas.shape().get(0), &self.layer.input_shape());
173        self.layer.backward(&backend, &mut ctx.deltas, deltas, inputs, &ctx.outputs);
174    }
175
176    #[inline]
177    fn update(&mut self, backend: &B, optimizer: &O, inputs: &B::Tensor, deltas: &B::Tensor, _ctx: &mut Self::Context) {
178        self.layer.calc_gradients(&backend, inputs, deltas);
179        self.layer.optimize(&backend, &optimizer);
180    }
181}