1use std::{collections::HashMap, hash::BuildHasher};
2
3use zenu_autograd::Variable;
4use zenu_matrix::{device::Device, num::Num};
5
6pub mod layers;
7
8pub trait ModuleParameters<T: Num, D: Device> {}
9
10impl<T: Num, D: Device> ModuleParameters<T, D> for () {}
11
12impl<T: Num, D: Device> ModuleParameters<T, D> for Variable<T, D> {}
13
14impl<T: Num, D: Device> ModuleParameters<T, D> for Vec<Variable<T, D>> {}
15
16impl<T: Num, D: Device, K, S: BuildHasher> ModuleParameters<T, D>
17 for HashMap<K, Variable<T, D>, S>
18{
19}
20
21pub trait Module<T: Num, D: Device> {
22 type Input: ModuleParameters<T, D>;
23 type Output: ModuleParameters<T, D>;
24 fn call(&self, input: Self::Input) -> Self::Output;
25}
26
27pub trait Parameters<T: Num, D: Device> {
28 fn weights(&self) -> HashMap<String, Variable<T, D>>;
29 fn biases(&self) -> HashMap<String, Variable<T, D>>;
30 fn parameters(&self) -> HashMap<String, Variable<T, D>> {
31 let weights = self.weights();
32 let biases = self.biases();
33 let mut parameters = HashMap::new();
34 for (key, value) in weights {
35 parameters.insert(key.clone(), value.clone());
36 }
37 for (key, value) in biases {
38 parameters.insert(key.clone(), value.clone());
39 }
40 parameters
41 }
42 fn load_parameters(&mut self, parameters: HashMap<String, Variable<T, D>>) {
43 for (self_key, self_value) in self.parameters() {
44 if let Some(value) = parameters.get(&self_key) {
45 self_value.get_as_mut().copy_from(&value.get_as_ref());
46 } else {
47 panic!("Failed to load model missing key: {self_key}");
48 }
49 }
50 }
51}
52
53impl<T: Num, D: Device> Parameters<T, D> for () {
54 fn weights(&self) -> HashMap<String, Variable<T, D>> {
55 HashMap::new()
56 }
57
58 fn biases(&self) -> HashMap<String, Variable<T, D>> {
59 HashMap::new()
60 }
61}
62impl<T: Num, D: Device, P: Parameters<T, D>> Parameters<T, D> for Vec<P> {
63 fn weights(&self) -> HashMap<String, Variable<T, D>> {
64 let mut weights = HashMap::new();
65 for (idx, param) in self.iter().enumerate() {
66 for (key, value) in param.weights() {
67 weights.insert(format!("{idx}.{key}"), value.clone());
68 }
69 }
70 weights
71 }
72
73 fn biases(&self) -> HashMap<String, Variable<T, D>> {
74 let mut biases = HashMap::new();
75 for (idx, param) in self.iter().enumerate() {
76 for (key, value) in param.biases() {
77 biases.insert(format!("{idx}.{key}",), value.clone());
78 }
79 }
80 biases
81 }
82}
83
84impl<T: Num, D: Device> Parameters<T, D> for Box<dyn Parameters<T, D>> {
85 fn weights(&self) -> HashMap<String, Variable<T, D>> {
86 self.as_ref().weights()
87 }
88
89 fn biases(&self) -> HashMap<String, Variable<T, D>> {
90 self.as_ref().biases()
91 }
92}
93
94impl<T: Num, D: Device, P: Parameters<T, D>, S: ::std::hash::BuildHasher> Parameters<T, D>
95 for HashMap<String, P, S>
96{
97 fn weights(&self) -> HashMap<String, Variable<T, D>> {
98 let mut weights = HashMap::new();
99 for (key, param) in self {
100 for (sub_key, value) in param.weights() {
101 weights.insert(format!("{key}.{sub_key}"), value.clone());
102 }
103 }
104 weights
105 }
106
107 fn biases(&self) -> HashMap<String, Variable<T, D>> {
108 let mut biases = HashMap::new();
109 for (key, param) in self {
110 for (sub_key, value) in param.biases() {
111 biases.insert(format!("{key}.{sub_key}"), value.clone());
112 }
113 }
114 biases
115 }
116}