scirs2_neural/layers/recurrent/
bidirectional.rs

1//! Bidirectional wrapper for recurrent layers
2
3use crate::error::{NeuralError, Result};
4use crate::layers::Layer;
5use scirs2_core::ndarray::{concatenate, Array, Axis, IxDyn, ScalarOperand};
6use scirs2_core::numeric::Float;
7use std::fmt::Debug;
8use std::sync::{Arc, RwLock};
9
10/// Bidirectional RNN wrapper for recurrent layers
11///
12/// This layer wraps a recurrent layer to enable bidirectional processing.
13/// It processes the input sequence in both forward and backward directions,
14/// and concatenates the results.
15///
16/// # Examples
17/// ```
18/// use scirs2_neural::layers::{Layer, recurrent::{Bidirectional, RNN, rnn::RecurrentActivation}};
19/// use scirs2_core::ndarray::{Array, Array3};
20/// use scirs2_core::random::rngs::SmallRng;
21/// use scirs2_core::random::SeedableRng;
22///
23/// // Create RNN layers for forward and backward directions
24/// let mut rng = scirs2_core::random::rng();
25/// let forward_rnn = RNN::new(10, 20, RecurrentActivation::Tanh, &mut rng).unwrap();
26/// let backward_rnn = RNN::new(10, 20, RecurrentActivation::Tanh, &mut rng).unwrap();
27///
28/// // Wrap them in a bidirectional layer
29/// let birnn = Bidirectional::new(Box::new(forward_rnn), Some(Box::new(backward_rnn)), None).unwrap();
30///
31/// // Forward pass with a batch of 2 samples, sequence length 5, and 10 features
32/// let batch_size = 2;
33/// let seq_len = 5;
34/// let input_size = 10;
35/// let input = Array3::<f64>::from_elem((batch_size, seq_len, input_size), 0.1).into_dyn();
36/// let output = birnn.forward(&input).unwrap();
37///
38/// // Output should have dimensions [batch_size, seq_len, hidden_size*2]
39/// assert_eq!(output.shape(), &[batch_size, seq_len, 40]);
40/// ```
41pub struct Bidirectional<F: Float + Debug + Send + Sync> {
42    /// Forward direction layer
43    forward_layer: Box<dyn Layer<F> + Send + Sync>,
44    /// Backward direction layer (using the same layer type)
45    backward_layer: Option<Box<dyn Layer<F> + Send + Sync>>,
46    /// Name for the layer
47    name: Option<String>,
48    /// Input cache for backward pass
49    input_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
50}
51
52impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Bidirectional<F> {
53    /// Create a new bidirectional wrapper
54    ///
55    /// # Arguments
56    /// * `forward_layer` - The recurrent layer to use in forward direction
57    /// * `backward_layer` - Optional recurrent layer for backward direction (if None, forward layer will be used)
58    /// * `name` - Optional name for the layer
59    ///
60    /// # Returns
61    /// * A new bidirectional layer
62    pub fn new(
63        forward_layer: Box<dyn Layer<F> + Send + Sync>,
64        backward_layer: Option<Box<dyn Layer<F> + Send + Sync>>,
65        name: Option<&str>,
66    ) -> Result<Self> {
67        Ok(Self {
68            forward_layer,
69            backward_layer,
70            name: name.map(String::from),
71            input_cache: Arc::new(RwLock::new(None)),
72        })
73    }
74
75    /// Create a new bidirectional wrapper with a single layer
76    /// This constructor is for backward compatibility
77    pub fn new_with_single_layer(
78        layer: Box<dyn Layer<F> + Send + Sync>,
79        name: Option<&str>,
80    ) -> Result<Self> {
81        Self::new(layer, None, name)
82    }
83
84    /// Get the name of the layer
85    pub fn name(&self) -> Option<&str> {
86        self.name.as_deref()
87    }
88}
89
90impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for Bidirectional<F> {
91    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
92        // Cache input for backward pass
93        *self.input_cache.write().unwrap() = Some(input.clone());
94
95        // Check input dimensions
96        let inputshape = input.shape();
97        if inputshape.len() != 3 {
98            return Err(NeuralError::InferenceError(format!(
99                "Expected 3D input [batch_size, seq_len, input_size], got {inputshape:?}"
100            )));
101        }
102        let _batch_size = inputshape[0];
103        let seq_len = inputshape[1];
104
105        // Forward direction
106        let forward_output = self.forward_layer.forward(input)?;
107
108        // If no backward layer is provided, we need to create a duplicate of the forward layer
109        // for backward processing. Since we can't clone trait objects directly, we'll
110        // process the sequence twice with the same layer for bidirectional behavior.
111        if self.backward_layer.is_none() {
112            // Process the input in reverse direction with the same forward layer
113            // Create reversed input
114            let mut reversed_slices = Vec::new();
115            for t in (0..seq_len).rev() {
116                let slice = input.slice(scirs2_core::ndarray::s![.., t..t + 1, ..]);
117                reversed_slices.push(slice);
118            }
119            let views: Vec<_> = reversed_slices.iter().map(|s| s.view()).collect();
120            let reversed_input = concatenate(Axis(1), &views)?.into_dyn();
121
122            // Process through the same forward layer
123            let backward_output = self.forward_layer.forward(&reversed_input)?;
124
125            // Reverse the backward output to align with forward output
126            let mut backward_reversed_slices = Vec::new();
127            for t in (0..seq_len).rev() {
128                let slice = backward_output.slice(scirs2_core::ndarray::s![.., t..t + 1, ..]);
129                backward_reversed_slices.push(slice);
130            }
131            let backward_views: Vec<_> =
132                backward_reversed_slices.iter().map(|s| s.view()).collect();
133            let backward_output_aligned = concatenate(Axis(1), &backward_views)?.into_dyn();
134
135            // Concatenate forward and backward outputs along the feature dimension
136            let forward_view = forward_output.view();
137            let backward_view = backward_output_aligned.view();
138            let output = concatenate(Axis(2), &[forward_view, backward_view])?.into_dyn();
139            return Ok(output);
140        }
141
142        // Process backward direction
143        let backward_layer = self.backward_layer.as_ref().unwrap();
144
145        // Reverse the sequence dimension of input
146        // Create views for each time step and reverse their order
147        let mut reversed_slices = Vec::new();
148        for t in (0..seq_len).rev() {
149            let slice = input.slice(scirs2_core::ndarray::s![.., t..t + 1, ..]);
150            reversed_slices.push(slice);
151        }
152
153        // Concatenate the reversed slices along the time dimension
154        let views: Vec<_> = reversed_slices.iter().map(|s| s.view()).collect();
155        let reversed_input = concatenate(Axis(1), &views)?.into_dyn();
156
157        // Process through backward layer
158        let backward_output = backward_layer.forward(&reversed_input)?;
159
160        // Reverse the backward output to align with forward output
161        let mut backward_reversed_slices = Vec::new();
162        for t in (0..seq_len).rev() {
163            let slice = backward_output.slice(scirs2_core::ndarray::s![.., t..t + 1, ..]);
164            backward_reversed_slices.push(slice);
165        }
166        let backward_views: Vec<_> = backward_reversed_slices.iter().map(|s| s.view()).collect();
167        let backward_output_aligned = concatenate(Axis(1), &backward_views)?.into_dyn();
168
169        // Concatenate forward and backward outputs along the feature dimension
170        let forward_view = forward_output.view();
171        let backward_view = backward_output_aligned.view();
172        let output = concatenate(Axis(2), &[forward_view, backward_view])?.into_dyn();
173        Ok(output)
174    }
175
176    fn backward(
177        &self,
178        _input: &Array<F, IxDyn>,
179        grad_output: &Array<F, IxDyn>,
180    ) -> Result<Array<F, IxDyn>> {
181        // Retrieve cached _input
182        let input_ref = self.input_cache.read().unwrap();
183        if input_ref.is_none() {
184            return Err(NeuralError::InferenceError(
185                "No cached _input for backward pass. Call forward() first.".to_string(),
186            ));
187        }
188        let cached_input = input_ref.as_ref().unwrap();
189
190        // Check gradient dimensions
191        let gradshape = grad_output.shape();
192        if gradshape.len() != 3 {
193            return Err(NeuralError::InferenceError(format!(
194                "Expected 3D gradient [batch_size, seq_len, hidden_size*2], got {gradshape:?}"
195            )));
196        }
197        let _batch_size = gradshape[0];
198        let seq_len = gradshape[1];
199        let total_hidden = gradshape[2];
200
201        // If no backward layer, we need to handle gradients for both directions processed
202        // by the same layer
203        if self.backward_layer.is_none() {
204            // Split gradient into forward and backward components
205            let hidden_size = total_hidden / 2;
206            let grad_forward = grad_output
207                .slice(scirs2_core::ndarray::s![.., .., ..hidden_size])
208                .to_owned()
209                .into_dyn();
210            let grad_backward = grad_output
211                .slice(scirs2_core::ndarray::s![.., .., hidden_size..])
212                .to_owned()
213                .into_dyn();
214
215            // Backward pass through forward layer with forward gradient
216            let grad_input_forward = self.forward_layer.backward(cached_input, &grad_forward)?;
217
218            // For backward gradient, we need to reverse it first, then compute backward pass
219            let mut backward_grad_slices = Vec::new();
220            for t in (0..seq_len).rev() {
221                let slice = grad_backward.slice(scirs2_core::ndarray::s![.., t..t + 1, ..]);
222                backward_grad_slices.push(slice);
223            }
224            let backward_grad_views: Vec<_> =
225                backward_grad_slices.iter().map(|s| s.view()).collect();
226            let grad_backward_reversed = concatenate(Axis(1), &backward_grad_views)?.into_dyn();
227
228            // Reverse the _input for backward processing
229            let mut input_slices = Vec::new();
230            for t in (0..seq_len).rev() {
231                let slice = cached_input.slice(scirs2_core::ndarray::s![.., t..t + 1, ..]);
232                input_slices.push(slice);
233            }
234            let input_views: Vec<_> = input_slices.iter().map(|s| s.view()).collect();
235            let input_reversed = concatenate(Axis(1), &input_views)?.into_dyn();
236
237            // Backward pass through the same forward layer
238            let grad_input_backward_reversed = self
239                .forward_layer
240                .backward(&input_reversed, &grad_backward_reversed)?;
241
242            // Reverse the backward gradient back to original order
243            let mut final_backward_slices = Vec::new();
244            for t in (0..seq_len).rev() {
245                let slice =
246                    grad_input_backward_reversed.slice(scirs2_core::ndarray::s![.., t..t + 1, ..]);
247                final_backward_slices.push(slice);
248            }
249            let final_backward_views: Vec<_> =
250                final_backward_slices.iter().map(|s| s.view()).collect();
251            let grad_input_backward = concatenate(Axis(1), &final_backward_views)?.into_dyn();
252
253            // Sum the gradients from forward and backward paths
254            let grad_input = grad_input_forward + grad_input_backward;
255            return Ok(grad_input);
256        }
257
258        // Get the backward layer
259        let backward_layer = self.backward_layer.as_ref().unwrap();
260
261        // Split gradient into forward and backward components
262        let hidden_size = total_hidden / 2;
263        let grad_forward = grad_output
264            .slice(scirs2_core::ndarray::s![.., .., ..hidden_size])
265            .to_owned()
266            .into_dyn();
267        let grad_backward = grad_output
268            .slice(scirs2_core::ndarray::s![.., .., hidden_size..])
269            .to_owned()
270            .into_dyn();
271
272        // Backward pass through forward layer
273        let grad_input_forward = self.forward_layer.backward(cached_input, &grad_forward)?;
274
275        // For backward layer, we need to reverse the gradient and _input
276        // Reverse the gradient for backward layer
277        let mut backward_grad_slices = Vec::new();
278        for t in (0..seq_len).rev() {
279            let slice = grad_backward.slice(scirs2_core::ndarray::s![.., t..t + 1, ..]);
280            backward_grad_slices.push(slice);
281        }
282        let backward_grad_views: Vec<_> = backward_grad_slices.iter().map(|s| s.view()).collect();
283        let grad_backward_reversed = concatenate(Axis(1), &backward_grad_views)?.into_dyn();
284
285        // Reverse the _input for backward layer
286        let mut input_slices = Vec::new();
287        for t in (0..seq_len).rev() {
288            let slice = cached_input.slice(scirs2_core::ndarray::s![.., t..t + 1, ..]);
289            input_slices.push(slice);
290        }
291        let input_views: Vec<_> = input_slices.iter().map(|s| s.view()).collect();
292        let input_reversed = concatenate(Axis(1), &input_views)?.into_dyn();
293
294        // Backward pass through backward layer
295        let grad_input_backward_reversed =
296            backward_layer.backward(&input_reversed, &grad_backward_reversed)?;
297
298        // Reverse the backward gradient back to original order
299        let mut final_backward_slices = Vec::new();
300        for t in (0..seq_len).rev() {
301            let slice =
302                grad_input_backward_reversed.slice(scirs2_core::ndarray::s![.., t..t + 1, ..]);
303            final_backward_slices.push(slice);
304        }
305        let final_backward_views: Vec<_> = final_backward_slices.iter().map(|s| s.view()).collect();
306        let grad_input_backward = concatenate(Axis(1), &final_backward_views)?.into_dyn();
307
308        // Sum the gradients from forward and backward paths
309        let grad_input = grad_input_forward + grad_input_backward;
310        Ok(grad_input)
311    }
312
313    fn update(&mut self, learningrate: F) -> Result<()> {
314        // Update forward layer
315        self.forward_layer.update(learningrate)?;
316
317        // Update backward layer if present
318        if let Some(ref mut backward_layer) = self.backward_layer {
319            backward_layer.update(learningrate)?;
320        }
321
322        Ok(())
323    }
324
325    fn as_any(&self) -> &dyn std::any::Any {
326        self
327    }
328
329    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
330        self
331    }
332}