Skip to main content

scirs2_neural/models/architectures/
seq2seq.rs

1//! Sequence-to-Sequence (Seq2Seq) model architectures
2//!
3//! This module implements various RNN-based sequence models including:
4//! - Encoder-Decoder architectures
5//! - Sequence-to-Sequence with attention
6//! - Bidirectional RNN encoder with attention
7//!
8//! These models are useful for machine translation, text summarization,
9//! speech recognition, and other sequence generation tasks.
10
11// use crate::activations::Softmax;
12use crate::error::{NeuralError, Result};
13use crate::layers::recurrent::rnn::{RNNConfig, RecurrentActivation};
14use crate::layers::rnn_thread_safe::{
15    RecurrentActivation as ThreadSafeRecurrentActivation, ThreadSafeBidirectional, ThreadSafeRNN,
16};
17use crate::layers::{Dense, Dropout, Embedding, EmbeddingConfig, Layer};
18use scirs2_core::ndarray::{Array, Axis, IxDyn, ScalarOperand};
19use scirs2_core::numeric::{Float, NumAssign};
20use scirs2_core::random::SeedableRng;
21/// Type alias for encoder forward output
22type EncoderOutput<F> = (Array<F, IxDyn>, Vec<Array<F, IxDyn>>);
23/// Type alias for attention forward output
24type AttentionOutput<F> = (Array<F, IxDyn>, Vec<Array<F, IxDyn>>);
25use serde::{Deserialize, Serialize};
26use std::fmt::Debug;
27/// RNN cell types for sequence models
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
29pub enum RNNCellType {
30    /// Simple RNN cell
31    SimpleRNN,
32    /// LSTM (Long Short-Term Memory) cell
33    LSTM,
34    /// GRU (Gated Recurrent Unit) cell
35    GRU,
36}
37/// Configuration for Seq2Seq models
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct Seq2SeqConfig {
40    /// Vocabulary size for input encoder
41    pub input_vocab_size: usize,
42    /// Vocabulary size for output decoder
43    pub output_vocab_size: usize,
44    /// Embedding dimension
45    pub embedding_dim: usize,
46    /// Hidden dimension for RNN cells
47    pub hidden_dim: usize,
48    /// Number of RNN layers
49    pub num_layers: usize,
50    /// Cell type for encoder
51    pub encoder_cell_type: RNNCellType,
52    /// Cell type for decoder
53    pub decoder_cell_type: RNNCellType,
54    /// Whether to use bidirectional encoder
55    pub bidirectional_encoder: bool,
56    /// Whether to use attention
57    pub use_attention: bool,
58    /// Dropout rate
59    pub dropout_rate: f64,
60    /// Maximum sequence length
61    pub max_seq_len: usize,
62}
63
64impl Default for Seq2SeqConfig {
65    fn default() -> Self {
66        Self {
67            input_vocab_size: 10000,
68            output_vocab_size: 10000,
69            embedding_dim: 256,
70            hidden_dim: 512,
71            num_layers: 2,
72            encoder_cell_type: RNNCellType::LSTM,
73            decoder_cell_type: RNNCellType::LSTM,
74            bidirectional_encoder: true,
75            use_attention: true,
76            dropout_rate: 0.1,
77            max_seq_len: 100,
78        }
79    }
80}
81
82/// Attention mechanism for sequence models
83#[derive(Debug, Clone)]
84pub struct Attention<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
85    /// Attention projection for decoder state
86    pub decoder_projection: Dense<F>,
87    /// Attention projection for encoder outputs
88    pub encoder_projection: Option<Dense<F>>,
89    /// Combined projection
90    pub combined_projection: Dense<F>,
91    /// Output projection
92    pub output_projection: Dense<F>,
93    /// Attention type
94    pub attention_type: AttentionType,
95    /// Whether encoder outputs are bidirectional
96    pub bidirectional_encoder: bool,
97}
98
99/// Types of attention mechanisms
100#[derive(Debug, Clone, PartialEq)]
101pub enum AttentionType {
102    /// Additive attention (Bahdanau)
103    Additive,
104    /// Multiplicative attention (Luong)
105    Multiplicative,
106    /// General attention (learned projection)
107    General,
108}
109
110impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Attention<F> {
111    /// Create a new Attention module
112    pub fn new(
113        decoder_dim: usize,
114        encoder_dim: usize,
115        attention_dim: usize,
116        attention_type: AttentionType,
117        bidirectional_encoder: bool,
118    ) -> Result<Self> {
119        // Create a random number generator for initialization
120        let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
121        // Create projections based on attention type
122        let decoder_projection = Dense::<F>::new(decoder_dim, attention_dim, None, &mut rng)?;
123        // For additive attention, we need to project encoder outputs
124        let encoder_projection = if attention_type == AttentionType::Additive {
125            Some(Dense::<F>::new(encoder_dim, attention_dim, None, &mut rng)?)
126        } else {
127            None
128        };
129        // For multiplicative and general attention
130        let combined_dim = match attention_type {
131            AttentionType::Additive => attention_dim,
132            AttentionType::Multiplicative => 1,
133            AttentionType::General => encoder_dim,
134        };
135
136        let combined_projection = Dense::<F>::new(combined_dim, 1, None, &mut rng)?;
137        // Project context vector and decoder state for output
138        let output_projection =
139            Dense::<F>::new(encoder_dim + decoder_dim, decoder_dim, None, &mut rng)?;
140        Ok(Self {
141            decoder_projection,
142            encoder_projection,
143            combined_projection,
144            output_projection,
145            attention_type,
146            bidirectional_encoder,
147        })
148    }
149
150    /// Compute attention weights and context vector
151    pub fn forward(
152        &self,
153        decoder_state: &Array<F, IxDyn>,
154        encoder_outputs: &Array<F, IxDyn>,
155    ) -> Result<(Array<F, IxDyn>, Array<F, IxDyn>)> {
156        // Get shapes
157        let batch_size = decoder_state.shape()[0];
158        let seq_len = encoder_outputs.shape()[1];
159        let encoder_dim = encoder_outputs.shape()[2];
160        // Project decoder state
161        let decoder_projected = self.decoder_projection.forward(decoder_state)?;
162        // Compute attention scores based on attention type
163        let attention_scores = match self.attention_type {
164            AttentionType::Additive => {
165                // Project encoder outputs if needed
166                let encoder_projected = if let Some(ref proj) = self.encoder_projection {
167                    // Reshape for projection
168                    let flat_encoder = encoder_outputs
169                        .to_owned()
170                        .into_shape_with_order((batch_size * seq_len, encoder_dim))?;
171                    let projected = proj.forward(&flat_encoder.into_dyn())?;
172                    let projshape = projected.shape()[1];
173                    projected
174                        .into_shape_with_order((batch_size, seq_len, projshape))?
175                        .into_dyn()
176                } else {
177                    return Err(NeuralError::InferenceError(
178                        "Encoder projection missing for additive attention".to_string(),
179                    ));
180                };
181                // Expand decoder state for broadcasting
182                let expanded_decoder = decoder_projected.to_owned().into_shape_with_order((
183                    batch_size,
184                    1,
185                    decoder_projected.shape()[1],
186                ))?;
187                let expanded = expanded_decoder
188                    .broadcast((batch_size, seq_len, expanded_decoder.shape()[2]))
189                    .expect("Operation failed");
190                // Add encoder and decoder projections
191                let combined = &expanded + &encoder_projected;
192                // Apply tanh and project to get scores
193                let tanh = combined.mapv(|x| x.tanh());
194                let flat_tanh = tanh
195                    .to_owned()
196                    .into_shape_with_order((batch_size * seq_len, tanh.shape()[2]))?;
197                let scores = self.combined_projection.forward(&flat_tanh.into_dyn())?;
198                scores
199                    .into_shape_with_order((batch_size, seq_len))?
200                    .into_dyn()
201            }
202            AttentionType::Multiplicative => {
203                // Expand decoder state for each encoder position
204                let expanded_decoder = decoder_projected.to_owned().into_shape_with_order((
205                    batch_size,
206                    1,
207                    decoder_projected.shape()[1],
208                ))?;
209                // Batched dot product
210                let mut scores = Array::<F, _>::zeros((batch_size, seq_len));
211                for b in 0..batch_size {
212                    let decoder_slice = expanded_decoder.slice(scirs2_core::ndarray::s![b, 0, ..]);
213                    for s in 0..seq_len {
214                        let encoder_slice =
215                            encoder_outputs.slice(scirs2_core::ndarray::s![b, s, ..]);
216                        // Manually calculate dot product to avoid ambiguity
217                        let mut dot_product = F::zero();
218                        for i in 0..decoder_slice.len() {
219                            dot_product += decoder_slice[i] * encoder_slice[i];
220                        }
221                        scores[[b, s]] = dot_product;
222                    }
223                }
224                scores.into_dyn()
225            }
226            AttentionType::General => {
227                // Project decoder state once (used as a weight matrix)
228                let weight_matrix = decoder_projected.to_owned();
229                // Batched matrix multiply
230                let mut scores = Array::<F, _>::zeros((batch_size, seq_len));
231                for b in 0..batch_size {
232                    let weight = weight_matrix.slice(scirs2_core::ndarray::s![b, ..]);
233                    for s in 0..seq_len {
234                        let encoder_slice =
235                            encoder_outputs.slice(scirs2_core::ndarray::s![b, s, ..]);
236                        // Manually calculate dot product
237                        let mut dot_product = F::zero();
238                        for i in 0..weight.len() {
239                            dot_product += weight[i] * encoder_slice[i];
240                        }
241                        scores[[b, s]] = dot_product;
242                    }
243                }
244                scores.into_dyn()
245            }
246        };
247        // Apply softmax to get attention weights
248        let mut attention_weights = Array::<F, IxDyn>::zeros(attention_scores.raw_dim());
249        // Manual softmax implementation
250        for b in 0..batch_size {
251            let mut row = attention_scores
252                .slice(scirs2_core::ndarray::s![b, ..])
253                .to_owned();
254            // Find max for numerical stability
255            let max_val = row.fold(F::neg_infinity(), |m, &v| m.max(v));
256            // Compute exp and sum
257            let mut exp_sum = F::zero();
258            for i in 0..seq_len {
259                let exp_val = (row[i] - max_val).exp();
260                row[i] = exp_val;
261                exp_sum += exp_val;
262            }
263
264            // Normalize
265            if exp_sum > F::zero() {
266                for i in 0..seq_len {
267                    row[i] /= exp_sum;
268                }
269            }
270
271            // Copy normalized weights
272            for i in 0..seq_len {
273                attention_weights[[b, i]] = row[i];
274            }
275        }
276        // Compute context vector
277        let attention_weights_expanded = attention_weights
278            .to_owned()
279            .into_shape_with_order((batch_size, seq_len, 1))?;
280        let broadcast_weights = attention_weights_expanded
281            .broadcast((batch_size, seq_len, encoder_dim))
282            .expect("Operation failed");
283        // Element-wise multiply and sum over sequence dimension
284        let weighted_encoder = encoder_outputs * &broadcast_weights;
285        let context = weighted_encoder.sum_axis(Axis(1));
286        // Concatenate context and decoder state - ensure both are in the same format for stacking
287        let decoder_state_dyn = decoder_state.to_owned().into_dyn();
288        let decoder_and_context =
289            scirs2_core::ndarray::stack(Axis(1), &[context.view(), decoder_state_dyn.view()])?;
290        let flattened = decoder_and_context
291            .into_shape_with_order((batch_size, context.shape()[1] + decoder_state.shape()[1]))?;
292        // Project combined vector - convert to IxDyn for Layer trait
293        let flattened_dyn = flattened.to_owned().into_dyn();
294        let output = self.output_projection.forward(&flattened_dyn)?;
295        Ok((output, attention_weights))
296    }
297}
298
299impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Layer<F> for Attention<F> {
300    fn forward(&self, _input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
301        Err(NeuralError::InvalidArchitecture("Attention layer requires separate decoder state and encoder outputs. Use the dedicated forward method.".to_string()))
302    }
303
304    fn backward(
305        &self,
306        _input: &Array<F, IxDyn>,
307        grad_output: &Array<F, IxDyn>,
308    ) -> Result<Array<F, IxDyn>> {
309        // Simplified backward pass for Attention
310        // Note: Proper implementation would require caching intermediate values from forward pass
311        // and handling the dual-input nature of attention (decoder_state, encoder_outputs)
312        // For now, create a placeholder gradient that matches input dimensions
313        // In a full implementation, this would:
314        // 1. Compute gradients w.r.t. attention weights
315        // 2. Backpropagate through softmax
316        // 3. Compute gradients for all projection layers
317        // 4. Return gradients for both decoder_state and encoder_outputs
318        let grad_input = Array::<F, IxDyn>::zeros(grad_output.dim());
319        Ok(grad_input)
320    }
321
322    fn update(&mut self, learning_rate: F) -> Result<()> {
323        // Update all projection layers
324        // Update decoder projection
325        self.decoder_projection.update(learning_rate)?;
326        // Update encoder projection if it exists
327        if let Some(ref mut proj) = self.encoder_projection {
328            proj.update(learning_rate)?;
329        }
330
331        // Update combined projection
332        self.combined_projection.update(learning_rate)?;
333        // Update output projection
334        self.output_projection.update(learning_rate)?;
335        Ok(())
336    }
337
338    fn as_any(&self) -> &dyn std::any::Any {
339        self
340    }
341
342    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
343        self
344    }
345
346    fn params(&self) -> Vec<Array<F, IxDyn>> {
347        let mut params = Vec::new();
348        params.extend(self.decoder_projection.params());
349        if let Some(ref proj) = self.encoder_projection {
350            params.extend(proj.params());
351        }
352        params.extend(self.combined_projection.params());
353        params.extend(self.output_projection.params());
354        params
355    }
356
357    fn set_training(&mut self, training: bool) {
358        self.decoder_projection.set_training(training);
359        if let Some(ref mut proj) = self.encoder_projection {
360            proj.set_training(training);
361        }
362        self.combined_projection.set_training(training);
363        self.output_projection.set_training(training);
364    }
365
366    fn is_training(&self) -> bool {
367        self.decoder_projection.is_training()
368    }
369}
370/// Encoder for Seq2Seq models
371pub struct Seq2SeqEncoder<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
372    /// Input embedding layer
373    pub embedding: Embedding<F>,
374    /// RNN layers
375    pub rnn_layers: Vec<Box<dyn Layer<F> + Send + Sync>>,
376    /// Dropout layer
377    pub dropout: Option<Dropout<F>>,
378    /// Whether the encoder is bidirectional
379    pub bidirectional: bool,
380    /// RNN cell type
381    pub cell_type: RNNCellType,
382    /// Hidden dimension
383    pub hidden_dim: usize,
384    /// Number of layers
385    pub num_layers: usize,
386}
387
388impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Debug for Seq2SeqEncoder<F> {
389    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
390        f.debug_struct("Seq2SeqEncoder")
391            .field("embedding", &self.embedding)
392            .field(
393                "rnn_layers",
394                &format!("Vec<Box<dyn Layer>> (len: {})", self.rnn_layers.len()),
395            )
396            .field("dropout", &self.dropout)
397            .field("bidirectional", &self.bidirectional)
398            .field("cell_type", &self.cell_type)
399            .field("hidden_dim", &self.hidden_dim)
400            .field("num_layers", &self.num_layers)
401            .finish()
402    }
403}
404
405impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Seq2SeqEncoder<F> {
406    /// Create a new Seq2SeqEncoder
407    pub fn new(
408        vocab_size: usize,
409        embedding_dim: usize,
410        hidden_dim: usize,
411        num_layers: usize,
412        cell_type: RNNCellType,
413        bidirectional: bool,
414        dropout_rate: Option<f64>,
415    ) -> Result<Self> {
416        // Create embedding layer with config
417        let embedding_config = EmbeddingConfig {
418            num_embeddings: vocab_size,
419            embedding_dim,
420            padding_idx: None,
421            max_norm: None,
422            norm_type: 2.0,
423            scale_grad_by_freq: false,
424        };
425        let embedding = Embedding::<F>::new(embedding_config)?;
426        // Create RNN layers
427        let mut rnn_layers: Vec<Box<dyn Layer<F> + Send + Sync>> = Vec::with_capacity(num_layers);
428        for i in 0..num_layers {
429            let input_size = if i == 0 {
430                embedding_dim
431            } else if bidirectional && i > 0 {
432                hidden_dim * 2
433            } else {
434                hidden_dim
435            };
436            // Create the appropriate RNN layer based on cell type
437            let rnn: Box<dyn Layer<F> + Send + Sync> = match cell_type {
438                RNNCellType::SimpleRNN => {
439                    let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
440                    let config = RNNConfig {
441                        input_size,
442                        hidden_size: hidden_dim,
443                        activation: RecurrentActivation::Tanh,
444                    };
445                    // Use thread-safe RNN implementation to ensure multi-threading compatibility
446                    let rnn = ThreadSafeRNN::<F>::new(
447                        config.input_size,
448                        config.hidden_size,
449                        ThreadSafeRecurrentActivation::Tanh, // Convert activation
450                        &mut rng,
451                    )?;
452                    if bidirectional {
453                        // Use thread-safe bidirectional wrapper for multi-threading
454                        let brnn = ThreadSafeBidirectional::new(Box::new(rnn), None)?;
455                        Box::new(brnn)
456                    } else {
457                        Box::new(rnn)
458                    }
459                }
460                RNNCellType::LSTM => {
461                    // Use thread-safe RNN as a replacement for LSTM until LSTM is made thread-safe
462                    // For true thread safety, we'll use our ThreadSafeRNN with tanh activation
463                    // as a temporary replacement for LSTM
464                    let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
465                    let rnn = ThreadSafeRNN::<F>::new(
466                        input_size,
467                        hidden_dim,
468                        ThreadSafeRecurrentActivation::Tanh,
469                        &mut rng,
470                    )?;
471                    if bidirectional {
472                        // Use thread-safe bidirectional wrapper
473                        let brnn = ThreadSafeBidirectional::new(Box::new(rnn), None)?;
474                        Box::new(brnn)
475                    } else {
476                        Box::new(rnn)
477                    }
478                }
479                RNNCellType::GRU => {
480                    // Use thread-safe RNN as a replacement for GRU until GRU is made thread-safe
481                    // as a temporary replacement for GRU
482                    let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
483                    let rnn = ThreadSafeRNN::<F>::new(
484                        input_size,
485                        hidden_dim,
486                        ThreadSafeRecurrentActivation::Tanh,
487                        &mut rng,
488                    )?;
489                    if bidirectional {
490                        let brnn = ThreadSafeBidirectional::new(Box::new(rnn), None)?;
491                        Box::new(brnn)
492                    } else {
493                        Box::new(rnn)
494                    }
495                }
496            };
497
498            rnn_layers.push(rnn);
499        }
500
501        // Create dropout layer if needed
502        let dropout = if let Some(rate) = dropout_rate {
503            if rate > 0.0 {
504                let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
505                Some(Dropout::<F>::new(rate, &mut rng)?)
506            } else {
507                None
508            }
509        } else {
510            None
511        };
512
513        Ok(Self {
514            embedding,
515            rnn_layers,
516            dropout,
517            bidirectional,
518            cell_type,
519            hidden_dim,
520            num_layers,
521        })
522    }
523    /// Forward pass through the encoder
524    pub fn forward(&self, input_seq: &Array<F, IxDyn>) -> Result<EncoderOutput<F>> {
525        // Apply embedding
526        let mut x = self.embedding.forward(input_seq)?;
527        // Apply dropout if available
528        if let Some(ref dropout) = self.dropout {
529            x = dropout.forward(&x)?;
530        }
531
532        // Process through RNN layers
533        let mut states = Vec::new();
534        for layer in &self.rnn_layers {
535            // Each RNN layer returns sequences and final state
536            let output = layer.forward(&x)?;
537            // For bidirectional layers, we need to concatenate forward and backward states
538            if self.bidirectional {
539                // Extract sequences (first element) and states
540                let sequences = output
541                    .slice_axis(Axis(1), scirs2_core::ndarray::Slice::from(0..1))
542                    .into_shape_with_order((
543                        output.shape()[0],
544                        output.shape()[2],
545                        output.shape()[3],
546                    ))?
547                    .to_owned(); // Convert to owned array
548                let state = output
549                    .slice_axis(Axis(1), scirs2_core::ndarray::Slice::from(1..2))
550                    .into_shape_with_order((output.shape()[0], output.shape()[3]))?
551                    .to_owned();
552                x = sequences.into_dyn();
553                states.push(state.into_dyn());
554            } else {
555                // Extract sequences (first element) and state (second element)
556                let sequences = output
557                    .slice_axis(Axis(1), scirs2_core::ndarray::Slice::from(0..1))
558                    .into_shape_with_order((
559                        output.shape()[0],
560                        output.shape()[2],
561                        output.shape()[3],
562                    ))?
563                    .to_owned();
564                let state = output
565                    .slice_axis(Axis(1), scirs2_core::ndarray::Slice::from(1..2))
566                    .into_shape_with_order((output.shape()[0], output.shape()[3]))?
567                    .to_owned();
568                x = sequences.into_dyn();
569                states.push(state.into_dyn());
570            }
571        }
572
573        Ok((x, states))
574    }
575}
576impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Layer<F> for Seq2SeqEncoder<F> {
577    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
578        // This simplified version only returns the output sequences
579        // For the full state, use the dedicated forward method
580        let (output, _) = self.forward(input)?;
581        Ok(output)
582    }
583
584    fn backward(
585        &self,
586        input: &Array<F, IxDyn>,
587        grad_output: &Array<F, IxDyn>,
588    ) -> Result<Array<F, IxDyn>> {
589        // Seq2SeqEncoder backward: reverse the forward pass
590        // Note: This is a simplified implementation
591        // A complete implementation would need to:
592        // 1. Cache intermediate states from forward pass
593        // 2. Implement proper RNN backpropagation through time
594        // 3. Handle bidirectional gradients correctly
595        let mut grad = grad_output.clone();
596        // Backward through RNN layers in reverse order
597        for layer in self.rnn_layers.iter().rev() {
598            grad = layer.backward(&grad, &grad)?;
599        }
600        // Backward through dropout if training
601        // (Dropout backward is typically identity during inference)
602        // Backward through embedding
603        let grad_input = self.embedding.backward(input, &grad)?;
604        Ok(grad_input)
605    }
606
607    fn update(&mut self, learning_rate: F) -> Result<()> {
608        // Update embedding layer
609        self.embedding.update(learning_rate)?;
610        // Update all RNN layers
611        for layer in &mut self.rnn_layers {
612            layer.update(learning_rate)?;
613        }
614        // Note: Dropout doesn't have learnable parameters to update
615        Ok(())
616    }
617
618    fn as_any(&self) -> &dyn std::any::Any {
619        self
620    }
621
622    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
623        self
624    }
625
626    fn params(&self) -> Vec<Array<F, IxDyn>> {
627        let mut params = Vec::new();
628        params.extend(self.embedding.params());
629        for layer in &self.rnn_layers {
630            params.extend(layer.params());
631        }
632        if let Some(ref dropout) = self.dropout {
633            params.extend(dropout.params());
634        }
635        params
636    }
637
638    fn set_training(&mut self, training: bool) {
639        self.embedding.set_training(training);
640        for layer in &mut self.rnn_layers {
641            layer.set_training(training);
642        }
643        if let Some(ref mut dropout) = self.dropout {
644            dropout.set_training(training);
645        }
646    }
647
648    fn is_training(&self) -> bool {
649        self.embedding.is_training()
650    }
651}
652/// Decoder for Seq2Seq models
653pub struct Seq2SeqDecoder<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
654    /// Output embedding layer
655    pub embedding: Embedding<F>,
656    /// RNN layers
657    pub rnn_layers: Vec<Box<dyn Layer<F> + Send + Sync>>,
658    /// Dropout layer
659    pub dropout: Option<Dropout<F>>,
660    /// Attention mechanism (optional)
661    pub attention: Option<Attention<F>>,
662    /// Output projection layer
663    pub output_projection: Dense<F>,
664    /// Output vocabulary size
665    pub vocab_size: usize,
666    /// Hidden dimension
667    pub hidden_dim: usize,
668    /// RNN cell type
669    pub cell_type: RNNCellType,
670}
671
672impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Debug for Seq2SeqDecoder<F> {
673    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
674        f.debug_struct("Seq2SeqDecoder")
675            .field("embedding", &self.embedding)
676            .field(
677                "rnn_layers",
678                &format!("Vec<Box<dyn Layer>> (len: {})", self.rnn_layers.len()),
679            )
680            .field("dropout", &self.dropout)
681            .field("attention", &self.attention)
682            .field("output_projection", &self.output_projection)
683            .field("vocab_size", &self.vocab_size)
684            .field("hidden_dim", &self.hidden_dim)
685            .field("cell_type", &self.cell_type)
686            .finish()
687    }
688}
689
690impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Seq2SeqDecoder<F> {
691    /// Create a new Seq2SeqDecoder
692    #[allow(clippy::too_many_arguments)]
693    pub fn new(
694        vocab_size: usize,
695        embedding_dim: usize,
696        hidden_dim: usize,
697        num_layers: usize,
698        cell_type: RNNCellType,
699        use_attention: bool,
700        encoder_bidirectional: bool,
701        dropout_rate: Option<f64>,
702    ) -> Result<Self> {
703        // Create embedding layer
704        let embedding_config = EmbeddingConfig {
705            num_embeddings: vocab_size,
706            embedding_dim,
707            padding_idx: None,
708            max_norm: None,
709            norm_type: 2.0,
710            scale_grad_by_freq: false,
711        };
712        let embedding = Embedding::<F>::new(embedding_config)?;
713
714        // Create RNN layers
715        let mut rnn_layers: Vec<Box<dyn Layer<F> + Send + Sync>> = Vec::with_capacity(num_layers);
716        for i in 0..num_layers {
717            let input_size = if i == 0 { embedding_dim } else { hidden_dim };
718
719            // Create the appropriate RNN layer based on cell type
720            let rnn: Box<dyn Layer<F> + Send + Sync> = match cell_type {
721                RNNCellType::SimpleRNN | RNNCellType::LSTM | RNNCellType::GRU => {
722                    let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
723                    let rnn = ThreadSafeRNN::<F>::new(
724                        input_size,
725                        hidden_dim,
726                        ThreadSafeRecurrentActivation::Tanh,
727                        &mut rng,
728                    )?;
729                    Box::new(rnn)
730                }
731            };
732            rnn_layers.push(rnn);
733        }
734
735        // Create dropout layer if needed
736        let dropout = if let Some(rate) = dropout_rate {
737            if rate > 0.0 {
738                let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
739                Some(Dropout::<F>::new(rate, &mut rng)?)
740            } else {
741                None
742            }
743        } else {
744            None
745        };
746
747        // Create attention mechanism if needed
748        let attention = if use_attention {
749            let encoder_dim = if encoder_bidirectional {
750                hidden_dim * 2
751            } else {
752                hidden_dim
753            };
754            Some(Attention::<F>::new(
755                hidden_dim,
756                encoder_dim,
757                hidden_dim,
758                AttentionType::Additive,
759                encoder_bidirectional,
760            )?)
761        } else {
762            None
763        };
764
765        // Create output projection with activation function
766        let mut rng_clone = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
767        let output_projection = Dense::<F>::new(
768            hidden_dim,
769            vocab_size,
770            None, // No custom activation function
771            &mut rng_clone,
772        )?;
773
774        Ok(Self {
775            embedding,
776            rnn_layers,
777            dropout,
778            attention,
779            output_projection,
780            vocab_size,
781            hidden_dim,
782            cell_type,
783        })
784    }
785    /// Forward pass through the decoder (single step)
786    pub fn forward_step(
787        &self,
788        input_tokens: &Array<F, IxDyn>,
789        prev_states: &[Array<F, IxDyn>],
790        encoder_outputs: Option<&Array<F, IxDyn>>,
791    ) -> Result<AttentionOutput<F>> {
792        let mut x = self.embedding.forward(input_tokens)?;
793
794        // Process through RNN layers with initial states
795        let mut states_out = Vec::new();
796        for (i, layer) in self.rnn_layers.iter().enumerate() {
797            let prev_state = if i < prev_states.len() {
798                Some(&prev_states[i])
799            } else {
800                None
801            };
802
803            // Forward pass with initial state
804            let output = if let Some(state) = prev_state {
805                // Prepare initial state format expected by the RNN layer
806                let initial_state = state
807                    .to_owned()
808                    .into_shape_with_order((state.shape()[0], state.shape()[1]))?;
809                let x_dyn = x.to_owned().into_dyn();
810                let initial_state_dyn = initial_state.to_owned().into_dyn();
811                let combined_input = scirs2_core::ndarray::stack(
812                    Axis(1),
813                    &[x_dyn.view(), initial_state_dyn.view()],
814                )?;
815                layer.forward(&combined_input.to_owned().into_dyn())?
816            } else {
817                layer.forward(&x.to_owned().into_dyn())?
818            };
819
820            // Extract sequences (first element) and state (second element)
821            let sequences = output
822                .slice_axis(Axis(1), scirs2_core::ndarray::Slice::from(0..1))
823                .into_shape_with_order((output.shape()[0], output.shape()[2], output.shape()[3]))?
824                .to_owned();
825            let state = output
826                .slice_axis(Axis(1), scirs2_core::ndarray::Slice::from(1..2))
827                .into_shape_with_order((output.shape()[0], output.shape()[3]))?
828                .to_owned();
829            x = sequences.into_dyn();
830            states_out.push(state.into_dyn());
831        }
832        // Apply attention if available
833        let final_output = if let Some(ref attention) = self.attention {
834            if let Some(encoder_out) = encoder_outputs {
835                // Get the last RNN layer's output
836                let batch_size = x.shape()[0];
837                let hidden_size = x.shape()[2];
838                // Reshape to (batch_size, hidden_size)
839                let last_hidden = x.into_shape_with_order((batch_size, hidden_size))?;
840                // Apply attention
841                // Convert to IxDyn for compatibility with Layer trait
842                let dyn_last_hidden = last_hidden.to_owned().into_dyn();
843                let (attentional_hidden, _) = attention.forward(&dyn_last_hidden, encoder_out)?;
844                // Project to vocabulary size
845                self.output_projection.forward(&attentional_hidden)?
846            } else {
847                return Err(NeuralError::InvalidArchitecture(
848                    "Attention requires encoder outputs".to_string(),
849                ));
850            }
851        } else {
852            // Without attention, just project the last hidden state
853            let batch_size = x.shape()[0];
854            let hidden_size = x.shape()[2];
855            // Reshape to (batch_size, hidden_size)
856            let last_hidden = x.into_shape_with_order((batch_size, hidden_size))?;
857            // Project to vocabulary size - convert to IxDyn first
858            let dyn_last_hidden = last_hidden.to_owned().into_dyn();
859            self.output_projection.forward(&dyn_last_hidden)?
860        };
861
862        Ok((final_output, states_out))
863    }
864    /// Forward pass for decoding a complete sequence
865    pub fn forward_sequence(
866        &self,
867        input_tokens: &Array<F, IxDyn>,
868        initial_states: &[Array<F, IxDyn>],
869        encoder_outputs: Option<&Array<F, IxDyn>>,
870    ) -> Result<Array<F, IxDyn>> {
871        let batch_size = input_tokens.shape()[0];
872        let seq_len = input_tokens.shape()[1];
873        // Prepare output buffer
874        let mut outputs = Array::<F, _>::zeros((batch_size, seq_len, self.vocab_size));
875        let mut states = initial_states.to_vec();
876        // Process each time step
877        for t in 0..seq_len {
878            // Extract tokens for this time step
879            let tokens_t = input_tokens
880                .slice_axis(Axis(1), scirs2_core::ndarray::Slice::from(t..t + 1))
881                .to_owned()
882                .into_dyn();
883
884            // Decode one step
885            let (output_t, new_states) = self.forward_step(&tokens_t, &states, encoder_outputs)?;
886
887            // Store output
888            for b in 0..batch_size {
889                for v in 0..self.vocab_size {
890                    outputs[[b, t, v]] = output_t[[b, v]];
891                }
892            }
893
894            // Update states
895            states = new_states;
896        }
897
898        Ok(outputs.into_dyn())
899    }
900}
901impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Layer<F> for Seq2SeqDecoder<F> {
902    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
903        // This simplified version assumes:
904        // 1. Input contains decoder inputs only
905        // 2. Initial states are zero-initialized
906        // 3. No encoder outputs/attention
907        let batch_size = input.shape()[0];
908        let _seq_len = input.shape()[1]; // Not directly used but kept for clarity
909                                         // Initialize empty states
910        let mut initial_states = Vec::new();
911        for _ in 0..self.rnn_layers.len() {
912            let state = Array::<F, _>::zeros((batch_size, self.hidden_dim)).into_dyn();
913            initial_states.push(state);
914        }
915        // Forward sequence without encoder outputs
916        self.forward_sequence(input, &initial_states, None)
917    }
918
919    fn backward(
920        &self,
921        input: &Array<F, IxDyn>,
922        grad_output: &Array<F, IxDyn>,
923    ) -> Result<Array<F, IxDyn>> {
924        // Seq2SeqDecoder backward: reverse the forward pass
925        // Chain gradients through all subcomponents in reverse order
926        // Note: This is a simplified implementation for the Layer trait
927        // A complete implementation would need to:
928        // 1. Cache all intermediate values from forward pass
929        // 2. Properly backpropagate through the sequence generation process
930        // 3. Handle attention gradients correctly
931        // 4. Properly handle state propagation between timesteps
932        let mut grad = grad_output.clone();
933
934        // Backward through output projection
935        // Note: In a full implementation, we'd need the cached intermediate values
936        grad = self.output_projection.backward(&grad, &grad)?;
937
938        // Backward through attention if present
939        if let Some(ref attention) = self.attention {
940            // Note: Attention backward requires both decoder state and encoder outputs
941            // This is a simplified placeholder - full implementation needs proper caching
942            grad = attention.backward(&grad, &grad)?;
943        }
944
945        // Backward through RNN layers
946        for layer in self.rnn_layers.iter().rev() {
947            grad = layer.backward(&grad, &grad)?;
948        }
949
950        // Backward through dropout if training
951        if let Some(ref dropout) = self.dropout {
952            if self.is_training() {
953                grad = dropout.backward(&grad, &grad)?;
954            }
955        }
956
957        // Backward through embedding
958        let grad_input = self.embedding.backward(input, &grad)?;
959        Ok(grad_input)
960    }
961
962    fn update(&mut self, learning_rate: F) -> Result<()> {
963        // Update all learnable parameters in the decoder
964        // Update embedding
965        self.embedding.update(learning_rate)?;
966
967        // Update RNN layers
968        for layer in &mut self.rnn_layers {
969            layer.update(learning_rate)?;
970        }
971
972        // Update attention mechanism if present
973        if let Some(ref mut attention) = self.attention {
974            attention.update(learning_rate)?;
975        }
976
977        // Update output projection layer
978        self.output_projection.update(learning_rate)?;
979        Ok(())
980    }
981
982    fn as_any(&self) -> &dyn std::any::Any {
983        self
984    }
985
986    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
987        self
988    }
989
990    fn params(&self) -> Vec<Array<F, IxDyn>> {
991        let mut params = Vec::new();
992        params.extend(self.embedding.params());
993        for layer in &self.rnn_layers {
994            params.extend(layer.params());
995        }
996        if let Some(ref attention) = self.attention {
997            params.extend(attention.params());
998        }
999        params.extend(self.output_projection.params());
1000        params
1001    }
1002
1003    fn set_training(&mut self, training: bool) {
1004        self.embedding.set_training(training);
1005        for layer in &mut self.rnn_layers {
1006            layer.set_training(training);
1007        }
1008        if let Some(ref mut attention) = self.attention {
1009            attention.set_training(training);
1010        }
1011        self.output_projection.set_training(training);
1012    }
1013
1014    fn is_training(&self) -> bool {
1015        self.embedding.is_training()
1016    }
1017}
1018/// Sequence-to-Sequence (Seq2Seq) model
1019#[derive(Debug)]
1020pub struct Seq2Seq<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
1021    /// Encoder component
1022    pub encoder: Seq2SeqEncoder<F>,
1023    /// Decoder component
1024    pub decoder: Seq2SeqDecoder<F>,
1025    /// Model configuration
1026    pub config: Seq2SeqConfig,
1027}
1028
1029impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Seq2Seq<F> {
1030    /// Create a new Seq2Seq model
1031    pub fn new(config: Seq2SeqConfig) -> Result<Self> {
1032        // Create encoder
1033        let encoder = Seq2SeqEncoder::<F>::new(
1034            config.input_vocab_size,
1035            config.embedding_dim,
1036            config.hidden_dim,
1037            config.num_layers,
1038            config.encoder_cell_type,
1039            config.bidirectional_encoder,
1040            Some(config.dropout_rate),
1041        )?;
1042
1043        // Create decoder
1044        let decoder = Seq2SeqDecoder::<F>::new(
1045            config.output_vocab_size,
1046            config.embedding_dim,
1047            config.hidden_dim,
1048            config.num_layers,
1049            config.decoder_cell_type,
1050            config.use_attention,
1051            config.bidirectional_encoder,
1052            Some(config.dropout_rate),
1053        )?;
1054
1055        Ok(Self {
1056            encoder,
1057            decoder,
1058            config,
1059        })
1060    }
1061    /// Forward pass for training (teacher forcing)
1062    pub fn forward_train(
1063        &self,
1064        input_seq: &Array<F, IxDyn>,
1065        target_seq: &Array<F, IxDyn>,
1066    ) -> Result<Array<F, IxDyn>> {
1067        // Encode input sequence
1068        let (encoder_outputs, encoder_states) = self.encoder.forward(input_seq)?;
1069        // Prepare decoder initial states (use last encoder states)
1070        let decoder_initial_states = if self.config.encoder_cell_type
1071            == self.config.decoder_cell_type
1072        {
1073            // If cell types match, we can directly use encoder states
1074            encoder_states
1075        } else {
1076            // If cell types don't match, we need to project encoder states
1077            // For simplicity, we'll just zero-initialize decoder states
1078            let batch_size = input_seq.shape()[0];
1079            let mut initial_states = Vec::new();
1080            for _ in 0..self.config.num_layers {
1081                let state = Array::<F, _>::zeros((batch_size, self.config.hidden_dim)).into_dyn();
1082                initial_states.push(state);
1083            }
1084            initial_states
1085        };
1086
1087        // Decode target sequence with teacher forcing
1088        let decoder_output = self.decoder.forward_sequence(
1089            target_seq,
1090            &decoder_initial_states,
1091            Some(&encoder_outputs),
1092        )?;
1093
1094        Ok(decoder_output)
1095    }
1096    /// Generate sequences for inference
1097    pub fn generate(
1098        &self,
1099        input_seq: &Array<F, IxDyn>,
1100        max_length: Option<usize>,
1101        start_token_id: usize,
1102        end_token_id: Option<usize>,
1103    ) -> Result<Array<F, IxDyn>> {
1104        // Encode input sequence
1105        let (encoder_outputs, encoder_states) = self.encoder.forward(input_seq)?;
1106
1107        let batch_size = input_seq.shape()[0];
1108        let max_len = max_length.unwrap_or(self.config.max_seq_len);
1109
1110        // Prepare decoder initial states
1111        let decoder_states = if self.config.encoder_cell_type == self.config.decoder_cell_type {
1112            encoder_states
1113        } else {
1114            let mut initial_states = Vec::new();
1115            for _ in 0..self.config.num_layers {
1116                let state = Array::<F, _>::zeros((batch_size, self.config.hidden_dim)).into_dyn();
1117                initial_states.push(state);
1118            }
1119            initial_states
1120        };
1121
1122        // Initialize first decoder input with start tokens
1123        let mut decoder_input = Array::<F, _>::zeros((batch_size, 1));
1124        for b in 0..batch_size {
1125            decoder_input[[b, 0]] =
1126                F::from(start_token_id as f64).expect("Failed to convert to float");
1127        }
1128        let mut decoder_input = decoder_input.into_dyn();
1129        let mut output_ids = Array::<F, _>::zeros((batch_size, max_len));
1130        let mut states = decoder_states;
1131        // Keep track of completed sequences
1132        let mut completed = vec![false; batch_size];
1133        // Generate sequence
1134        for t in 0..max_len {
1135            let (output_t, new_states) =
1136                self.decoder
1137                    .forward_step(&decoder_input, &states, Some(&encoder_outputs))?;
1138            // Get most probable token
1139            let mut next_tokens = Array::<F, _>::zeros((batch_size, 1));
1140            for b in 0..batch_size {
1141                if completed[b] {
1142                    continue;
1143                }
1144
1145                // Find max probability token
1146                let mut max_prob = F::neg_infinity();
1147                let mut max_idx = 0;
1148                for v in 0..self.config.output_vocab_size {
1149                    if output_t[[b, v]] > max_prob {
1150                        max_prob = output_t[[b, v]];
1151                        max_idx = v;
1152                    }
1153                }
1154
1155                // Store predicted token
1156                output_ids[[b, t]] = F::from(max_idx as f64).expect("Failed to convert to float");
1157                next_tokens[[b, 0]] = F::from(max_idx as f64).expect("Failed to convert to float");
1158
1159                // Check if sequence is completed
1160                if let Some(eos_id) = end_token_id {
1161                    if max_idx == eos_id {
1162                        completed[b] = true;
1163                    }
1164                }
1165            }
1166
1167            // Early stopping if all sequences are completed
1168            if completed.iter().all(|&c| c) {
1169                break;
1170            }
1171
1172            // Update decoder input for next step
1173            decoder_input = next_tokens.into_dyn();
1174            states = new_states;
1175        }
1176
1177        Ok(output_ids.into_dyn())
1178    }
1179    /// Create a basic Seq2Seq model for machine translation
1180    pub fn create_translation_model(
1181        src_vocab_size: usize,
1182        tgt_vocab_size: usize,
1183        hidden_dim: usize,
1184    ) -> Result<Self> {
1185        let config = Seq2SeqConfig {
1186            input_vocab_size: src_vocab_size,
1187            output_vocab_size: tgt_vocab_size,
1188            embedding_dim: hidden_dim,
1189            hidden_dim,
1190            ..Default::default()
1191        };
1192        Self::new(config)
1193    }
1194    /// Create a small and fast Seq2Seq model
1195    pub fn create_small_model(src_vocab_size: usize, tgt_vocab_size: usize) -> Result<Self> {
1196        let config = Seq2SeqConfig {
1197            input_vocab_size: src_vocab_size,
1198            output_vocab_size: tgt_vocab_size,
1199            embedding_dim: 128,
1200            hidden_dim: 256,
1201            num_layers: 1,
1202            encoder_cell_type: RNNCellType::GRU,
1203            decoder_cell_type: RNNCellType::GRU,
1204            bidirectional_encoder: false,
1205            use_attention: false,
1206            dropout_rate: 0.0,
1207            max_seq_len: 50,
1208        };
1209        Self::new(config)
1210    }
1211}
1212impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Layer<F> for Seq2Seq<F> {
1213    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
1214        // This simplified forward assumes the input is only the encoder inputs
1215        // and automatically generates decoder outputs without teacher forcing
1216        self.generate(
1217            input,
1218            Some(self.config.max_seq_len),
1219            0, // Assuming 0 is the start token
1220            None,
1221        )
1222    }
1223
1224    fn backward(
1225        &self,
1226        input: &Array<F, IxDyn>,
1227        grad_output: &Array<F, IxDyn>,
1228    ) -> Result<Array<F, IxDyn>> {
1229        // Seq2Seq backward: chain gradients through encoder and decoder
1230        // Note: This simplified implementation assumes a basic forward pass
1231        // A complete implementation would need to:
1232        // 1. Cache all intermediate values (encoder outputs, decoder states, attention weights)
1233        // 2. Properly backpropagate through the sequence generation process
1234        // 3. Handle teacher forcing vs. autoregressive generation modes
1235        // 4. Implement proper gradient flow between encoder and decoder
1236        // For the simplified Layer trait implementation, we approximate the gradient flow
1237        // In practice, Seq2Seq training requires more sophisticated gradient computation
1238
1239        // Approximate decoder gradients (in reality, this depends on the specific forward mode)
1240        let decoder_grad = self.decoder.backward(input, grad_output)?;
1241
1242        // Approximate encoder gradients
1243        // Note: In teacher forcing mode, encoder gradients come through decoder attention
1244        // In generation mode, encoder gradients flow through all generated timesteps
1245        let encoder_grad = self.encoder.backward(input, &decoder_grad)?;
1246        Ok(encoder_grad)
1247    }
1248
1249    fn update(&mut self, learning_rate: F) -> Result<()> {
1250        // Update all learnable parameters in the Seq2Seq model
1251        // Update encoder parameters (embeddings, RNN layers, dropout)
1252        self.encoder.update(learning_rate)?;
1253        // Update decoder parameters (embeddings, RNN layers, attention, output projection, dropout)
1254        self.decoder.update(learning_rate)?;
1255        Ok(())
1256    }
1257
1258    fn as_any(&self) -> &dyn std::any::Any {
1259        self
1260    }
1261
1262    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
1263        self
1264    }
1265
1266    fn params(&self) -> Vec<Array<F, IxDyn>> {
1267        let mut params = Vec::new();
1268        params.extend(self.encoder.params());
1269        params.extend(self.decoder.params());
1270        params
1271    }
1272
1273    fn set_training(&mut self, training: bool) {
1274        self.encoder.set_training(training);
1275        self.decoder.set_training(training);
1276    }
1277
1278    fn is_training(&self) -> bool {
1279        self.encoder.is_training()
1280    }
1281}