zenu_layer/layers/
dropout.rs1use std::{cell::RefCell, collections::HashMap, rc::Rc};
2
3use zenu_autograd::{
4 nn::dropout::{dropout, DropoutConfig},
5 Variable,
6};
7use zenu_matrix::{
8 device::Device,
9 dim::{DimDyn, DimTrait},
10 num::Num,
11};
12
13use crate::{Module, Parameters};
14
15pub struct Dropout<T: Num, D: Device> {
16 config: DropoutConfig<T, D>,
17 input_shape: Option<Rc<RefCell<DimDyn>>>,
18 raio: f32,
19}
20
21impl<T: Num, D: Device> Module<T, D> for Dropout<T, D> {
22 type Input = Variable<T, D>;
23 type Output = Variable<T, D>;
24 fn call(&self, input: Variable<T, D>) -> Variable<T, D> {
25 if self.input_shape.as_ref().unwrap().borrow().slice() != input.get_shape().slice() {
26 todo!();
27 }
28 dropout(input, self.raio, Some(self.config.clone()))
29 }
30}
31
32impl<T: Num, D: Device> Dropout<T, D> {
33 #[must_use]
34 pub fn new(rate: f32) -> Self {
35 let config = DropoutConfig::new(rate);
36 Self {
37 config,
38 input_shape: None,
39 raio: rate,
40 }
41 }
42
43 pub fn gpu_init(&self, shape: DimDyn) {
44 self.config.gpu_init(shape);
45 }
46}
47
48impl<T: Num, D: Device> Parameters<T, D> for Dropout<T, D> {
49 fn weights(&self) -> HashMap<String, Variable<T, D>> {
50 HashMap::new()
51 }
52
53 fn biases(&self) -> HashMap<String, Variable<T, D>> {
54 HashMap::new()
55 }
56}