1use crate::error::{NeuralError, Result};
13use crate::layers::recurrent::rnn::{RNNConfig, RecurrentActivation};
14use crate::layers::rnn_thread_safe::{
15 RecurrentActivation as ThreadSafeRecurrentActivation, ThreadSafeBidirectional, ThreadSafeRNN,
16};
17use crate::layers::{Dense, Dropout, Embedding, EmbeddingConfig, Layer};
18use scirs2_core::ndarray::{Array, Axis, IxDyn, ScalarOperand};
19use scirs2_core::numeric::{Float, NumAssign};
20use scirs2_core::random::SeedableRng;
21type EncoderOutput<F> = (Array<F, IxDyn>, Vec<Array<F, IxDyn>>);
23type AttentionOutput<F> = (Array<F, IxDyn>, Vec<Array<F, IxDyn>>);
25use serde::{Deserialize, Serialize};
26use std::fmt::Debug;
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
29pub enum RNNCellType {
30 SimpleRNN,
32 LSTM,
34 GRU,
36}
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct Seq2SeqConfig {
40 pub input_vocab_size: usize,
42 pub output_vocab_size: usize,
44 pub embedding_dim: usize,
46 pub hidden_dim: usize,
48 pub num_layers: usize,
50 pub encoder_cell_type: RNNCellType,
52 pub decoder_cell_type: RNNCellType,
54 pub bidirectional_encoder: bool,
56 pub use_attention: bool,
58 pub dropout_rate: f64,
60 pub max_seq_len: usize,
62}
63
64impl Default for Seq2SeqConfig {
65 fn default() -> Self {
66 Self {
67 input_vocab_size: 10000,
68 output_vocab_size: 10000,
69 embedding_dim: 256,
70 hidden_dim: 512,
71 num_layers: 2,
72 encoder_cell_type: RNNCellType::LSTM,
73 decoder_cell_type: RNNCellType::LSTM,
74 bidirectional_encoder: true,
75 use_attention: true,
76 dropout_rate: 0.1,
77 max_seq_len: 100,
78 }
79 }
80}
81
82#[derive(Debug, Clone)]
84pub struct Attention<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
85 pub decoder_projection: Dense<F>,
87 pub encoder_projection: Option<Dense<F>>,
89 pub combined_projection: Dense<F>,
91 pub output_projection: Dense<F>,
93 pub attention_type: AttentionType,
95 pub bidirectional_encoder: bool,
97}
98
99#[derive(Debug, Clone, PartialEq)]
101pub enum AttentionType {
102 Additive,
104 Multiplicative,
106 General,
108}
109
110impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Attention<F> {
111 pub fn new(
113 decoder_dim: usize,
114 encoder_dim: usize,
115 attention_dim: usize,
116 attention_type: AttentionType,
117 bidirectional_encoder: bool,
118 ) -> Result<Self> {
119 let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
121 let decoder_projection = Dense::<F>::new(decoder_dim, attention_dim, None, &mut rng)?;
123 let encoder_projection = if attention_type == AttentionType::Additive {
125 Some(Dense::<F>::new(encoder_dim, attention_dim, None, &mut rng)?)
126 } else {
127 None
128 };
129 let combined_dim = match attention_type {
131 AttentionType::Additive => attention_dim,
132 AttentionType::Multiplicative => 1,
133 AttentionType::General => encoder_dim,
134 };
135
136 let combined_projection = Dense::<F>::new(combined_dim, 1, None, &mut rng)?;
137 let output_projection =
139 Dense::<F>::new(encoder_dim + decoder_dim, decoder_dim, None, &mut rng)?;
140 Ok(Self {
141 decoder_projection,
142 encoder_projection,
143 combined_projection,
144 output_projection,
145 attention_type,
146 bidirectional_encoder,
147 })
148 }
149
150 pub fn forward(
152 &self,
153 decoder_state: &Array<F, IxDyn>,
154 encoder_outputs: &Array<F, IxDyn>,
155 ) -> Result<(Array<F, IxDyn>, Array<F, IxDyn>)> {
156 let batch_size = decoder_state.shape()[0];
158 let seq_len = encoder_outputs.shape()[1];
159 let encoder_dim = encoder_outputs.shape()[2];
160 let decoder_projected = self.decoder_projection.forward(decoder_state)?;
162 let attention_scores = match self.attention_type {
164 AttentionType::Additive => {
165 let encoder_projected = if let Some(ref proj) = self.encoder_projection {
167 let flat_encoder = encoder_outputs
169 .to_owned()
170 .into_shape_with_order((batch_size * seq_len, encoder_dim))?;
171 let projected = proj.forward(&flat_encoder.into_dyn())?;
172 let projshape = projected.shape()[1];
173 projected
174 .into_shape_with_order((batch_size, seq_len, projshape))?
175 .into_dyn()
176 } else {
177 return Err(NeuralError::InferenceError(
178 "Encoder projection missing for additive attention".to_string(),
179 ));
180 };
181 let expanded_decoder = decoder_projected.to_owned().into_shape_with_order((
183 batch_size,
184 1,
185 decoder_projected.shape()[1],
186 ))?;
187 let expanded = expanded_decoder
188 .broadcast((batch_size, seq_len, expanded_decoder.shape()[2]))
189 .expect("Operation failed");
190 let combined = &expanded + &encoder_projected;
192 let tanh = combined.mapv(|x| x.tanh());
194 let flat_tanh = tanh
195 .to_owned()
196 .into_shape_with_order((batch_size * seq_len, tanh.shape()[2]))?;
197 let scores = self.combined_projection.forward(&flat_tanh.into_dyn())?;
198 scores
199 .into_shape_with_order((batch_size, seq_len))?
200 .into_dyn()
201 }
202 AttentionType::Multiplicative => {
203 let expanded_decoder = decoder_projected.to_owned().into_shape_with_order((
205 batch_size,
206 1,
207 decoder_projected.shape()[1],
208 ))?;
209 let mut scores = Array::<F, _>::zeros((batch_size, seq_len));
211 for b in 0..batch_size {
212 let decoder_slice = expanded_decoder.slice(scirs2_core::ndarray::s![b, 0, ..]);
213 for s in 0..seq_len {
214 let encoder_slice =
215 encoder_outputs.slice(scirs2_core::ndarray::s![b, s, ..]);
216 let mut dot_product = F::zero();
218 for i in 0..decoder_slice.len() {
219 dot_product += decoder_slice[i] * encoder_slice[i];
220 }
221 scores[[b, s]] = dot_product;
222 }
223 }
224 scores.into_dyn()
225 }
226 AttentionType::General => {
227 let weight_matrix = decoder_projected.to_owned();
229 let mut scores = Array::<F, _>::zeros((batch_size, seq_len));
231 for b in 0..batch_size {
232 let weight = weight_matrix.slice(scirs2_core::ndarray::s![b, ..]);
233 for s in 0..seq_len {
234 let encoder_slice =
235 encoder_outputs.slice(scirs2_core::ndarray::s![b, s, ..]);
236 let mut dot_product = F::zero();
238 for i in 0..weight.len() {
239 dot_product += weight[i] * encoder_slice[i];
240 }
241 scores[[b, s]] = dot_product;
242 }
243 }
244 scores.into_dyn()
245 }
246 };
247 let mut attention_weights = Array::<F, IxDyn>::zeros(attention_scores.raw_dim());
249 for b in 0..batch_size {
251 let mut row = attention_scores
252 .slice(scirs2_core::ndarray::s![b, ..])
253 .to_owned();
254 let max_val = row.fold(F::neg_infinity(), |m, &v| m.max(v));
256 let mut exp_sum = F::zero();
258 for i in 0..seq_len {
259 let exp_val = (row[i] - max_val).exp();
260 row[i] = exp_val;
261 exp_sum += exp_val;
262 }
263
264 if exp_sum > F::zero() {
266 for i in 0..seq_len {
267 row[i] /= exp_sum;
268 }
269 }
270
271 for i in 0..seq_len {
273 attention_weights[[b, i]] = row[i];
274 }
275 }
276 let attention_weights_expanded = attention_weights
278 .to_owned()
279 .into_shape_with_order((batch_size, seq_len, 1))?;
280 let broadcast_weights = attention_weights_expanded
281 .broadcast((batch_size, seq_len, encoder_dim))
282 .expect("Operation failed");
283 let weighted_encoder = encoder_outputs * &broadcast_weights;
285 let context = weighted_encoder.sum_axis(Axis(1));
286 let decoder_state_dyn = decoder_state.to_owned().into_dyn();
288 let decoder_and_context =
289 scirs2_core::ndarray::stack(Axis(1), &[context.view(), decoder_state_dyn.view()])?;
290 let flattened = decoder_and_context
291 .into_shape_with_order((batch_size, context.shape()[1] + decoder_state.shape()[1]))?;
292 let flattened_dyn = flattened.to_owned().into_dyn();
294 let output = self.output_projection.forward(&flattened_dyn)?;
295 Ok((output, attention_weights))
296 }
297}
298
299impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Layer<F> for Attention<F> {
300 fn forward(&self, _input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
301 Err(NeuralError::InvalidArchitecture("Attention layer requires separate decoder state and encoder outputs. Use the dedicated forward method.".to_string()))
302 }
303
304 fn backward(
305 &self,
306 _input: &Array<F, IxDyn>,
307 grad_output: &Array<F, IxDyn>,
308 ) -> Result<Array<F, IxDyn>> {
309 let grad_input = Array::<F, IxDyn>::zeros(grad_output.dim());
319 Ok(grad_input)
320 }
321
322 fn update(&mut self, learning_rate: F) -> Result<()> {
323 self.decoder_projection.update(learning_rate)?;
326 if let Some(ref mut proj) = self.encoder_projection {
328 proj.update(learning_rate)?;
329 }
330
331 self.combined_projection.update(learning_rate)?;
333 self.output_projection.update(learning_rate)?;
335 Ok(())
336 }
337
338 fn as_any(&self) -> &dyn std::any::Any {
339 self
340 }
341
342 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
343 self
344 }
345
346 fn params(&self) -> Vec<Array<F, IxDyn>> {
347 let mut params = Vec::new();
348 params.extend(self.decoder_projection.params());
349 if let Some(ref proj) = self.encoder_projection {
350 params.extend(proj.params());
351 }
352 params.extend(self.combined_projection.params());
353 params.extend(self.output_projection.params());
354 params
355 }
356
357 fn set_training(&mut self, training: bool) {
358 self.decoder_projection.set_training(training);
359 if let Some(ref mut proj) = self.encoder_projection {
360 proj.set_training(training);
361 }
362 self.combined_projection.set_training(training);
363 self.output_projection.set_training(training);
364 }
365
366 fn is_training(&self) -> bool {
367 self.decoder_projection.is_training()
368 }
369}
370pub struct Seq2SeqEncoder<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
372 pub embedding: Embedding<F>,
374 pub rnn_layers: Vec<Box<dyn Layer<F> + Send + Sync>>,
376 pub dropout: Option<Dropout<F>>,
378 pub bidirectional: bool,
380 pub cell_type: RNNCellType,
382 pub hidden_dim: usize,
384 pub num_layers: usize,
386}
387
388impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Debug for Seq2SeqEncoder<F> {
389 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
390 f.debug_struct("Seq2SeqEncoder")
391 .field("embedding", &self.embedding)
392 .field(
393 "rnn_layers",
394 &format!("Vec<Box<dyn Layer>> (len: {})", self.rnn_layers.len()),
395 )
396 .field("dropout", &self.dropout)
397 .field("bidirectional", &self.bidirectional)
398 .field("cell_type", &self.cell_type)
399 .field("hidden_dim", &self.hidden_dim)
400 .field("num_layers", &self.num_layers)
401 .finish()
402 }
403}
404
405impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Seq2SeqEncoder<F> {
406 pub fn new(
408 vocab_size: usize,
409 embedding_dim: usize,
410 hidden_dim: usize,
411 num_layers: usize,
412 cell_type: RNNCellType,
413 bidirectional: bool,
414 dropout_rate: Option<f64>,
415 ) -> Result<Self> {
416 let embedding_config = EmbeddingConfig {
418 num_embeddings: vocab_size,
419 embedding_dim,
420 padding_idx: None,
421 max_norm: None,
422 norm_type: 2.0,
423 scale_grad_by_freq: false,
424 };
425 let embedding = Embedding::<F>::new(embedding_config)?;
426 let mut rnn_layers: Vec<Box<dyn Layer<F> + Send + Sync>> = Vec::with_capacity(num_layers);
428 for i in 0..num_layers {
429 let input_size = if i == 0 {
430 embedding_dim
431 } else if bidirectional && i > 0 {
432 hidden_dim * 2
433 } else {
434 hidden_dim
435 };
436 let rnn: Box<dyn Layer<F> + Send + Sync> = match cell_type {
438 RNNCellType::SimpleRNN => {
439 let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
440 let config = RNNConfig {
441 input_size,
442 hidden_size: hidden_dim,
443 activation: RecurrentActivation::Tanh,
444 };
445 let rnn = ThreadSafeRNN::<F>::new(
447 config.input_size,
448 config.hidden_size,
449 ThreadSafeRecurrentActivation::Tanh, &mut rng,
451 )?;
452 if bidirectional {
453 let brnn = ThreadSafeBidirectional::new(Box::new(rnn), None)?;
455 Box::new(brnn)
456 } else {
457 Box::new(rnn)
458 }
459 }
460 RNNCellType::LSTM => {
461 let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
465 let rnn = ThreadSafeRNN::<F>::new(
466 input_size,
467 hidden_dim,
468 ThreadSafeRecurrentActivation::Tanh,
469 &mut rng,
470 )?;
471 if bidirectional {
472 let brnn = ThreadSafeBidirectional::new(Box::new(rnn), None)?;
474 Box::new(brnn)
475 } else {
476 Box::new(rnn)
477 }
478 }
479 RNNCellType::GRU => {
480 let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
483 let rnn = ThreadSafeRNN::<F>::new(
484 input_size,
485 hidden_dim,
486 ThreadSafeRecurrentActivation::Tanh,
487 &mut rng,
488 )?;
489 if bidirectional {
490 let brnn = ThreadSafeBidirectional::new(Box::new(rnn), None)?;
491 Box::new(brnn)
492 } else {
493 Box::new(rnn)
494 }
495 }
496 };
497
498 rnn_layers.push(rnn);
499 }
500
501 let dropout = if let Some(rate) = dropout_rate {
503 if rate > 0.0 {
504 let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
505 Some(Dropout::<F>::new(rate, &mut rng)?)
506 } else {
507 None
508 }
509 } else {
510 None
511 };
512
513 Ok(Self {
514 embedding,
515 rnn_layers,
516 dropout,
517 bidirectional,
518 cell_type,
519 hidden_dim,
520 num_layers,
521 })
522 }
523 pub fn forward(&self, input_seq: &Array<F, IxDyn>) -> Result<EncoderOutput<F>> {
525 let mut x = self.embedding.forward(input_seq)?;
527 if let Some(ref dropout) = self.dropout {
529 x = dropout.forward(&x)?;
530 }
531
532 let mut states = Vec::new();
534 for layer in &self.rnn_layers {
535 let output = layer.forward(&x)?;
537 if self.bidirectional {
539 let sequences = output
541 .slice_axis(Axis(1), scirs2_core::ndarray::Slice::from(0..1))
542 .into_shape_with_order((
543 output.shape()[0],
544 output.shape()[2],
545 output.shape()[3],
546 ))?
547 .to_owned(); let state = output
549 .slice_axis(Axis(1), scirs2_core::ndarray::Slice::from(1..2))
550 .into_shape_with_order((output.shape()[0], output.shape()[3]))?
551 .to_owned();
552 x = sequences.into_dyn();
553 states.push(state.into_dyn());
554 } else {
555 let sequences = output
557 .slice_axis(Axis(1), scirs2_core::ndarray::Slice::from(0..1))
558 .into_shape_with_order((
559 output.shape()[0],
560 output.shape()[2],
561 output.shape()[3],
562 ))?
563 .to_owned();
564 let state = output
565 .slice_axis(Axis(1), scirs2_core::ndarray::Slice::from(1..2))
566 .into_shape_with_order((output.shape()[0], output.shape()[3]))?
567 .to_owned();
568 x = sequences.into_dyn();
569 states.push(state.into_dyn());
570 }
571 }
572
573 Ok((x, states))
574 }
575}
576impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Layer<F> for Seq2SeqEncoder<F> {
577 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
578 let (output, _) = self.forward(input)?;
581 Ok(output)
582 }
583
584 fn backward(
585 &self,
586 input: &Array<F, IxDyn>,
587 grad_output: &Array<F, IxDyn>,
588 ) -> Result<Array<F, IxDyn>> {
589 let mut grad = grad_output.clone();
596 for layer in self.rnn_layers.iter().rev() {
598 grad = layer.backward(&grad, &grad)?;
599 }
600 let grad_input = self.embedding.backward(input, &grad)?;
604 Ok(grad_input)
605 }
606
607 fn update(&mut self, learning_rate: F) -> Result<()> {
608 self.embedding.update(learning_rate)?;
610 for layer in &mut self.rnn_layers {
612 layer.update(learning_rate)?;
613 }
614 Ok(())
616 }
617
618 fn as_any(&self) -> &dyn std::any::Any {
619 self
620 }
621
622 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
623 self
624 }
625
626 fn params(&self) -> Vec<Array<F, IxDyn>> {
627 let mut params = Vec::new();
628 params.extend(self.embedding.params());
629 for layer in &self.rnn_layers {
630 params.extend(layer.params());
631 }
632 if let Some(ref dropout) = self.dropout {
633 params.extend(dropout.params());
634 }
635 params
636 }
637
638 fn set_training(&mut self, training: bool) {
639 self.embedding.set_training(training);
640 for layer in &mut self.rnn_layers {
641 layer.set_training(training);
642 }
643 if let Some(ref mut dropout) = self.dropout {
644 dropout.set_training(training);
645 }
646 }
647
648 fn is_training(&self) -> bool {
649 self.embedding.is_training()
650 }
651}
652pub struct Seq2SeqDecoder<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
654 pub embedding: Embedding<F>,
656 pub rnn_layers: Vec<Box<dyn Layer<F> + Send + Sync>>,
658 pub dropout: Option<Dropout<F>>,
660 pub attention: Option<Attention<F>>,
662 pub output_projection: Dense<F>,
664 pub vocab_size: usize,
666 pub hidden_dim: usize,
668 pub cell_type: RNNCellType,
670}
671
672impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Debug for Seq2SeqDecoder<F> {
673 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
674 f.debug_struct("Seq2SeqDecoder")
675 .field("embedding", &self.embedding)
676 .field(
677 "rnn_layers",
678 &format!("Vec<Box<dyn Layer>> (len: {})", self.rnn_layers.len()),
679 )
680 .field("dropout", &self.dropout)
681 .field("attention", &self.attention)
682 .field("output_projection", &self.output_projection)
683 .field("vocab_size", &self.vocab_size)
684 .field("hidden_dim", &self.hidden_dim)
685 .field("cell_type", &self.cell_type)
686 .finish()
687 }
688}
689
690impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Seq2SeqDecoder<F> {
691 #[allow(clippy::too_many_arguments)]
693 pub fn new(
694 vocab_size: usize,
695 embedding_dim: usize,
696 hidden_dim: usize,
697 num_layers: usize,
698 cell_type: RNNCellType,
699 use_attention: bool,
700 encoder_bidirectional: bool,
701 dropout_rate: Option<f64>,
702 ) -> Result<Self> {
703 let embedding_config = EmbeddingConfig {
705 num_embeddings: vocab_size,
706 embedding_dim,
707 padding_idx: None,
708 max_norm: None,
709 norm_type: 2.0,
710 scale_grad_by_freq: false,
711 };
712 let embedding = Embedding::<F>::new(embedding_config)?;
713
714 let mut rnn_layers: Vec<Box<dyn Layer<F> + Send + Sync>> = Vec::with_capacity(num_layers);
716 for i in 0..num_layers {
717 let input_size = if i == 0 { embedding_dim } else { hidden_dim };
718
719 let rnn: Box<dyn Layer<F> + Send + Sync> = match cell_type {
721 RNNCellType::SimpleRNN | RNNCellType::LSTM | RNNCellType::GRU => {
722 let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
723 let rnn = ThreadSafeRNN::<F>::new(
724 input_size,
725 hidden_dim,
726 ThreadSafeRecurrentActivation::Tanh,
727 &mut rng,
728 )?;
729 Box::new(rnn)
730 }
731 };
732 rnn_layers.push(rnn);
733 }
734
735 let dropout = if let Some(rate) = dropout_rate {
737 if rate > 0.0 {
738 let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
739 Some(Dropout::<F>::new(rate, &mut rng)?)
740 } else {
741 None
742 }
743 } else {
744 None
745 };
746
747 let attention = if use_attention {
749 let encoder_dim = if encoder_bidirectional {
750 hidden_dim * 2
751 } else {
752 hidden_dim
753 };
754 Some(Attention::<F>::new(
755 hidden_dim,
756 encoder_dim,
757 hidden_dim,
758 AttentionType::Additive,
759 encoder_bidirectional,
760 )?)
761 } else {
762 None
763 };
764
765 let mut rng_clone = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
767 let output_projection = Dense::<F>::new(
768 hidden_dim,
769 vocab_size,
770 None, &mut rng_clone,
772 )?;
773
774 Ok(Self {
775 embedding,
776 rnn_layers,
777 dropout,
778 attention,
779 output_projection,
780 vocab_size,
781 hidden_dim,
782 cell_type,
783 })
784 }
785 pub fn forward_step(
787 &self,
788 input_tokens: &Array<F, IxDyn>,
789 prev_states: &[Array<F, IxDyn>],
790 encoder_outputs: Option<&Array<F, IxDyn>>,
791 ) -> Result<AttentionOutput<F>> {
792 let mut x = self.embedding.forward(input_tokens)?;
793
794 let mut states_out = Vec::new();
796 for (i, layer) in self.rnn_layers.iter().enumerate() {
797 let prev_state = if i < prev_states.len() {
798 Some(&prev_states[i])
799 } else {
800 None
801 };
802
803 let output = if let Some(state) = prev_state {
805 let initial_state = state
807 .to_owned()
808 .into_shape_with_order((state.shape()[0], state.shape()[1]))?;
809 let x_dyn = x.to_owned().into_dyn();
810 let initial_state_dyn = initial_state.to_owned().into_dyn();
811 let combined_input = scirs2_core::ndarray::stack(
812 Axis(1),
813 &[x_dyn.view(), initial_state_dyn.view()],
814 )?;
815 layer.forward(&combined_input.to_owned().into_dyn())?
816 } else {
817 layer.forward(&x.to_owned().into_dyn())?
818 };
819
820 let sequences = output
822 .slice_axis(Axis(1), scirs2_core::ndarray::Slice::from(0..1))
823 .into_shape_with_order((output.shape()[0], output.shape()[2], output.shape()[3]))?
824 .to_owned();
825 let state = output
826 .slice_axis(Axis(1), scirs2_core::ndarray::Slice::from(1..2))
827 .into_shape_with_order((output.shape()[0], output.shape()[3]))?
828 .to_owned();
829 x = sequences.into_dyn();
830 states_out.push(state.into_dyn());
831 }
832 let final_output = if let Some(ref attention) = self.attention {
834 if let Some(encoder_out) = encoder_outputs {
835 let batch_size = x.shape()[0];
837 let hidden_size = x.shape()[2];
838 let last_hidden = x.into_shape_with_order((batch_size, hidden_size))?;
840 let dyn_last_hidden = last_hidden.to_owned().into_dyn();
843 let (attentional_hidden, _) = attention.forward(&dyn_last_hidden, encoder_out)?;
844 self.output_projection.forward(&attentional_hidden)?
846 } else {
847 return Err(NeuralError::InvalidArchitecture(
848 "Attention requires encoder outputs".to_string(),
849 ));
850 }
851 } else {
852 let batch_size = x.shape()[0];
854 let hidden_size = x.shape()[2];
855 let last_hidden = x.into_shape_with_order((batch_size, hidden_size))?;
857 let dyn_last_hidden = last_hidden.to_owned().into_dyn();
859 self.output_projection.forward(&dyn_last_hidden)?
860 };
861
862 Ok((final_output, states_out))
863 }
864 pub fn forward_sequence(
866 &self,
867 input_tokens: &Array<F, IxDyn>,
868 initial_states: &[Array<F, IxDyn>],
869 encoder_outputs: Option<&Array<F, IxDyn>>,
870 ) -> Result<Array<F, IxDyn>> {
871 let batch_size = input_tokens.shape()[0];
872 let seq_len = input_tokens.shape()[1];
873 let mut outputs = Array::<F, _>::zeros((batch_size, seq_len, self.vocab_size));
875 let mut states = initial_states.to_vec();
876 for t in 0..seq_len {
878 let tokens_t = input_tokens
880 .slice_axis(Axis(1), scirs2_core::ndarray::Slice::from(t..t + 1))
881 .to_owned()
882 .into_dyn();
883
884 let (output_t, new_states) = self.forward_step(&tokens_t, &states, encoder_outputs)?;
886
887 for b in 0..batch_size {
889 for v in 0..self.vocab_size {
890 outputs[[b, t, v]] = output_t[[b, v]];
891 }
892 }
893
894 states = new_states;
896 }
897
898 Ok(outputs.into_dyn())
899 }
900}
901impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Layer<F> for Seq2SeqDecoder<F> {
902 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
903 let batch_size = input.shape()[0];
908 let _seq_len = input.shape()[1]; let mut initial_states = Vec::new();
911 for _ in 0..self.rnn_layers.len() {
912 let state = Array::<F, _>::zeros((batch_size, self.hidden_dim)).into_dyn();
913 initial_states.push(state);
914 }
915 self.forward_sequence(input, &initial_states, None)
917 }
918
919 fn backward(
920 &self,
921 input: &Array<F, IxDyn>,
922 grad_output: &Array<F, IxDyn>,
923 ) -> Result<Array<F, IxDyn>> {
924 let mut grad = grad_output.clone();
933
934 grad = self.output_projection.backward(&grad, &grad)?;
937
938 if let Some(ref attention) = self.attention {
940 grad = attention.backward(&grad, &grad)?;
943 }
944
945 for layer in self.rnn_layers.iter().rev() {
947 grad = layer.backward(&grad, &grad)?;
948 }
949
950 if let Some(ref dropout) = self.dropout {
952 if self.is_training() {
953 grad = dropout.backward(&grad, &grad)?;
954 }
955 }
956
957 let grad_input = self.embedding.backward(input, &grad)?;
959 Ok(grad_input)
960 }
961
962 fn update(&mut self, learning_rate: F) -> Result<()> {
963 self.embedding.update(learning_rate)?;
966
967 for layer in &mut self.rnn_layers {
969 layer.update(learning_rate)?;
970 }
971
972 if let Some(ref mut attention) = self.attention {
974 attention.update(learning_rate)?;
975 }
976
977 self.output_projection.update(learning_rate)?;
979 Ok(())
980 }
981
982 fn as_any(&self) -> &dyn std::any::Any {
983 self
984 }
985
986 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
987 self
988 }
989
990 fn params(&self) -> Vec<Array<F, IxDyn>> {
991 let mut params = Vec::new();
992 params.extend(self.embedding.params());
993 for layer in &self.rnn_layers {
994 params.extend(layer.params());
995 }
996 if let Some(ref attention) = self.attention {
997 params.extend(attention.params());
998 }
999 params.extend(self.output_projection.params());
1000 params
1001 }
1002
1003 fn set_training(&mut self, training: bool) {
1004 self.embedding.set_training(training);
1005 for layer in &mut self.rnn_layers {
1006 layer.set_training(training);
1007 }
1008 if let Some(ref mut attention) = self.attention {
1009 attention.set_training(training);
1010 }
1011 self.output_projection.set_training(training);
1012 }
1013
1014 fn is_training(&self) -> bool {
1015 self.embedding.is_training()
1016 }
1017}
1018#[derive(Debug)]
1020pub struct Seq2Seq<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
1021 pub encoder: Seq2SeqEncoder<F>,
1023 pub decoder: Seq2SeqDecoder<F>,
1025 pub config: Seq2SeqConfig,
1027}
1028
1029impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Seq2Seq<F> {
1030 pub fn new(config: Seq2SeqConfig) -> Result<Self> {
1032 let encoder = Seq2SeqEncoder::<F>::new(
1034 config.input_vocab_size,
1035 config.embedding_dim,
1036 config.hidden_dim,
1037 config.num_layers,
1038 config.encoder_cell_type,
1039 config.bidirectional_encoder,
1040 Some(config.dropout_rate),
1041 )?;
1042
1043 let decoder = Seq2SeqDecoder::<F>::new(
1045 config.output_vocab_size,
1046 config.embedding_dim,
1047 config.hidden_dim,
1048 config.num_layers,
1049 config.decoder_cell_type,
1050 config.use_attention,
1051 config.bidirectional_encoder,
1052 Some(config.dropout_rate),
1053 )?;
1054
1055 Ok(Self {
1056 encoder,
1057 decoder,
1058 config,
1059 })
1060 }
1061 pub fn forward_train(
1063 &self,
1064 input_seq: &Array<F, IxDyn>,
1065 target_seq: &Array<F, IxDyn>,
1066 ) -> Result<Array<F, IxDyn>> {
1067 let (encoder_outputs, encoder_states) = self.encoder.forward(input_seq)?;
1069 let decoder_initial_states = if self.config.encoder_cell_type
1071 == self.config.decoder_cell_type
1072 {
1073 encoder_states
1075 } else {
1076 let batch_size = input_seq.shape()[0];
1079 let mut initial_states = Vec::new();
1080 for _ in 0..self.config.num_layers {
1081 let state = Array::<F, _>::zeros((batch_size, self.config.hidden_dim)).into_dyn();
1082 initial_states.push(state);
1083 }
1084 initial_states
1085 };
1086
1087 let decoder_output = self.decoder.forward_sequence(
1089 target_seq,
1090 &decoder_initial_states,
1091 Some(&encoder_outputs),
1092 )?;
1093
1094 Ok(decoder_output)
1095 }
1096 pub fn generate(
1098 &self,
1099 input_seq: &Array<F, IxDyn>,
1100 max_length: Option<usize>,
1101 start_token_id: usize,
1102 end_token_id: Option<usize>,
1103 ) -> Result<Array<F, IxDyn>> {
1104 let (encoder_outputs, encoder_states) = self.encoder.forward(input_seq)?;
1106
1107 let batch_size = input_seq.shape()[0];
1108 let max_len = max_length.unwrap_or(self.config.max_seq_len);
1109
1110 let decoder_states = if self.config.encoder_cell_type == self.config.decoder_cell_type {
1112 encoder_states
1113 } else {
1114 let mut initial_states = Vec::new();
1115 for _ in 0..self.config.num_layers {
1116 let state = Array::<F, _>::zeros((batch_size, self.config.hidden_dim)).into_dyn();
1117 initial_states.push(state);
1118 }
1119 initial_states
1120 };
1121
1122 let mut decoder_input = Array::<F, _>::zeros((batch_size, 1));
1124 for b in 0..batch_size {
1125 decoder_input[[b, 0]] =
1126 F::from(start_token_id as f64).expect("Failed to convert to float");
1127 }
1128 let mut decoder_input = decoder_input.into_dyn();
1129 let mut output_ids = Array::<F, _>::zeros((batch_size, max_len));
1130 let mut states = decoder_states;
1131 let mut completed = vec![false; batch_size];
1133 for t in 0..max_len {
1135 let (output_t, new_states) =
1136 self.decoder
1137 .forward_step(&decoder_input, &states, Some(&encoder_outputs))?;
1138 let mut next_tokens = Array::<F, _>::zeros((batch_size, 1));
1140 for b in 0..batch_size {
1141 if completed[b] {
1142 continue;
1143 }
1144
1145 let mut max_prob = F::neg_infinity();
1147 let mut max_idx = 0;
1148 for v in 0..self.config.output_vocab_size {
1149 if output_t[[b, v]] > max_prob {
1150 max_prob = output_t[[b, v]];
1151 max_idx = v;
1152 }
1153 }
1154
1155 output_ids[[b, t]] = F::from(max_idx as f64).expect("Failed to convert to float");
1157 next_tokens[[b, 0]] = F::from(max_idx as f64).expect("Failed to convert to float");
1158
1159 if let Some(eos_id) = end_token_id {
1161 if max_idx == eos_id {
1162 completed[b] = true;
1163 }
1164 }
1165 }
1166
1167 if completed.iter().all(|&c| c) {
1169 break;
1170 }
1171
1172 decoder_input = next_tokens.into_dyn();
1174 states = new_states;
1175 }
1176
1177 Ok(output_ids.into_dyn())
1178 }
1179 pub fn create_translation_model(
1181 src_vocab_size: usize,
1182 tgt_vocab_size: usize,
1183 hidden_dim: usize,
1184 ) -> Result<Self> {
1185 let config = Seq2SeqConfig {
1186 input_vocab_size: src_vocab_size,
1187 output_vocab_size: tgt_vocab_size,
1188 embedding_dim: hidden_dim,
1189 hidden_dim,
1190 ..Default::default()
1191 };
1192 Self::new(config)
1193 }
1194 pub fn create_small_model(src_vocab_size: usize, tgt_vocab_size: usize) -> Result<Self> {
1196 let config = Seq2SeqConfig {
1197 input_vocab_size: src_vocab_size,
1198 output_vocab_size: tgt_vocab_size,
1199 embedding_dim: 128,
1200 hidden_dim: 256,
1201 num_layers: 1,
1202 encoder_cell_type: RNNCellType::GRU,
1203 decoder_cell_type: RNNCellType::GRU,
1204 bidirectional_encoder: false,
1205 use_attention: false,
1206 dropout_rate: 0.0,
1207 max_seq_len: 50,
1208 };
1209 Self::new(config)
1210 }
1211}
1212impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Layer<F> for Seq2Seq<F> {
1213 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
1214 self.generate(
1217 input,
1218 Some(self.config.max_seq_len),
1219 0, None,
1221 )
1222 }
1223
1224 fn backward(
1225 &self,
1226 input: &Array<F, IxDyn>,
1227 grad_output: &Array<F, IxDyn>,
1228 ) -> Result<Array<F, IxDyn>> {
1229 let decoder_grad = self.decoder.backward(input, grad_output)?;
1241
1242 let encoder_grad = self.encoder.backward(input, &decoder_grad)?;
1246 Ok(encoder_grad)
1247 }
1248
1249 fn update(&mut self, learning_rate: F) -> Result<()> {
1250 self.encoder.update(learning_rate)?;
1253 self.decoder.update(learning_rate)?;
1255 Ok(())
1256 }
1257
1258 fn as_any(&self) -> &dyn std::any::Any {
1259 self
1260 }
1261
1262 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
1263 self
1264 }
1265
1266 fn params(&self) -> Vec<Array<F, IxDyn>> {
1267 let mut params = Vec::new();
1268 params.extend(self.encoder.params());
1269 params.extend(self.decoder.params());
1270 params
1271 }
1272
1273 fn set_training(&mut self, training: bool) {
1274 self.encoder.set_training(training);
1275 self.decoder.set_training(training);
1276 }
1277
1278 fn is_training(&self) -> bool {
1279 self.encoder.is_training()
1280 }
1281}