scirs2_neural/layers/recurrent/
bidirectional.rs1use crate::error::{NeuralError, Result};
4use crate::layers::Layer;
5use scirs2_core::ndarray::{concatenate, Array, Axis, IxDyn, ScalarOperand};
6use scirs2_core::numeric::Float;
7use std::fmt::Debug;
8use std::sync::{Arc, RwLock};
9
10pub struct Bidirectional<F: Float + Debug + Send + Sync> {
42 forward_layer: Box<dyn Layer<F> + Send + Sync>,
44 backward_layer: Option<Box<dyn Layer<F> + Send + Sync>>,
46 name: Option<String>,
48 input_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
50}
51
52impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Bidirectional<F> {
53 pub fn new(
63 forward_layer: Box<dyn Layer<F> + Send + Sync>,
64 backward_layer: Option<Box<dyn Layer<F> + Send + Sync>>,
65 name: Option<&str>,
66 ) -> Result<Self> {
67 Ok(Self {
68 forward_layer,
69 backward_layer,
70 name: name.map(String::from),
71 input_cache: Arc::new(RwLock::new(None)),
72 })
73 }
74
75 pub fn new_with_single_layer(
78 layer: Box<dyn Layer<F> + Send + Sync>,
79 name: Option<&str>,
80 ) -> Result<Self> {
81 Self::new(layer, None, name)
82 }
83
84 pub fn name(&self) -> Option<&str> {
86 self.name.as_deref()
87 }
88}
89
90impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for Bidirectional<F> {
91 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
92 *self.input_cache.write().unwrap() = Some(input.clone());
94
95 let inputshape = input.shape();
97 if inputshape.len() != 3 {
98 return Err(NeuralError::InferenceError(format!(
99 "Expected 3D input [batch_size, seq_len, input_size], got {inputshape:?}"
100 )));
101 }
102 let _batch_size = inputshape[0];
103 let seq_len = inputshape[1];
104
105 let forward_output = self.forward_layer.forward(input)?;
107
108 if self.backward_layer.is_none() {
112 let mut reversed_slices = Vec::new();
115 for t in (0..seq_len).rev() {
116 let slice = input.slice(scirs2_core::ndarray::s![.., t..t + 1, ..]);
117 reversed_slices.push(slice);
118 }
119 let views: Vec<_> = reversed_slices.iter().map(|s| s.view()).collect();
120 let reversed_input = concatenate(Axis(1), &views)?.into_dyn();
121
122 let backward_output = self.forward_layer.forward(&reversed_input)?;
124
125 let mut backward_reversed_slices = Vec::new();
127 for t in (0..seq_len).rev() {
128 let slice = backward_output.slice(scirs2_core::ndarray::s![.., t..t + 1, ..]);
129 backward_reversed_slices.push(slice);
130 }
131 let backward_views: Vec<_> =
132 backward_reversed_slices.iter().map(|s| s.view()).collect();
133 let backward_output_aligned = concatenate(Axis(1), &backward_views)?.into_dyn();
134
135 let forward_view = forward_output.view();
137 let backward_view = backward_output_aligned.view();
138 let output = concatenate(Axis(2), &[forward_view, backward_view])?.into_dyn();
139 return Ok(output);
140 }
141
142 let backward_layer = self.backward_layer.as_ref().unwrap();
144
145 let mut reversed_slices = Vec::new();
148 for t in (0..seq_len).rev() {
149 let slice = input.slice(scirs2_core::ndarray::s![.., t..t + 1, ..]);
150 reversed_slices.push(slice);
151 }
152
153 let views: Vec<_> = reversed_slices.iter().map(|s| s.view()).collect();
155 let reversed_input = concatenate(Axis(1), &views)?.into_dyn();
156
157 let backward_output = backward_layer.forward(&reversed_input)?;
159
160 let mut backward_reversed_slices = Vec::new();
162 for t in (0..seq_len).rev() {
163 let slice = backward_output.slice(scirs2_core::ndarray::s![.., t..t + 1, ..]);
164 backward_reversed_slices.push(slice);
165 }
166 let backward_views: Vec<_> = backward_reversed_slices.iter().map(|s| s.view()).collect();
167 let backward_output_aligned = concatenate(Axis(1), &backward_views)?.into_dyn();
168
169 let forward_view = forward_output.view();
171 let backward_view = backward_output_aligned.view();
172 let output = concatenate(Axis(2), &[forward_view, backward_view])?.into_dyn();
173 Ok(output)
174 }
175
176 fn backward(
177 &self,
178 _input: &Array<F, IxDyn>,
179 grad_output: &Array<F, IxDyn>,
180 ) -> Result<Array<F, IxDyn>> {
181 let input_ref = self.input_cache.read().unwrap();
183 if input_ref.is_none() {
184 return Err(NeuralError::InferenceError(
185 "No cached _input for backward pass. Call forward() first.".to_string(),
186 ));
187 }
188 let cached_input = input_ref.as_ref().unwrap();
189
190 let gradshape = grad_output.shape();
192 if gradshape.len() != 3 {
193 return Err(NeuralError::InferenceError(format!(
194 "Expected 3D gradient [batch_size, seq_len, hidden_size*2], got {gradshape:?}"
195 )));
196 }
197 let _batch_size = gradshape[0];
198 let seq_len = gradshape[1];
199 let total_hidden = gradshape[2];
200
201 if self.backward_layer.is_none() {
204 let hidden_size = total_hidden / 2;
206 let grad_forward = grad_output
207 .slice(scirs2_core::ndarray::s![.., .., ..hidden_size])
208 .to_owned()
209 .into_dyn();
210 let grad_backward = grad_output
211 .slice(scirs2_core::ndarray::s![.., .., hidden_size..])
212 .to_owned()
213 .into_dyn();
214
215 let grad_input_forward = self.forward_layer.backward(cached_input, &grad_forward)?;
217
218 let mut backward_grad_slices = Vec::new();
220 for t in (0..seq_len).rev() {
221 let slice = grad_backward.slice(scirs2_core::ndarray::s![.., t..t + 1, ..]);
222 backward_grad_slices.push(slice);
223 }
224 let backward_grad_views: Vec<_> =
225 backward_grad_slices.iter().map(|s| s.view()).collect();
226 let grad_backward_reversed = concatenate(Axis(1), &backward_grad_views)?.into_dyn();
227
228 let mut input_slices = Vec::new();
230 for t in (0..seq_len).rev() {
231 let slice = cached_input.slice(scirs2_core::ndarray::s![.., t..t + 1, ..]);
232 input_slices.push(slice);
233 }
234 let input_views: Vec<_> = input_slices.iter().map(|s| s.view()).collect();
235 let input_reversed = concatenate(Axis(1), &input_views)?.into_dyn();
236
237 let grad_input_backward_reversed = self
239 .forward_layer
240 .backward(&input_reversed, &grad_backward_reversed)?;
241
242 let mut final_backward_slices = Vec::new();
244 for t in (0..seq_len).rev() {
245 let slice =
246 grad_input_backward_reversed.slice(scirs2_core::ndarray::s![.., t..t + 1, ..]);
247 final_backward_slices.push(slice);
248 }
249 let final_backward_views: Vec<_> =
250 final_backward_slices.iter().map(|s| s.view()).collect();
251 let grad_input_backward = concatenate(Axis(1), &final_backward_views)?.into_dyn();
252
253 let grad_input = grad_input_forward + grad_input_backward;
255 return Ok(grad_input);
256 }
257
258 let backward_layer = self.backward_layer.as_ref().unwrap();
260
261 let hidden_size = total_hidden / 2;
263 let grad_forward = grad_output
264 .slice(scirs2_core::ndarray::s![.., .., ..hidden_size])
265 .to_owned()
266 .into_dyn();
267 let grad_backward = grad_output
268 .slice(scirs2_core::ndarray::s![.., .., hidden_size..])
269 .to_owned()
270 .into_dyn();
271
272 let grad_input_forward = self.forward_layer.backward(cached_input, &grad_forward)?;
274
275 let mut backward_grad_slices = Vec::new();
278 for t in (0..seq_len).rev() {
279 let slice = grad_backward.slice(scirs2_core::ndarray::s![.., t..t + 1, ..]);
280 backward_grad_slices.push(slice);
281 }
282 let backward_grad_views: Vec<_> = backward_grad_slices.iter().map(|s| s.view()).collect();
283 let grad_backward_reversed = concatenate(Axis(1), &backward_grad_views)?.into_dyn();
284
285 let mut input_slices = Vec::new();
287 for t in (0..seq_len).rev() {
288 let slice = cached_input.slice(scirs2_core::ndarray::s![.., t..t + 1, ..]);
289 input_slices.push(slice);
290 }
291 let input_views: Vec<_> = input_slices.iter().map(|s| s.view()).collect();
292 let input_reversed = concatenate(Axis(1), &input_views)?.into_dyn();
293
294 let grad_input_backward_reversed =
296 backward_layer.backward(&input_reversed, &grad_backward_reversed)?;
297
298 let mut final_backward_slices = Vec::new();
300 for t in (0..seq_len).rev() {
301 let slice =
302 grad_input_backward_reversed.slice(scirs2_core::ndarray::s![.., t..t + 1, ..]);
303 final_backward_slices.push(slice);
304 }
305 let final_backward_views: Vec<_> = final_backward_slices.iter().map(|s| s.view()).collect();
306 let grad_input_backward = concatenate(Axis(1), &final_backward_views)?.into_dyn();
307
308 let grad_input = grad_input_forward + grad_input_backward;
310 Ok(grad_input)
311 }
312
313 fn update(&mut self, learningrate: F) -> Result<()> {
314 self.forward_layer.update(learningrate)?;
316
317 if let Some(ref mut backward_layer) = self.backward_layer {
319 backward_layer.update(learningrate)?;
320 }
321
322 Ok(())
323 }
324
325 fn as_any(&self) -> &dyn std::any::Any {
326 self
327 }
328
329 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
330 self
331 }
332}