1use std::collections::HashMap;
2#[cfg(feature = "nvidia")]
3use std::{cell::RefCell, rc::Rc};
4
5use zenu_autograd::{
6 nn::rnns::weights::{CellType, RNNLayerWeights, RNNWeights},
7 Variable,
8};
9#[cfg(feature = "nvidia")]
10use zenu_matrix::{
11 device::nvidia::Nvidia,
12 nn::rnn::{RNNDescriptor, RNNWeights as RNNWeightsMat},
13};
14
15use zenu_matrix::{device::Device, num::Num};
16
17use crate::Parameters;
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum Activation {
21 ReLU,
22 Tanh,
23}
24
25#[expect(clippy::module_name_repetitions)]
26pub struct RNNInner<T: Num, D: Device, C: CellType> {
27 pub(super) weights: Option<Vec<RNNLayerWeights<T, D, C>>>,
28 #[cfg(feature = "nvidia")]
29 pub(super) desc: Option<Rc<RefCell<RNNDescriptor<T>>>>,
30 #[cfg(feature = "nvidia")]
31 pub(super) cudnn_weights: Option<Variable<T, Nvidia>>,
32 #[cfg(feature = "nvidia")]
33 pub(super) is_cudnn: bool,
34 pub(super) is_bidirectional: bool,
35 pub(super) activation: Option<Activation>,
36 #[cfg(feature = "nvidia")]
37 pub(super) is_training: bool,
38}
39fn get_num_layers<T: Num, D: Device, C: CellType>(
40 parameters: &HashMap<String, Variable<T, D>>,
41) -> usize {
42 let mut num_layers = 0;
43 for key in parameters.keys() {
44 if key.starts_with(format!("{}.", C::name()).as_str()) {
45 let layer_num = key.split('.').nth(1).unwrap().parse::<usize>().unwrap();
46 num_layers = num_layers.max(layer_num);
47 }
48 }
49 num_layers + 1
50}
51
52fn is_bidirectional<T: Num, D: Device, C: CellType>(
53 parameters: &HashMap<String, Variable<T, D>>,
54) -> bool {
55 for key in parameters.keys() {
56 if key.starts_with(format!("{}.", C::name()).as_str()) && key.contains("reverse") {
57 return true;
58 }
59 }
60 false
61}
62
63fn get_nth_weights<T: Num, D: Device, C: CellType>(
64 parameters: &HashMap<String, Variable<T, D>>,
65 idx: usize,
66 is_bidirectional: bool,
67) -> RNNLayerWeights<T, D, C> {
68 let cell_name = C::name();
69 let forward_input = parameters
70 .get(&format!("{cell_name}.{idx}.forward.weight_input"))
71 .unwrap()
72 .clone();
73 let forward_hidden = parameters
74 .get(&format!("{cell_name}.{idx}.forward.weight_hidden"))
75 .unwrap()
76 .clone();
77 let forward_bias_input = parameters
78 .get(&format!("{cell_name}.{idx}.forward.bias_input"))
79 .unwrap()
80 .clone();
81 let forward_bias_hidden = parameters
82 .get(&format!("{cell_name}.{idx}.forward.bias_hidden"))
83 .unwrap()
84 .clone();
85
86 let forward = RNNWeights::new(
87 forward_input,
88 forward_hidden,
89 forward_bias_input,
90 forward_bias_hidden,
91 );
92
93 if is_bidirectional {
94 let reverse_input = parameters
95 .get(&format!("{cell_name}.{idx}.reverse.weight_input"))
96 .unwrap()
97 .clone();
98 let reverse_hidden = parameters
99 .get(&format!("{cell_name}.{idx}.reverse.weight_hidden"))
100 .unwrap()
101 .clone();
102 let reverse_bias_input = parameters
103 .get(&format!("{cell_name}.{idx}.reverse.bias_input"))
104 .unwrap()
105 .clone();
106 let reverse_bias_hidden = parameters
107 .get(&format!("{cell_name}.{idx}.reverse.bias_hidden"))
108 .unwrap()
109 .clone();
110
111 let backward = RNNWeights::new(
112 reverse_input,
113 reverse_hidden,
114 reverse_bias_input,
115 reverse_bias_hidden,
116 );
117
118 RNNLayerWeights::new(forward, Some(backward))
119 } else {
120 RNNLayerWeights::new(forward, None)
121 }
122}
123
124impl<T: Num, D: Device, C: CellType> Parameters<T, D> for RNNInner<T, D, C> {
125 fn weights(&self) -> HashMap<String, Variable<T, D>> {
126 let cell_name = C::name();
127 #[cfg(feature = "nvidia")]
128 let weights = if self.is_cudnn {
129 self.cudnn_weights_to_layer_weights()
130 } else {
131 self.weights.as_ref().unwrap().clone()
132 };
133
134 #[cfg(not(feature = "nvidia"))]
135 let weights = self.weights.as_ref().unwrap().clone();
136
137 let mut parameters = HashMap::new();
138
139 for (idx, weight) in weights.iter().enumerate() {
140 let forward = weight.forward.clone();
141 let backward = weight.backward.clone();
142
143 let forward_input = forward.weight_input.clone();
144 let forward_hidden = forward.weight_hidden.clone();
145
146 parameters.insert(
147 format!("{cell_name}.{idx}.forward.weight_input"),
148 forward_input.to(),
149 );
150 parameters.insert(
151 format!("{cell_name}.{idx}.forward.weight_hidden"),
152 forward_hidden.to(),
153 );
154
155 if self.is_bidirectional {
156 let reverse = backward.unwrap();
157 let reverse_input = reverse.weight_input.clone();
158 let reverse_hidden = reverse.weight_hidden.clone();
159
160 parameters.insert(
161 format!("{cell_name}.{idx}.reverse.weight_input"),
162 reverse_input.to(),
163 );
164 parameters.insert(
165 format!("{cell_name}.{idx}.reverse.weight_hidden"),
166 reverse_hidden.to(),
167 );
168 }
169 }
170
171 parameters
172 }
173
174 fn biases(&self) -> HashMap<String, Variable<T, D>> {
175 let cell_name = C::name();
176 #[cfg(feature = "nvidia")]
177 let weights = if self.is_cudnn {
178 self.cudnn_weights_to_layer_weights()
179 } else {
180 self.weights.as_ref().unwrap().clone()
181 };
182
183 #[cfg(not(feature = "nvidia"))]
184 let weights = self.weights.as_ref().unwrap().clone();
185
186 let mut parameters = HashMap::new();
187
188 for (idx, weight) in weights.iter().enumerate() {
189 let forward = weight.forward.clone();
190 let backward = weight.backward.clone();
191
192 let forward_input = forward.bias_input.clone();
193 let forward_hidden = forward.bias_hidden.clone();
194
195 parameters.insert(
196 format!("{cell_name}.{idx}.forward.bias_input"),
197 forward_input.to(),
198 );
199 parameters.insert(
200 format!("{cell_name}.{idx}.forward.bias_hidden"),
201 forward_hidden.to(),
202 );
203
204 if self.is_bidirectional {
205 let reverse = backward.unwrap();
206 let reverse_input = reverse.bias_input.clone();
207 let reverse_hidden = reverse.bias_hidden.clone();
208
209 parameters.insert(
210 format!("{cell_name}.{idx}.reverse.bias_input"),
211 reverse_input.to(),
212 );
213 parameters.insert(
214 format!("{cell_name}.{idx}.reverse.bias_hidden"),
215 reverse_hidden.to(),
216 );
217 }
218 }
219
220 parameters
221 }
222
223 fn load_parameters(&mut self, parameters: HashMap<String, Variable<T, D>>) {
224 let num_layers = get_num_layers::<T, D, C>(¶meters);
225 let is_bidirectional = is_bidirectional::<T, D, C>(¶meters);
226
227 let mut weights = Vec::new();
228
229 for idx in 0..num_layers {
230 let weight = get_nth_weights(¶meters, idx, is_bidirectional);
231 weights.push(weight);
232 }
233
234 self.weights = Some(weights.clone());
235
236 #[cfg(feature = "nvidia")]
237 if self.is_cudnn {
238 let desc = self.desc.as_ref().unwrap();
239 let cudnn_weights = self.cudnn_weights.as_ref().unwrap();
240
241 let weights = rnn_weights_to_desc(weights, self.is_bidirectional);
242
243 desc.borrow()
244 .load_rnn_weights(cudnn_weights.get_as_mut().as_mut_ptr().cast(), weights)
245 .unwrap();
246
247 self.cudnn_weights = Some(cudnn_weights.clone());
248 self.weights = None;
249 }
250 }
251}
252
253#[cfg(feature = "nvidia")]
254pub(super) fn rnn_weights_to_desc<T: Num, D: Device, C: CellType>(
255 weights: Vec<RNNLayerWeights<T, D, C>>,
256 is_bidirectional: bool,
257) -> Vec<RNNWeightsMat<T, D>> {
258 let mut rnn_weights = Vec::new();
259
260 for weight in weights {
261 let forwad_weights = weight.forward;
262 let weight_input = forwad_weights.weight_input.get_as_ref();
263 let weight_hidden = forwad_weights.weight_hidden.get_as_ref();
264 let bias_input = forwad_weights.bias_input.get_as_ref();
265 let bias_hidden = forwad_weights.bias_hidden.get_as_ref();
266
267 let weights = RNNWeightsMat::new(
268 weight_input.new_matrix(),
269 weight_hidden.new_matrix(),
270 bias_input.new_matrix(),
271 bias_hidden.new_matrix(),
272 );
273
274 rnn_weights.push(weights);
275
276 if is_bidirectional {
277 let backward_weights = weight.backward.unwrap();
278 let weight_input = backward_weights.weight_input.get_as_ref();
279 let weight_hidden = backward_weights.weight_hidden.get_as_ref();
280 let bias_input = backward_weights.bias_input.get_as_ref();
281 let bias_hidden = backward_weights.bias_hidden.get_as_ref();
282
283 let weights = RNNWeightsMat::new(
284 weight_input.new_matrix(),
285 weight_hidden.new_matrix(),
286 bias_input.new_matrix(),
287 bias_hidden.new_matrix(),
288 );
289
290 rnn_weights.push(weights);
291 }
292 }
293
294 rnn_weights
295}
296
297impl<T: Num, D: Device, C: CellType> RNNInner<T, D, C> {
298 #[cfg(feature = "nvidia")]
299 fn cudnn_weights_to_layer_weights(&self) -> Vec<RNNLayerWeights<T, D, C>> {
300 let desc = self.desc.as_ref().unwrap().clone();
301 let cudnn_weights_ptr = self
302 .cudnn_weights
303 .as_ref()
304 .unwrap()
305 .get_as_mut()
306 .as_mut_ptr();
307 let weights = desc
308 .borrow()
309 .store_rnn_weights::<D>(cudnn_weights_ptr.cast());
310
311 let weights = weights
312 .into_iter()
313 .map(RNNWeights::from)
314 .collect::<Vec<_>>();
315
316 if self.is_bidirectional {
317 let mut layer_weights = Vec::new();
318 for i in 0..weights.len() / 2 {
319 let forward = weights[i].clone();
320 let backward = weights[i + 1].clone();
321 layer_weights.push(RNNLayerWeights::new(forward, Some(backward)));
322 }
323 return layer_weights;
324 }
325 weights
326 .into_iter()
327 .map(|w| RNNLayerWeights::new(w, None))
328 .collect()
329 }
330}