zenu_layer/layers/rnn/
inner.rs

1use std::collections::HashMap;
2#[cfg(feature = "nvidia")]
3use std::{cell::RefCell, rc::Rc};
4
5use zenu_autograd::{
6    nn::rnns::weights::{CellType, RNNLayerWeights, RNNWeights},
7    Variable,
8};
9#[cfg(feature = "nvidia")]
10use zenu_matrix::{
11    device::nvidia::Nvidia,
12    nn::rnn::{RNNDescriptor, RNNWeights as RNNWeightsMat},
13};
14
15use zenu_matrix::{device::Device, num::Num};
16
17use crate::Parameters;
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum Activation {
21    ReLU,
22    Tanh,
23}
24
25#[expect(clippy::module_name_repetitions)]
26pub struct RNNInner<T: Num, D: Device, C: CellType> {
27    pub(super) weights: Option<Vec<RNNLayerWeights<T, D, C>>>,
28    #[cfg(feature = "nvidia")]
29    pub(super) desc: Option<Rc<RefCell<RNNDescriptor<T>>>>,
30    #[cfg(feature = "nvidia")]
31    pub(super) cudnn_weights: Option<Variable<T, Nvidia>>,
32    #[cfg(feature = "nvidia")]
33    pub(super) is_cudnn: bool,
34    pub(super) is_bidirectional: bool,
35    pub(super) activation: Option<Activation>,
36    #[cfg(feature = "nvidia")]
37    pub(super) is_training: bool,
38}
39fn get_num_layers<T: Num, D: Device, C: CellType>(
40    parameters: &HashMap<String, Variable<T, D>>,
41) -> usize {
42    let mut num_layers = 0;
43    for key in parameters.keys() {
44        if key.starts_with(format!("{}.", C::name()).as_str()) {
45            let layer_num = key.split('.').nth(1).unwrap().parse::<usize>().unwrap();
46            num_layers = num_layers.max(layer_num);
47        }
48    }
49    num_layers + 1
50}
51
52fn is_bidirectional<T: Num, D: Device, C: CellType>(
53    parameters: &HashMap<String, Variable<T, D>>,
54) -> bool {
55    for key in parameters.keys() {
56        if key.starts_with(format!("{}.", C::name()).as_str()) && key.contains("reverse") {
57            return true;
58        }
59    }
60    false
61}
62
63fn get_nth_weights<T: Num, D: Device, C: CellType>(
64    parameters: &HashMap<String, Variable<T, D>>,
65    idx: usize,
66    is_bidirectional: bool,
67) -> RNNLayerWeights<T, D, C> {
68    let cell_name = C::name();
69    let forward_input = parameters
70        .get(&format!("{cell_name}.{idx}.forward.weight_input"))
71        .unwrap()
72        .clone();
73    let forward_hidden = parameters
74        .get(&format!("{cell_name}.{idx}.forward.weight_hidden"))
75        .unwrap()
76        .clone();
77    let forward_bias_input = parameters
78        .get(&format!("{cell_name}.{idx}.forward.bias_input"))
79        .unwrap()
80        .clone();
81    let forward_bias_hidden = parameters
82        .get(&format!("{cell_name}.{idx}.forward.bias_hidden"))
83        .unwrap()
84        .clone();
85
86    let forward = RNNWeights::new(
87        forward_input,
88        forward_hidden,
89        forward_bias_input,
90        forward_bias_hidden,
91    );
92
93    if is_bidirectional {
94        let reverse_input = parameters
95            .get(&format!("{cell_name}.{idx}.reverse.weight_input"))
96            .unwrap()
97            .clone();
98        let reverse_hidden = parameters
99            .get(&format!("{cell_name}.{idx}.reverse.weight_hidden"))
100            .unwrap()
101            .clone();
102        let reverse_bias_input = parameters
103            .get(&format!("{cell_name}.{idx}.reverse.bias_input"))
104            .unwrap()
105            .clone();
106        let reverse_bias_hidden = parameters
107            .get(&format!("{cell_name}.{idx}.reverse.bias_hidden"))
108            .unwrap()
109            .clone();
110
111        let backward = RNNWeights::new(
112            reverse_input,
113            reverse_hidden,
114            reverse_bias_input,
115            reverse_bias_hidden,
116        );
117
118        RNNLayerWeights::new(forward, Some(backward))
119    } else {
120        RNNLayerWeights::new(forward, None)
121    }
122}
123
124impl<T: Num, D: Device, C: CellType> Parameters<T, D> for RNNInner<T, D, C> {
125    fn weights(&self) -> HashMap<String, Variable<T, D>> {
126        let cell_name = C::name();
127        #[cfg(feature = "nvidia")]
128        let weights = if self.is_cudnn {
129            self.cudnn_weights_to_layer_weights()
130        } else {
131            self.weights.as_ref().unwrap().clone()
132        };
133
134        #[cfg(not(feature = "nvidia"))]
135        let weights = self.weights.as_ref().unwrap().clone();
136
137        let mut parameters = HashMap::new();
138
139        for (idx, weight) in weights.iter().enumerate() {
140            let forward = weight.forward.clone();
141            let backward = weight.backward.clone();
142
143            let forward_input = forward.weight_input.clone();
144            let forward_hidden = forward.weight_hidden.clone();
145
146            parameters.insert(
147                format!("{cell_name}.{idx}.forward.weight_input"),
148                forward_input.to(),
149            );
150            parameters.insert(
151                format!("{cell_name}.{idx}.forward.weight_hidden"),
152                forward_hidden.to(),
153            );
154
155            if self.is_bidirectional {
156                let reverse = backward.unwrap();
157                let reverse_input = reverse.weight_input.clone();
158                let reverse_hidden = reverse.weight_hidden.clone();
159
160                parameters.insert(
161                    format!("{cell_name}.{idx}.reverse.weight_input"),
162                    reverse_input.to(),
163                );
164                parameters.insert(
165                    format!("{cell_name}.{idx}.reverse.weight_hidden"),
166                    reverse_hidden.to(),
167                );
168            }
169        }
170
171        parameters
172    }
173
174    fn biases(&self) -> HashMap<String, Variable<T, D>> {
175        let cell_name = C::name();
176        #[cfg(feature = "nvidia")]
177        let weights = if self.is_cudnn {
178            self.cudnn_weights_to_layer_weights()
179        } else {
180            self.weights.as_ref().unwrap().clone()
181        };
182
183        #[cfg(not(feature = "nvidia"))]
184        let weights = self.weights.as_ref().unwrap().clone();
185
186        let mut parameters = HashMap::new();
187
188        for (idx, weight) in weights.iter().enumerate() {
189            let forward = weight.forward.clone();
190            let backward = weight.backward.clone();
191
192            let forward_input = forward.bias_input.clone();
193            let forward_hidden = forward.bias_hidden.clone();
194
195            parameters.insert(
196                format!("{cell_name}.{idx}.forward.bias_input"),
197                forward_input.to(),
198            );
199            parameters.insert(
200                format!("{cell_name}.{idx}.forward.bias_hidden"),
201                forward_hidden.to(),
202            );
203
204            if self.is_bidirectional {
205                let reverse = backward.unwrap();
206                let reverse_input = reverse.bias_input.clone();
207                let reverse_hidden = reverse.bias_hidden.clone();
208
209                parameters.insert(
210                    format!("{cell_name}.{idx}.reverse.bias_input"),
211                    reverse_input.to(),
212                );
213                parameters.insert(
214                    format!("{cell_name}.{idx}.reverse.bias_hidden"),
215                    reverse_hidden.to(),
216                );
217            }
218        }
219
220        parameters
221    }
222
223    fn load_parameters(&mut self, parameters: HashMap<String, Variable<T, D>>) {
224        let num_layers = get_num_layers::<T, D, C>(&parameters);
225        let is_bidirectional = is_bidirectional::<T, D, C>(&parameters);
226
227        let mut weights = Vec::new();
228
229        for idx in 0..num_layers {
230            let weight = get_nth_weights(&parameters, idx, is_bidirectional);
231            weights.push(weight);
232        }
233
234        self.weights = Some(weights.clone());
235
236        #[cfg(feature = "nvidia")]
237        if self.is_cudnn {
238            let desc = self.desc.as_ref().unwrap();
239            let cudnn_weights = self.cudnn_weights.as_ref().unwrap();
240
241            let weights = rnn_weights_to_desc(weights, self.is_bidirectional);
242
243            desc.borrow()
244                .load_rnn_weights(cudnn_weights.get_as_mut().as_mut_ptr().cast(), weights)
245                .unwrap();
246
247            self.cudnn_weights = Some(cudnn_weights.clone());
248            self.weights = None;
249        }
250    }
251}
252
253#[cfg(feature = "nvidia")]
254pub(super) fn rnn_weights_to_desc<T: Num, D: Device, C: CellType>(
255    weights: Vec<RNNLayerWeights<T, D, C>>,
256    is_bidirectional: bool,
257) -> Vec<RNNWeightsMat<T, D>> {
258    let mut rnn_weights = Vec::new();
259
260    for weight in weights {
261        let forwad_weights = weight.forward;
262        let weight_input = forwad_weights.weight_input.get_as_ref();
263        let weight_hidden = forwad_weights.weight_hidden.get_as_ref();
264        let bias_input = forwad_weights.bias_input.get_as_ref();
265        let bias_hidden = forwad_weights.bias_hidden.get_as_ref();
266
267        let weights = RNNWeightsMat::new(
268            weight_input.new_matrix(),
269            weight_hidden.new_matrix(),
270            bias_input.new_matrix(),
271            bias_hidden.new_matrix(),
272        );
273
274        rnn_weights.push(weights);
275
276        if is_bidirectional {
277            let backward_weights = weight.backward.unwrap();
278            let weight_input = backward_weights.weight_input.get_as_ref();
279            let weight_hidden = backward_weights.weight_hidden.get_as_ref();
280            let bias_input = backward_weights.bias_input.get_as_ref();
281            let bias_hidden = backward_weights.bias_hidden.get_as_ref();
282
283            let weights = RNNWeightsMat::new(
284                weight_input.new_matrix(),
285                weight_hidden.new_matrix(),
286                bias_input.new_matrix(),
287                bias_hidden.new_matrix(),
288            );
289
290            rnn_weights.push(weights);
291        }
292    }
293
294    rnn_weights
295}
296
297impl<T: Num, D: Device, C: CellType> RNNInner<T, D, C> {
298    #[cfg(feature = "nvidia")]
299    fn cudnn_weights_to_layer_weights(&self) -> Vec<RNNLayerWeights<T, D, C>> {
300        let desc = self.desc.as_ref().unwrap().clone();
301        let cudnn_weights_ptr = self
302            .cudnn_weights
303            .as_ref()
304            .unwrap()
305            .get_as_mut()
306            .as_mut_ptr();
307        let weights = desc
308            .borrow()
309            .store_rnn_weights::<D>(cudnn_weights_ptr.cast());
310
311        let weights = weights
312            .into_iter()
313            .map(RNNWeights::from)
314            .collect::<Vec<_>>();
315
316        if self.is_bidirectional {
317            let mut layer_weights = Vec::new();
318            for i in 0..weights.len() / 2 {
319                let forward = weights[i].clone();
320                let backward = weights[i + 1].clone();
321                layer_weights.push(RNNLayerWeights::new(forward, Some(backward)));
322            }
323            return layer_weights;
324        }
325        weights
326            .into_iter()
327            .map(|w| RNNLayerWeights::new(w, None))
328            .collect()
329    }
330}