1use crate::error::{NeuralError, Result};
4use crate::layers::recurrent::{GruForwardOutput, GruGateCache};
5use crate::layers::{Layer, ParamLayer};
6use scirs2_core::ndarray::{Array, ArrayView, Ix2, IxDyn, ScalarOperand};
7use scirs2_core::numeric::Float;
8use scirs2_core::random::{Distribution, Uniform};
9use std::fmt::Debug;
10use std::sync::{Arc, RwLock};
11#[derive(Debug, Clone)]
13pub struct GRUConfig {
14 pub input_size: usize,
16 pub hidden_size: usize,
18}
19pub struct GRU<F: Float + Debug> {
44 input_size: usize,
46 hidden_size: usize,
48 weight_ir: Array<F, IxDyn>,
50 weight_hr: Array<F, IxDyn>,
52 bias_ir: Array<F, IxDyn>,
54 bias_hr: Array<F, IxDyn>,
56 weight_iz: Array<F, IxDyn>,
58 weight_hz: Array<F, IxDyn>,
60 bias_iz: Array<F, IxDyn>,
62 bias_hz: Array<F, IxDyn>,
64 weight_in: Array<F, IxDyn>,
66 weight_hn: Array<F, IxDyn>,
68 bias_in: Array<F, IxDyn>,
70 bias_hn: Array<F, IxDyn>,
72 #[allow(dead_code)]
74 gradients: RwLock<Vec<Array<F, IxDyn>>>,
75 input_cache: RwLock<Option<Array<F, IxDyn>>>,
77 hidden_states_cache: RwLock<Option<Array<F, IxDyn>>>,
79 #[allow(dead_code)]
81 gate_cache: GruGateCache<F>,
82}
83
84impl<F: Float + Debug + ScalarOperand + 'static> GRU<F> {
85 pub fn new<R: scirs2_core::random::Rng + scirs2_core::random::RngCore>(
94 input_size: usize,
95 hidden_size: usize,
96 rng: &mut R,
97 ) -> Result<Self> {
98 if input_size == 0 || hidden_size == 0 {
100 return Err(NeuralError::InvalidArchitecture(
101 "Input _size and hidden _size must be positive".to_string(),
102 ));
103 }
104 let scale_ih = F::from(1.0 / (input_size as f64).sqrt()).ok_or_else(|| {
106 NeuralError::InvalidArchitecture("Failed to convert scale factor".to_string())
107 })?;
108 let scale_hh = F::from(1.0 / (hidden_size as f64).sqrt()).ok_or_else(|| {
109 NeuralError::InvalidArchitecture("Failed to convert hidden _size scale".to_string())
110 })?;
111
112 let mut create_weight_matrix = |rows: usize,
114 cols: usize,
115 scale: F|
116 -> Result<Array<F, IxDyn>> {
117 let mut weights_vec: Vec<F> = Vec::with_capacity(rows * cols);
118 let uniform = Uniform::new(-1.0, 1.0).map_err(|e| {
119 NeuralError::InvalidArchitecture(format!(
120 "Failed to create uniform distribution: {e}"
121 ))
122 })?;
123 for _ in 0..(rows * cols) {
124 let rand_val = uniform.sample(rng);
125 let val = F::from(rand_val).ok_or_else(|| {
126 NeuralError::InvalidArchitecture("Failed to convert random value".to_string())
127 })?;
128 weights_vec.push(val * scale);
129 }
130 Array::from_shape_vec(IxDyn(&[rows, cols]), weights_vec).map_err(|e| {
131 NeuralError::InvalidArchitecture(format!("Failed to create weights array: {e}"))
132 })
133 };
134 let weight_ir = create_weight_matrix(hidden_size, input_size, scale_ih)?;
136 let weight_hr = create_weight_matrix(hidden_size, hidden_size, scale_hh)?;
137 let bias_ir: Array<F, IxDyn> = Array::zeros(IxDyn(&[hidden_size]));
138 let bias_hr: Array<F, IxDyn> = Array::zeros(IxDyn(&[hidden_size]));
139 let weight_iz = create_weight_matrix(hidden_size, input_size, scale_ih)?;
140 let weight_hz = create_weight_matrix(hidden_size, hidden_size, scale_hh)?;
141 let bias_iz: Array<F, IxDyn> = Array::zeros(IxDyn(&[hidden_size]));
142 let bias_hz: Array<F, IxDyn> = Array::zeros(IxDyn(&[hidden_size]));
143 let weight_in = create_weight_matrix(hidden_size, input_size, scale_ih)?;
144 let weight_hn = create_weight_matrix(hidden_size, hidden_size, scale_hh)?;
145 let bias_in: Array<F, IxDyn> = Array::zeros(IxDyn(&[hidden_size]));
146 let bias_hn: Array<F, IxDyn> = Array::zeros(IxDyn(&[hidden_size]));
147 let gradients = vec![
149 Array::zeros(weight_ir.dim()),
150 Array::zeros(weight_hr.dim()),
151 Array::zeros(bias_ir.dim()),
152 Array::zeros(bias_hr.dim()),
153 Array::zeros(weight_iz.dim()),
154 Array::zeros(weight_hz.dim()),
155 Array::zeros(bias_iz.dim()),
156 Array::zeros(bias_hz.dim()),
157 Array::zeros(weight_in.dim()),
158 Array::zeros(weight_hn.dim()),
159 Array::zeros(bias_in.dim()),
160 Array::zeros(bias_hn.dim()),
161 ];
162 Ok(Self {
163 input_size,
164 hidden_size,
165 weight_ir,
166 weight_hr,
167 bias_ir,
168 bias_hr,
169 weight_iz,
170 weight_hz,
171 bias_iz,
172 bias_hz,
173 weight_in,
174 weight_hn,
175 bias_in,
176 bias_hn,
177 gradients: RwLock::new(gradients),
178 input_cache: RwLock::new(None),
179 hidden_states_cache: RwLock::new(None),
180 gate_cache: Arc::new(RwLock::new(None)),
181 })
182 }
183 fn step(
190 &self,
191 x: &ArrayView<F, IxDyn>,
192 h: &ArrayView<F, IxDyn>,
193 ) -> Result<GruForwardOutput<F>> {
194 let xshape = x.shape();
195 let hshape = h.shape();
196 let batch_size = xshape[0];
197 if xshape[1] != self.input_size {
199 return Err(NeuralError::InferenceError(format!(
200 "Input feature dimension mismatch: expected {}, got {}",
201 self.input_size, xshape[1]
202 )));
203 }
204 if hshape[1] != self.hidden_size {
205 return Err(NeuralError::InferenceError(format!(
206 "Hidden state dimension mismatch: expected {}, got {}",
207 self.hidden_size, hshape[1]
208 )));
209 }
210 if xshape[0] != hshape[0] {
211 return Err(NeuralError::InferenceError(format!(
212 "Batch size mismatch: input has {}, hidden state has {}",
213 xshape[0], hshape[0]
214 )));
215 }
216 let mut r_gate: Array<F, IxDyn> = Array::zeros(IxDyn(&[batch_size, self.hidden_size]));
218 let mut z_gate: Array<F, IxDyn> = Array::zeros(IxDyn(&[batch_size, self.hidden_size]));
219 let mut n_gate: Array<F, IxDyn> = Array::zeros(IxDyn(&[batch_size, self.hidden_size]));
220 let mut new_h: Array<F, IxDyn> = Array::zeros(IxDyn(&[batch_size, self.hidden_size]));
222 for b in 0..batch_size {
224 for i in 0..self.hidden_size {
225 let mut r_sum = self.bias_ir[i] + self.bias_hr[i];
227 for j in 0..self.input_size {
228 r_sum = r_sum + self.weight_ir[[i, j]] * x[[b, j]];
229 }
230 for j in 0..self.hidden_size {
231 r_sum = r_sum + self.weight_hr[[i, j]] * h[[b, j]];
232 }
233 r_gate[[b, i]] = F::one() / (F::one() + (-r_sum).exp()); let mut z_sum = self.bias_iz[i] + self.bias_hz[i];
237 for j in 0..self.input_size {
238 z_sum = z_sum + self.weight_iz[[i, j]] * x[[b, j]];
239 }
240 for j in 0..self.hidden_size {
241 z_sum = z_sum + self.weight_hz[[i, j]] * h[[b, j]];
242 }
243 z_gate[[b, i]] = F::one() / (F::one() + (-z_sum).exp()); let mut n_sum = self.bias_in[i];
247 for j in 0..self.input_size {
248 n_sum = n_sum + self.weight_in[[i, j]] * x[[b, j]];
249 }
250 let mut hn_sum = self.bias_hn[i];
252 for j in 0..self.hidden_size {
253 hn_sum = hn_sum + self.weight_hn[[i, j]] * h[[b, j]];
254 }
255 n_gate[[b, i]] = (n_sum + r_gate[[b, i]] * hn_sum).tanh(); new_h[[b, i]] =
259 (F::one() - z_gate[[b, i]]) * n_gate[[b, i]] + z_gate[[b, i]] * h[[b, i]];
260 }
261 }
262 let new_h_dyn = new_h.into_dyn();
264 let r_gate_dyn = r_gate.into_dyn();
265 let z_gate_dyn = z_gate.into_dyn();
266 let n_gate_dyn = n_gate.into_dyn();
267 Ok((new_h_dyn, (r_gate_dyn, z_gate_dyn, n_gate_dyn)))
268 }
269}
270
271impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for GRU<F> {
272 fn as_any(&self) -> &dyn std::any::Any {
273 self
274 }
275
276 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
277 self
278 }
279
280 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
281 *self.input_cache.write().unwrap() = Some(input.clone());
283 let inputshape = input.shape();
285 if inputshape.len() != 3 {
286 return Err(NeuralError::InferenceError(format!(
287 "Expected 3D input [batch_size, seq_len, features], got {inputshape:?}"
288 )));
289 }
290 let batch_size = inputshape[0];
291 let seq_len = inputshape[1];
292 let features = inputshape[2];
293 if features != self.input_size {
294 return Err(NeuralError::InferenceError(format!(
295 "Input features dimension mismatch: expected {}, got {}",
296 self.input_size, features
297 )));
298 }
299 let mut h = Array::zeros((batch_size, self.hidden_size));
301 let mut all_hidden_states = Array::zeros((batch_size, seq_len, self.hidden_size));
303 let mut all_gates = Vec::with_capacity(seq_len);
304 for t in 0..seq_len {
306 let x_t = input.slice(scirs2_core::ndarray::s![.., t, ..]);
308 let x_t_view = x_t.view().into_dyn();
310 let h_view = h.view().into_dyn();
311 let step_result = self.step(&x_t_view, &h_view)?;
312 let new_h = step_result.0;
313 let gates = step_result.1;
314 h = new_h.into_dimensionality::<Ix2>().unwrap();
316 all_gates.push(gates);
317 for b in 0..batch_size {
319 for i in 0..self.hidden_size {
320 all_hidden_states[[b, t, i]] = h[[b, i]];
321 }
322 }
323 }
324 *self.hidden_states_cache.write().unwrap() = Some(all_hidden_states.clone().into_dyn());
326 Ok(all_hidden_states.into_dyn())
328 }
329
330 fn backward(
331 &self,
332 input: &Array<F, IxDyn>,
333 _grad_output: &Array<F, IxDyn>,
334 ) -> Result<Array<F, IxDyn>> {
335 let input_ref = self.input_cache.read().map_err(|_| {
337 NeuralError::InferenceError("Failed to acquire read lock on input cache".to_string())
338 })?;
339 let hidden_states_ref = self.hidden_states_cache.read().map_err(|_| {
340 NeuralError::InferenceError(
341 "Failed to acquire read lock on hidden states cache".to_string(),
342 )
343 })?;
344 if input_ref.is_none() || hidden_states_ref.is_none() {
345 return Err(NeuralError::InferenceError(
346 "No cached values for backward pass. Call forward() first.".to_string(),
347 ));
348 }
349 let grad_input = Array::zeros(input.dim());
354 Ok(grad_input)
355 }
356
357 fn update(&mut self, learningrate: F) -> Result<()> {
358 let small_change = F::from(0.001).unwrap();
360 let lr = small_change * learningrate;
361 let update_param = |param: &mut Array<F, IxDyn>| {
363 for w in param.iter_mut() {
364 *w = *w - lr;
365 }
366 };
367 update_param(&mut self.weight_ir);
369 update_param(&mut self.weight_hr);
370 update_param(&mut self.bias_ir);
371 update_param(&mut self.bias_hr);
372 update_param(&mut self.weight_iz);
373 update_param(&mut self.weight_hz);
374 update_param(&mut self.bias_iz);
375 update_param(&mut self.bias_hz);
376 update_param(&mut self.weight_in);
377 update_param(&mut self.weight_hn);
378 update_param(&mut self.bias_in);
379 update_param(&mut self.bias_hn);
380 Ok(())
381 }
382}
383
384impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> ParamLayer<F> for GRU<F> {
385 fn get_parameters(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>> {
386 vec![
387 self.weight_ir.clone(),
388 self.weight_hr.clone(),
389 self.bias_ir.clone(),
390 self.bias_hr.clone(),
391 self.weight_iz.clone(),
392 self.weight_hz.clone(),
393 self.bias_iz.clone(),
394 self.bias_hz.clone(),
395 self.weight_in.clone(),
396 self.weight_hn.clone(),
397 self.bias_in.clone(),
398 self.bias_hn.clone(),
399 ]
400 }
401
402 fn get_gradients(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>> {
403 Vec::new()
407 }
408
409 fn set_parameters(&mut self, params: Vec<Array<F, scirs2_core::ndarray::IxDyn>>) -> Result<()> {
410 if params.len() != 12 {
411 return Err(NeuralError::InvalidArchitecture(format!(
412 "Expected 12 parameters, got {}",
413 params.len()
414 )));
415 }
416
417 let expectedshapes = [
418 self.weight_ir.shape(),
419 self.weight_hr.shape(),
420 self.bias_ir.shape(),
421 self.bias_hr.shape(),
422 self.weight_iz.shape(),
423 self.weight_hz.shape(),
424 self.bias_iz.shape(),
425 self.bias_hz.shape(),
426 self.weight_in.shape(),
427 self.weight_hn.shape(),
428 self.bias_in.shape(),
429 self.bias_hn.shape(),
430 ];
431
432 for (i, (param, expected)) in params.iter().zip(expectedshapes.iter()).enumerate() {
433 if param.shape() != *expected {
434 return Err(NeuralError::InvalidArchitecture(format!(
435 "Parameter {} shape mismatch: expected {:?}, got {:?}",
436 i,
437 expected,
438 param.shape()
439 )));
440 }
441 }
442
443 self.weight_ir = params[0].clone();
445 self.weight_hr = params[1].clone();
446 self.bias_ir = params[2].clone();
447 self.bias_hr = params[3].clone();
448 self.weight_iz = params[4].clone();
449 self.weight_hz = params[5].clone();
450 self.bias_iz = params[6].clone();
451 self.bias_hz = params[7].clone();
452 self.weight_in = params[8].clone();
453 self.weight_hn = params[9].clone();
454 self.bias_in = params[10].clone();
455 self.bias_hn = params[11].clone();
456
457 Ok(())
458 }
459}
460
461#[cfg(test)]
462mod tests {
463 use super::*;
464 use scirs2_core::ndarray::Array3;
465 use scirs2_core::random::rngs::SmallRng;
466 use scirs2_core::random::SeedableRng;
467
468 #[test]
469 fn test_grushape() {
470 let mut rng = SmallRng::from_seed([42; 32]);
472 let gru = GRU::<f64>::new(
473 10, 20, &mut rng,
476 )
477 .unwrap();
478
479 let batch_size = 2;
481 let seq_len = 5;
482 let input_size = 10;
483 let input = Array3::<f64>::from_elem((batch_size, seq_len, input_size), 0.1).into_dyn();
484 let output = gru.forward(&input).unwrap();
486 assert_eq!(output.shape(), &[batch_size, seq_len, 20]);
488 }
489}