zenu_layer/
lib.rs

1use std::{collections::HashMap, hash::BuildHasher};
2
3use zenu_autograd::Variable;
4use zenu_matrix::{device::Device, num::Num};
5
6pub mod layers;
7
8pub trait ModuleParameters<T: Num, D: Device> {}
9
10impl<T: Num, D: Device> ModuleParameters<T, D> for () {}
11
12impl<T: Num, D: Device> ModuleParameters<T, D> for Variable<T, D> {}
13
14impl<T: Num, D: Device> ModuleParameters<T, D> for Vec<Variable<T, D>> {}
15
16impl<T: Num, D: Device, K, S: BuildHasher> ModuleParameters<T, D>
17    for HashMap<K, Variable<T, D>, S>
18{
19}
20
21pub trait Module<T: Num, D: Device> {
22    type Input: ModuleParameters<T, D>;
23    type Output: ModuleParameters<T, D>;
24    fn call(&self, input: Self::Input) -> Self::Output;
25}
26
27pub trait Parameters<T: Num, D: Device> {
28    fn weights(&self) -> HashMap<String, Variable<T, D>>;
29    fn biases(&self) -> HashMap<String, Variable<T, D>>;
30    fn parameters(&self) -> HashMap<String, Variable<T, D>> {
31        let weights = self.weights();
32        let biases = self.biases();
33        let mut parameters = HashMap::new();
34        for (key, value) in weights {
35            parameters.insert(key.clone(), value.clone());
36        }
37        for (key, value) in biases {
38            parameters.insert(key.clone(), value.clone());
39        }
40        parameters
41    }
42    fn load_parameters(&mut self, parameters: HashMap<String, Variable<T, D>>) {
43        for (self_key, self_value) in self.parameters() {
44            if let Some(value) = parameters.get(&self_key) {
45                self_value.get_as_mut().copy_from(&value.get_as_ref());
46            } else {
47                panic!("Failed to load model missing key: {self_key}");
48            }
49        }
50    }
51}
52
53impl<T: Num, D: Device> Parameters<T, D> for () {
54    fn weights(&self) -> HashMap<String, Variable<T, D>> {
55        HashMap::new()
56    }
57
58    fn biases(&self) -> HashMap<String, Variable<T, D>> {
59        HashMap::new()
60    }
61}
62impl<T: Num, D: Device, P: Parameters<T, D>> Parameters<T, D> for Vec<P> {
63    fn weights(&self) -> HashMap<String, Variable<T, D>> {
64        let mut weights = HashMap::new();
65        for (idx, param) in self.iter().enumerate() {
66            for (key, value) in param.weights() {
67                weights.insert(format!("{idx}.{key}"), value.clone());
68            }
69        }
70        weights
71    }
72
73    fn biases(&self) -> HashMap<String, Variable<T, D>> {
74        let mut biases = HashMap::new();
75        for (idx, param) in self.iter().enumerate() {
76            for (key, value) in param.biases() {
77                biases.insert(format!("{idx}.{key}",), value.clone());
78            }
79        }
80        biases
81    }
82}
83
84impl<T: Num, D: Device> Parameters<T, D> for Box<dyn Parameters<T, D>> {
85    fn weights(&self) -> HashMap<String, Variable<T, D>> {
86        self.as_ref().weights()
87    }
88
89    fn biases(&self) -> HashMap<String, Variable<T, D>> {
90        self.as_ref().biases()
91    }
92}
93
94impl<T: Num, D: Device, P: Parameters<T, D>, S: ::std::hash::BuildHasher> Parameters<T, D>
95    for HashMap<String, P, S>
96{
97    fn weights(&self) -> HashMap<String, Variable<T, D>> {
98        let mut weights = HashMap::new();
99        for (key, param) in self {
100            for (sub_key, value) in param.weights() {
101                weights.insert(format!("{key}.{sub_key}"), value.clone());
102            }
103        }
104        weights
105    }
106
107    fn biases(&self) -> HashMap<String, Variable<T, D>> {
108        let mut biases = HashMap::new();
109        for (key, param) in self {
110            for (sub_key, value) in param.biases() {
111                biases.insert(format!("{key}.{sub_key}"), value.clone());
112            }
113        }
114        biases
115    }
116}