1use crate::autograd::Variable;
5use crate::tensor::Tensor;
6use num_traits::Float;
7use rand_distr::{Distribution, Normal};
8use std::fmt::Debug;
9
10#[derive(Debug, Clone)]
13pub struct RecurrentConfig {
14 pub input_size: usize,
17
18 pub hidden_size: usize,
21
22 pub num_gates: usize,
25
26 pub bias: bool,
29
30 pub training: bool,
33}
34
35impl RecurrentConfig {
36 pub fn rnn(input_size: usize, hidden_size: usize, bias: bool) -> Self {
39 Self {
40 input_size,
41 hidden_size,
42 num_gates: 1,
43 bias,
44 training: true,
45 }
46 }
47
48 pub fn gru(input_size: usize, hidden_size: usize, bias: bool) -> Self {
51 Self {
52 input_size,
53 hidden_size,
54 num_gates: 3,
55 bias,
56 training: true,
57 }
58 }
59
60 pub fn lstm(input_size: usize, hidden_size: usize, bias: bool) -> Self {
63 Self {
64 input_size,
65 hidden_size,
66 num_gates: 4,
67 bias,
68 training: true,
69 }
70 }
71}
72
73pub trait RecurrentCell<T: Float + Send + Sync + Debug + 'static> {
76 fn input_size(&self) -> usize;
79
80 fn hidden_size(&self) -> usize;
83
84 fn set_training(&mut self, training: bool);
87
88 fn is_training(&self) -> bool;
91
92 fn config(&self) -> &RecurrentConfig;
95}
96
97pub struct RecurrentOps;
100
101impl RecurrentOps {
102 pub fn init_weights<
105 T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
106 >(
107 input_size: usize,
108 hidden_size: usize,
109 num_gates: usize,
110 ) -> (Variable<T>, Variable<T>) {
111 let mut rng = rand::thread_rng();
112 let normal = Normal::new(0.0, 0.1).unwrap();
113
114 let weight_ih_data: Vec<T> = (0..num_gates * hidden_size * input_size)
116 .map(|_| num_traits::cast(normal.sample(&mut rng) as f64).unwrap_or(T::zero()))
117 .collect();
118 let weight_ih = Variable::new(
119 Tensor::from_vec(weight_ih_data, vec![num_gates * hidden_size, input_size]),
120 true,
121 );
122
123 let weight_hh_data: Vec<T> = (0..num_gates * hidden_size * hidden_size)
125 .map(|_| num_traits::cast(normal.sample(&mut rng) as f64).unwrap_or(T::zero()))
126 .collect();
127 let weight_hh = Variable::new(
128 Tensor::from_vec(weight_hh_data, vec![num_gates * hidden_size, hidden_size]),
129 true,
130 );
131
132 (weight_ih, weight_hh)
133 }
134
135 pub fn init_bias<
138 T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
139 >(
140 hidden_size: usize,
141 num_gates: usize,
142 ) -> (Option<Variable<T>>, Option<Variable<T>>) {
143 let mut rng = rand::thread_rng();
144 let normal = Normal::new(0.0, 0.1).unwrap();
145
146 let bias_ih_data: Vec<T> = (0..num_gates * hidden_size)
147 .map(|_| num_traits::cast(normal.sample(&mut rng) as f64).unwrap_or(T::zero()))
148 .collect();
149 let bias_ih = Some(Variable::new(
150 Tensor::from_vec(bias_ih_data, vec![num_gates * hidden_size]),
151 true,
152 ));
153
154 let bias_hh_data: Vec<T> = (0..num_gates * hidden_size)
155 .map(|_| num_traits::cast(normal.sample(&mut rng) as f64).unwrap_or(T::zero()))
156 .collect();
157 let bias_hh = Some(Variable::new(
158 Tensor::from_vec(bias_hh_data, vec![num_gates * hidden_size]),
159 true,
160 ));
161
162 (bias_ih, bias_hh)
163 }
164
165 pub fn linear_transform<
168 T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
169 >(
170 input: &Variable<T>,
171 weight: &Variable<T>,
172 bias: Option<&Variable<T>>,
173 ) -> Variable<T> {
174 let output = Self::matmul_variables(input, &Self::transpose_variable(weight));
175
176 match bias {
177 Some(b) => Self::add_variables(&output, b),
178 None => output,
179 }
180 }
181
182 pub fn matmul_variables<
185 T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
186 >(
187 a: &Variable<T>,
188 b: &Variable<T>,
189 ) -> Variable<T> {
190 a.matmul(b)
192 }
193
194 pub fn add_variables<
197 T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
198 >(
199 a: &Variable<T>,
200 b: &Variable<T>,
201 ) -> Variable<T> {
202 a + b
204 }
205
206 pub fn multiply_variables<
209 T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
210 >(
211 a: &Variable<T>,
212 b: &Variable<T>,
213 ) -> Variable<T> {
214 a * b
216 }
217
218 pub fn subtract_from_scalar<
221 T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
222 >(
223 var: &Variable<T>,
224 scalar: T,
225 ) -> Variable<T> {
226 let var_binding = var.data();
227 let var_data = var_binding.read().unwrap();
228 let result_data = var_data.map(|x| scalar - x);
229 Variable::new(result_data, var.requires_grad())
230 }
231
232 pub fn transpose_variable<
235 T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
236 >(
237 var: &Variable<T>,
238 ) -> Variable<T> {
239 let var_binding = var.data();
240 let var_data = var_binding.read().unwrap();
241 let transposed_data = var_data.transpose().unwrap();
242 Variable::new(transposed_data, var.requires_grad())
243 }
244
245 pub fn sigmoid<
248 T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
249 >(
250 var: &Variable<T>,
251 ) -> Variable<T> {
252 let var_binding = var.data();
253 let var_data = var_binding.read().unwrap();
254 let sigmoid_data = var_data.map(|x| T::one() / (T::one() + (-x).exp()));
255 Variable::new(sigmoid_data, var.requires_grad())
256 }
257
258 pub fn tanh<
261 T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
262 >(
263 var: &Variable<T>,
264 ) -> Variable<T> {
265 let var_binding = var.data();
266 let var_data = var_binding.read().unwrap();
267 let tanh_data = var_data.map(|x| x.tanh());
268 Variable::new(tanh_data, var.requires_grad())
269 }
270
271 pub fn slice_gates<
274 T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
275 >(
276 gates: &Variable<T>,
277 gate_idx: usize,
278 hidden_size: usize,
279 ) -> Variable<T> {
280 let start_idx = gate_idx * hidden_size;
281 let end_idx = (gate_idx + 1) * hidden_size;
282
283 let gates_binding = gates.data();
285 let gates_data = gates_binding.read().unwrap();
286 let gate_data: Vec<T> = gates_data.as_slice().unwrap()[start_idx..end_idx].to_vec();
287 Variable::new(
288 Tensor::from_vec(gate_data, vec![gates_data.shape()[0], hidden_size]),
289 gates.requires_grad(),
290 )
291 }
292
293 pub fn zero_hidden_state<
296 T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
297 >(
298 batch_size: usize,
299 hidden_size: usize,
300 ) -> Variable<T> {
301 Variable::new(Tensor::zeros(&[batch_size, hidden_size]), false)
302 }
303}
304
305#[derive(Debug, Clone, Copy, PartialEq, Eq)]
308pub enum TrainingMode {
309 Train,
312 Eval,
315}
316
317impl From<bool> for TrainingMode {
318 fn from(training: bool) -> Self {
319 if training {
320 TrainingMode::Train
321 } else {
322 TrainingMode::Eval
323 }
324 }
325}
326
327impl From<TrainingMode> for bool {
328 fn from(mode: TrainingMode) -> Self {
329 matches!(mode, TrainingMode::Train)
330 }
331}
332
333pub fn collect_recurrent_parameters<
336 T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
337>(
338 weight_ih: &Variable<T>,
339 weight_hh: &Variable<T>,
340 bias_ih: &Option<Variable<T>>,
341 bias_hh: &Option<Variable<T>>,
342) -> Vec<Variable<T>> {
343 let mut params = vec![weight_ih.clone(), weight_hh.clone()];
344
345 if let Some(ref bias) = bias_ih {
346 params.push(bias.clone());
347 }
348
349 if let Some(ref bias) = bias_hh {
350 params.push(bias.clone());
351 }
352
353 params
354}
355
356pub struct MultiLayerUtils;
359
360impl MultiLayerUtils {
361 pub fn get_timestep_input<
364 T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
365 >(
366 input: &Variable<T>,
367 timestep: usize,
368 ) -> Variable<T> {
369 let input_binding = input.data();
371 let input_data = input_binding.read().unwrap();
372 let batch_size = input_data.shape()[0];
373 let feature_size = input_data.shape()[2];
374
375 let timestep_data: Vec<T> = (0..batch_size * feature_size)
377 .map(|i| {
378 let batch_idx = i / feature_size;
379 let feat_idx = i % feature_size;
380 input_data.as_slice().unwrap()[batch_idx * input_data.shape()[1] * feature_size
381 + timestep * feature_size
382 + feat_idx]
383 })
384 .collect();
385
386 Variable::new(
387 Tensor::from_vec(timestep_data, vec![batch_size, feature_size]),
388 input.requires_grad(),
389 )
390 }
391
392 pub fn stack_outputs<
395 T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
396 >(
397 outputs: &[Variable<T>],
398 ) -> Variable<T> {
399 let output_binding = outputs[0].data();
400 let output_data = output_binding.read().unwrap();
401 let batch_size = output_data.shape()[0];
402 let hidden_size = output_data.shape()[1];
403 let seq_len = outputs.len();
404
405 let mut stacked_data = Vec::new();
406
407 for batch_idx in 0..batch_size {
408 for t in 0..seq_len {
409 let output_binding = outputs[t].data();
410 let output_data = output_binding.read().unwrap();
411 let output_slice = output_data.as_slice().unwrap();
412 let start_idx = batch_idx * hidden_size;
413 let end_idx = start_idx + hidden_size;
414 stacked_data.extend_from_slice(&output_slice[start_idx..end_idx]);
415 }
416 }
417
418 Variable::new(
419 Tensor::from_vec(stacked_data, vec![batch_size, seq_len, hidden_size]),
420 outputs[0].requires_grad(),
421 )
422 }
423
424 pub fn stack_hidden_states<
427 T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
428 >(
429 states: &[Variable<T>],
430 num_layers: usize,
431 ) -> Variable<T> {
432 let state_binding = states[0].data();
433 let state_data = state_binding.read().unwrap();
434 let batch_size = state_data.shape()[0];
435 let hidden_size = state_data.shape()[1];
436
437 let mut stacked_data = Vec::new();
438
439 for state in states {
440 let state_binding = state.data();
441 let state_data = state_binding.read().unwrap();
442 stacked_data.extend_from_slice(state_data.as_slice().unwrap());
443 }
444
445 Variable::new(
446 Tensor::from_vec(stacked_data, vec![num_layers, batch_size, hidden_size]),
447 states[0].requires_grad(),
448 )
449 }
450}