rustorch/nn/
recurrent_common.rs

1//! Common functionality for recurrent neural networks
2//! リカレントニューラルネットワーク用共通機能
3
4use crate::autograd::Variable;
5use crate::tensor::Tensor;
6use num_traits::Float;
7use rand_distr::{Distribution, Normal};
8use std::fmt::Debug;
9
10/// Common configuration for recurrent cells
11/// リカレントセル用共通設定
12#[derive(Debug, Clone)]
13pub struct RecurrentConfig {
14    /// Input size
15    /// 入力サイズ
16    pub input_size: usize,
17
18    /// Hidden size
19    /// 隠れ状態サイズ
20    pub hidden_size: usize,
21
22    /// Number of gates (RNN: 1, GRU: 3, LSTM: 4)
23    /// ゲート数(RNN: 1, GRU: 3, LSTM: 4)
24    pub num_gates: usize,
25
26    /// Whether to use bias
27    /// バイアスを使用するか
28    pub bias: bool,
29
30    /// Training mode
31    /// 学習モード
32    pub training: bool,
33}
34
35impl RecurrentConfig {
36    /// Create new RNN configuration
37    /// 新しいRNN設定を作成
38    pub fn rnn(input_size: usize, hidden_size: usize, bias: bool) -> Self {
39        Self {
40            input_size,
41            hidden_size,
42            num_gates: 1,
43            bias,
44            training: true,
45        }
46    }
47
48    /// Create new GRU configuration
49    /// 新しいGRU設定を作成
50    pub fn gru(input_size: usize, hidden_size: usize, bias: bool) -> Self {
51        Self {
52            input_size,
53            hidden_size,
54            num_gates: 3,
55            bias,
56            training: true,
57        }
58    }
59
60    /// Create new LSTM configuration
61    /// 新しいLSTM設定を作成
62    pub fn lstm(input_size: usize, hidden_size: usize, bias: bool) -> Self {
63        Self {
64            input_size,
65            hidden_size,
66            num_gates: 4,
67            bias,
68            training: true,
69        }
70    }
71}
72
73/// Common trait for recurrent cells
74/// リカレントセル用共通トレイト
75pub trait RecurrentCell<T: Float + Send + Sync + Debug + 'static> {
76    /// Get input size
77    /// 入力サイズを取得
78    fn input_size(&self) -> usize;
79
80    /// Get hidden size
81    /// 隠れ状態サイズを取得
82    fn hidden_size(&self) -> usize;
83
84    /// Set training mode
85    /// 学習モードを設定
86    fn set_training(&mut self, training: bool);
87
88    /// Check if in training mode
89    /// 学習モードかどうかをチェック
90    fn is_training(&self) -> bool;
91
92    /// Get configuration
93    /// 設定を取得
94    fn config(&self) -> &RecurrentConfig;
95}
96
97/// Common operations for recurrent cells
98/// リカレントセル用共通操作
99pub struct RecurrentOps;
100
101impl RecurrentOps {
102    /// Initialize weights using Xavier/Glorot initialization
103    /// Xavier/Glorot初期化で重みを初期化
104    pub fn init_weights<
105        T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
106    >(
107        input_size: usize,
108        hidden_size: usize,
109        num_gates: usize,
110    ) -> (Variable<T>, Variable<T>) {
111        let mut rng = rand::thread_rng();
112        let normal = Normal::new(0.0, 0.1).unwrap();
113
114        // Input-to-hidden weights
115        let weight_ih_data: Vec<T> = (0..num_gates * hidden_size * input_size)
116            .map(|_| num_traits::cast(normal.sample(&mut rng) as f64).unwrap_or(T::zero()))
117            .collect();
118        let weight_ih = Variable::new(
119            Tensor::from_vec(weight_ih_data, vec![num_gates * hidden_size, input_size]),
120            true,
121        );
122
123        // Hidden-to-hidden weights
124        let weight_hh_data: Vec<T> = (0..num_gates * hidden_size * hidden_size)
125            .map(|_| num_traits::cast(normal.sample(&mut rng) as f64).unwrap_or(T::zero()))
126            .collect();
127        let weight_hh = Variable::new(
128            Tensor::from_vec(weight_hh_data, vec![num_gates * hidden_size, hidden_size]),
129            true,
130        );
131
132        (weight_ih, weight_hh)
133    }
134
135    /// Initialize bias
136    /// バイアスを初期化
137    pub fn init_bias<
138        T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
139    >(
140        hidden_size: usize,
141        num_gates: usize,
142    ) -> (Option<Variable<T>>, Option<Variable<T>>) {
143        let mut rng = rand::thread_rng();
144        let normal = Normal::new(0.0, 0.1).unwrap();
145
146        let bias_ih_data: Vec<T> = (0..num_gates * hidden_size)
147            .map(|_| num_traits::cast(normal.sample(&mut rng) as f64).unwrap_or(T::zero()))
148            .collect();
149        let bias_ih = Some(Variable::new(
150            Tensor::from_vec(bias_ih_data, vec![num_gates * hidden_size]),
151            true,
152        ));
153
154        let bias_hh_data: Vec<T> = (0..num_gates * hidden_size)
155            .map(|_| num_traits::cast(normal.sample(&mut rng) as f64).unwrap_or(T::zero()))
156            .collect();
157        let bias_hh = Some(Variable::new(
158            Tensor::from_vec(bias_hh_data, vec![num_gates * hidden_size]),
159            true,
160        ));
161
162        (bias_ih, bias_hh)
163    }
164
165    /// Linear transformation: input @ weight^T + bias
166    /// 線形変換: input @ weight^T + bias
167    pub fn linear_transform<
168        T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
169    >(
170        input: &Variable<T>,
171        weight: &Variable<T>,
172        bias: Option<&Variable<T>>,
173    ) -> Variable<T> {
174        let output = Self::matmul_variables(input, &Self::transpose_variable(weight));
175
176        match bias {
177            Some(b) => Self::add_variables(&output, b),
178            None => output,
179        }
180    }
181
182    /// Matrix multiplication for variables
183    /// Variable用の行列乗算
184    pub fn matmul_variables<
185        T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
186    >(
187        a: &Variable<T>,
188        b: &Variable<T>,
189    ) -> Variable<T> {
190        // Use Variable's matmul method directly
191        a.matmul(b)
192    }
193
194    /// Addition for variables
195    /// Variable用の加算
196    pub fn add_variables<
197        T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
198    >(
199        a: &Variable<T>,
200        b: &Variable<T>,
201    ) -> Variable<T> {
202        // Use Variable's add operator directly
203        a + b
204    }
205
206    /// Multiplication for variables
207    /// Variable用の乗算
208    pub fn multiply_variables<
209        T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
210    >(
211        a: &Variable<T>,
212        b: &Variable<T>,
213    ) -> Variable<T> {
214        // Use Variable's multiplication operator directly
215        a * b
216    }
217
218    /// Subtract variable from scalar
219    /// スカラーから変数を減算
220    pub fn subtract_from_scalar<
221        T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
222    >(
223        var: &Variable<T>,
224        scalar: T,
225    ) -> Variable<T> {
226        let var_binding = var.data();
227        let var_data = var_binding.read().unwrap();
228        let result_data = var_data.map(|x| scalar - x);
229        Variable::new(result_data, var.requires_grad())
230    }
231
232    /// Transpose for variables
233    /// Variable用の転置
234    pub fn transpose_variable<
235        T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
236    >(
237        var: &Variable<T>,
238    ) -> Variable<T> {
239        let var_binding = var.data();
240        let var_data = var_binding.read().unwrap();
241        let transposed_data = var_data.transpose().unwrap();
242        Variable::new(transposed_data, var.requires_grad())
243    }
244
245    /// Sigmoid activation for variables
246    /// Variable用のシグモイド活性化
247    pub fn sigmoid<
248        T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
249    >(
250        var: &Variable<T>,
251    ) -> Variable<T> {
252        let var_binding = var.data();
253        let var_data = var_binding.read().unwrap();
254        let sigmoid_data = var_data.map(|x| T::one() / (T::one() + (-x).exp()));
255        Variable::new(sigmoid_data, var.requires_grad())
256    }
257
258    /// Tanh activation for variables
259    /// Variable用のtanh活性化
260    pub fn tanh<
261        T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
262    >(
263        var: &Variable<T>,
264    ) -> Variable<T> {
265        let var_binding = var.data();
266        let var_data = var_binding.read().unwrap();
267        let tanh_data = var_data.map(|x| x.tanh());
268        Variable::new(tanh_data, var.requires_grad())
269    }
270
271    /// Slice gates from concatenated tensor
272    /// 連結されたテンソルからゲートをスライス
273    pub fn slice_gates<
274        T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
275    >(
276        gates: &Variable<T>,
277        gate_idx: usize,
278        hidden_size: usize,
279    ) -> Variable<T> {
280        let start_idx = gate_idx * hidden_size;
281        let end_idx = (gate_idx + 1) * hidden_size;
282
283        // Simplified slicing - in practice would need proper tensor slicing
284        let gates_binding = gates.data();
285        let gates_data = gates_binding.read().unwrap();
286        let gate_data: Vec<T> = gates_data.as_slice().unwrap()[start_idx..end_idx].to_vec();
287        Variable::new(
288            Tensor::from_vec(gate_data, vec![gates_data.shape()[0], hidden_size]),
289            gates.requires_grad(),
290        )
291    }
292
293    /// Create zero hidden state
294    /// ゼロ隠れ状態を作成
295    pub fn zero_hidden_state<
296        T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
297    >(
298        batch_size: usize,
299        hidden_size: usize,
300    ) -> Variable<T> {
301        Variable::new(Tensor::zeros(&[batch_size, hidden_size]), false)
302    }
303}
304
305/// Training mode enumeration
306/// 学習モード列挙型
307#[derive(Debug, Clone, Copy, PartialEq, Eq)]
308pub enum TrainingMode {
309    /// Training mode
310    /// 学習モード
311    Train,
312    /// Evaluation mode
313    /// 評価モード
314    Eval,
315}
316
317impl From<bool> for TrainingMode {
318    fn from(training: bool) -> Self {
319        if training {
320            TrainingMode::Train
321        } else {
322            TrainingMode::Eval
323        }
324    }
325}
326
327impl From<TrainingMode> for bool {
328    fn from(mode: TrainingMode) -> Self {
329        matches!(mode, TrainingMode::Train)
330    }
331}
332
333/// Common parameter collection for recurrent cells
334/// リカレントセル用共通パラメータ収集
335pub fn collect_recurrent_parameters<
336    T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
337>(
338    weight_ih: &Variable<T>,
339    weight_hh: &Variable<T>,
340    bias_ih: &Option<Variable<T>>,
341    bias_hh: &Option<Variable<T>>,
342) -> Vec<Variable<T>> {
343    let mut params = vec![weight_ih.clone(), weight_hh.clone()];
344
345    if let Some(ref bias) = bias_ih {
346        params.push(bias.clone());
347    }
348
349    if let Some(ref bias) = bias_hh {
350        params.push(bias.clone());
351    }
352
353    params
354}
355
356/// Common forward pass utilities for multi-layer recurrent networks
357/// 多層リカレントネットワーク用共通順伝播ユーティリティ
358pub struct MultiLayerUtils;
359
360impl MultiLayerUtils {
361    /// Get input for a specific timestep
362    /// 特定のタイムステップの入力を取得
363    pub fn get_timestep_input<
364        T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
365    >(
366        input: &Variable<T>,
367        timestep: usize,
368    ) -> Variable<T> {
369        // Simplified implementation - would need proper tensor slicing
370        let input_binding = input.data();
371        let input_data = input_binding.read().unwrap();
372        let batch_size = input_data.shape()[0];
373        let feature_size = input_data.shape()[2];
374
375        // Extract data for this timestep
376        let timestep_data: Vec<T> = (0..batch_size * feature_size)
377            .map(|i| {
378                let batch_idx = i / feature_size;
379                let feat_idx = i % feature_size;
380                input_data.as_slice().unwrap()[batch_idx * input_data.shape()[1] * feature_size
381                    + timestep * feature_size
382                    + feat_idx]
383            })
384            .collect();
385
386        Variable::new(
387            Tensor::from_vec(timestep_data, vec![batch_size, feature_size]),
388            input.requires_grad(),
389        )
390    }
391
392    /// Stack outputs along sequence dimension
393    /// シーケンス次元に沿って出力をスタック
394    pub fn stack_outputs<
395        T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
396    >(
397        outputs: &[Variable<T>],
398    ) -> Variable<T> {
399        let output_binding = outputs[0].data();
400        let output_data = output_binding.read().unwrap();
401        let batch_size = output_data.shape()[0];
402        let hidden_size = output_data.shape()[1];
403        let seq_len = outputs.len();
404
405        let mut stacked_data = Vec::new();
406
407        for batch_idx in 0..batch_size {
408            for t in 0..seq_len {
409                let output_binding = outputs[t].data();
410                let output_data = output_binding.read().unwrap();
411                let output_slice = output_data.as_slice().unwrap();
412                let start_idx = batch_idx * hidden_size;
413                let end_idx = start_idx + hidden_size;
414                stacked_data.extend_from_slice(&output_slice[start_idx..end_idx]);
415            }
416        }
417
418        Variable::new(
419            Tensor::from_vec(stacked_data, vec![batch_size, seq_len, hidden_size]),
420            outputs[0].requires_grad(),
421        )
422    }
423
424    /// Stack hidden states by layer
425    /// レイヤーごとに隠れ状態をスタック
426    pub fn stack_hidden_states<
427        T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
428    >(
429        states: &[Variable<T>],
430        num_layers: usize,
431    ) -> Variable<T> {
432        let state_binding = states[0].data();
433        let state_data = state_binding.read().unwrap();
434        let batch_size = state_data.shape()[0];
435        let hidden_size = state_data.shape()[1];
436
437        let mut stacked_data = Vec::new();
438
439        for state in states {
440            let state_binding = state.data();
441            let state_data = state_binding.read().unwrap();
442            stacked_data.extend_from_slice(state_data.as_slice().unwrap());
443        }
444
445        Variable::new(
446            Tensor::from_vec(stacked_data, vec![num_layers, batch_size, hidden_size]),
447            states[0].requires_grad(),
448        )
449    }
450}