zenu_layer/layers/
dropout.rs

1use std::{cell::RefCell, collections::HashMap, rc::Rc};
2
3use zenu_autograd::{
4    nn::dropout::{dropout, DropoutConfig},
5    Variable,
6};
7use zenu_matrix::{
8    device::Device,
9    dim::{DimDyn, DimTrait},
10    num::Num,
11};
12
13use crate::{Module, Parameters};
14
15pub struct Dropout<T: Num, D: Device> {
16    config: DropoutConfig<T, D>,
17    input_shape: Option<Rc<RefCell<DimDyn>>>,
18    raio: f32,
19}
20
21impl<T: Num, D: Device> Module<T, D> for Dropout<T, D> {
22    type Input = Variable<T, D>;
23    type Output = Variable<T, D>;
24    fn call(&self, input: Variable<T, D>) -> Variable<T, D> {
25        if self.input_shape.as_ref().unwrap().borrow().slice() != input.get_shape().slice() {
26            todo!();
27        }
28        dropout(input, self.raio, Some(self.config.clone()))
29    }
30}
31
32impl<T: Num, D: Device> Dropout<T, D> {
33    #[must_use]
34    pub fn new(rate: f32) -> Self {
35        let config = DropoutConfig::new(rate);
36        Self {
37            config,
38            input_shape: None,
39            raio: rate,
40        }
41    }
42
43    pub fn gpu_init(&self, shape: DimDyn) {
44        self.config.gpu_init(shape);
45    }
46}
47
48impl<T: Num, D: Device> Parameters<T, D> for Dropout<T, D> {
49    fn weights(&self) -> HashMap<String, Variable<T, D>> {
50        HashMap::new()
51    }
52
53    fn biases(&self) -> HashMap<String, Variable<T, D>> {
54        HashMap::new()
55    }
56}