scirs2_text/
transformer.rs

1//! # Transformer Architecture Module
2//!
3//! This module provides a complete implementation of the Transformer architecture,
4//! the foundation of modern language models like BERT, GPT, and T5. It includes
5//! all essential components for building state-of-the-art NLP models.
6//!
7//! ## Overview
8//!
9//! The Transformer architecture revolutionized natural language processing by
10//! introducing the self-attention mechanism. This module implements:
11//!
12//! - **Multi-Head Attention**: Core attention mechanism with multiple attention heads
13//! - **Positional Encoding**: Sinusoidal and learned position representations
14//! - **Encoder-Decoder Architecture**: Full transformer with both encoder and decoder stacks
15//! - **Layer Normalization**: Pre-norm and post-norm variants
16//! - **Feed-Forward Networks**: Position-wise fully connected layers
17//! - **Token Embeddings**: Learnable word and position embeddings
18//!
19//! ## Quick Start
20//!
21//! ```rust
22//! use scirs2_text::transformer::{TransformerModel, TransformerConfig};
23//!
24//! // Configure the transformer
25//! let config = TransformerConfig {
26//!     d_model: 512,           // Model dimension
27//!     nheads: 8,             // Number of attention heads
28//!     d_ff: 2048,             // Feed-forward dimension
29//!     n_encoder_layers: 6,    // Number of encoder layers
30//!     n_decoder_layers: 6,    // Number of decoder layers
31//!     max_seqlen: 512,       // Maximum sequence length
32//!     dropout: 0.1,           // Dropout rate
33//!     vocab_size: 10000,      // Vocabulary size
34//! };
35//!
36//! // Create the model
37//! let vocabulary = (0..10000).map(|i| format!("token_{}", i)).collect();
38//! let mut transformer = TransformerModel::new(config, vocabulary).unwrap();
39//!
40//! // Example input sequences (string tokens)
41//! let src_tokens = vec!["token_1".to_string(), "token_2".to_string(), "token_3".to_string()];
42//!
43//! // Encode the tokens
44//! let output = transformer.encode_tokens(&src_tokens).unwrap();
45//! println!("Model output shape: {:?}", output.shape());
46//! ```
47//!
48//! ## Building Individual Components
49//!
50//! ### Multi-Head Attention
51//!
52//! ```rust
53//! use scirs2_text::transformer::MultiHeadAttention;
54//! use scirs2_core::ndarray::Array2;
55//!
56//! let d_model = 512;
57//! let nheads = 8;
58//! let mut attention = MultiHeadAttention::new(d_model, nheads).unwrap();
59//!
60//! // Create dummy input (batch_size=2, seqlen=10, d_model=512)
61//! let input = Array2::zeros((10, 512));
62//! let output = attention.forward(input.view(), input.view(), input.view(), None).unwrap();
63//! ```
64//!
65//! ### Positional Encoding
66//!
67//! ```rust
68//! use scirs2_text::transformer::PositionalEncoding;
69//! use scirs2_core::ndarray::Array2;
70//!
71//! let d_model = 512;
72//! let max_len = 1000;
73//! let pos_encoding = PositionalEncoding::new(d_model, max_len);
74//!
75//! // Apply positional encoding to embeddings
76//! let seqlen = 20;
77//! let embeddings = Array2::<f64>::zeros((seqlen, d_model));
78//! let positional_encodings = pos_encoding.get_encoding(seqlen).unwrap();
79//! println!("Embeddings shape: {:?}", embeddings.shape());
80//! println!("Positional encodings shape: {:?}", positional_encodings.shape());
81//! ```
82//!
83//! ### Complete Encoder
84//!
85//! ```rust
86//! use scirs2_text::transformer::{TransformerEncoder, TransformerConfig};
87//! use scirs2_core::ndarray::Array2;
88//!
89//! let config = TransformerConfig {
90//!     d_model: 256,
91//!     nheads: 4,
92//!     d_ff: 1024,
93//!     n_encoder_layers: 3,
94//!     dropout: 0.1,
95//!     ..Default::default()
96//! };
97//!
98//! let encoder = TransformerEncoder::new(config).unwrap();
99//! let input = Array2::zeros((50, 256)); // (seqlen, d_model)
100//! let encoded = encoder.encode(input.view(), None).unwrap();
101//! ```
102//!
103//! ## Advanced Usage
104//!
105//! ### Custom Attention Patterns
106//!
107//! ```rust
108//! use scirs2_text::transformer::MultiHeadAttention;
109//! use scirs2_core::ndarray::Array2;
110//!
111//! let mut attention = MultiHeadAttention::new(512, 8).unwrap();
112//!
113//! // Create attention mask for autoregressive generation
114//! let seqlen = 10;
115//! let mut mask = Array2::from_elem((seqlen, seqlen), false);
116//! for i in 0..seqlen {
117//!     for j in (i+1)..seqlen {
118//!         mask[[i, j]] = true; // Mask future positions
119//!     }
120//! }
121//!
122//! let query = Array2::zeros((seqlen, 512));
123//! let key = Array2::zeros((seqlen, 512));
124//! let value = Array2::zeros((seqlen, 512));
125//! let output = attention.forward(query.view(), key.view(), value.view(), Some(mask.view())).unwrap();
126//! ```
127//!
128//! ### Layer-wise Learning Rate Decay
129//!
130//! ```rust
131//! use scirs2_text::transformer::{TransformerModel, TransformerConfig};
132//!
133//! # let config = TransformerConfig::default();
134//! # let vocabulary: Vec<String> = (0..config.vocab_size).map(|i| format!("token_{}", i)).collect();
135//! // Apply different learning rates to different layers  
136//! let mut model = TransformerModel::new(config, vocabulary).unwrap();
137//!
138//! // Typically: deeper layers get smaller learning rates
139//! let base_lr = 1e-4;
140//! // Note: Layer parameters would be accessed through training APIs
141//! println!("Base learning rate: {}", base_lr);
142//! ```
143//!
144//! ## Architecture Details
145//!
146//! ### Attention Mechanism
147//!
148//! The multi-head attention computes:
149//!
150//! ```text
151//! Attention(Q, K, V) = softmax(QK^T / √d_k)V
152//! MultiHead(Q, K, V) = Concat(head_1, ..., head_h)W^O
153//! where head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)
154//! ```
155//!
156//! ### Positional Encoding
157//!
158//! Uses sinusoidal functions to encode position information:
159//!
160//! ```text
161//! PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
162//! PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
163//! ```
164//!
165//! ### Layer Structure
166//!
167//! Each encoder/decoder layer follows the pattern:
168//!
169//! ```text
170//! x = LayerNorm(x + SelfAttention(x))
171//! x = LayerNorm(x + FeedForward(x))
172//! ```
173//!
174//! ## Performance Optimization
175//!
176//! 1. **Gradient Checkpointing**: Trade memory for computation in deep models
177//! 2. **Mixed Precision**: Use FP16 for faster training with minimal quality loss
178//! 3. **Key-Value Caching**: Cache attention keys and values during inference
179//! 4. **Attention Patterns**: Use sparse attention for very long sequences
180//! 5. **Model Parallelism**: Split large models across multiple GPUs
181//!
182//! ## Common Use Cases
183//!
184//! - **Machine Translation**: Encoder-decoder for seq2seq tasks
185//! - **Language Modeling**: Decoder-only for autoregressive generation
186//! - **Text Classification**: Encoder with classification head
187//! - **Question Answering**: Encoder with span prediction heads
188//! - **Text Summarization**: Encoder-decoder with copy mechanism
189//!
190//! ## Best Practices
191//!
192//! 1. **Warmup Learning Rate**: Start with small LR and gradually increase
193//! 2. **Layer Normalization**: Pre-norm generally works better than post-norm
194//! 3. **Residual Connections**: Essential for training deep networks
195//! 4. **Attention Dropout**: Apply dropout to attention weights, not just outputs
196//! 5. **Weight Initialization**: Use Xavier/Glorot initialization for stability
197
198use crate::error::{Result, TextError};
199use scirs2_core::ndarray::{s, Array1, Array2, Array3, ArrayView2};
200use scirs2_core::random::Rng;
201use statrs::statistics::Statistics;
202use std::collections::HashMap;
203
204/// Configuration for transformer models
205#[derive(Debug, Clone)]
206pub struct TransformerConfig {
207    /// Model dimension (embedding size)
208    pub d_model: usize,
209    /// Number of attention heads
210    pub nheads: usize,
211    /// Feed-forward network dimension
212    pub d_ff: usize,
213    /// Number of encoder layers
214    pub n_encoder_layers: usize,
215    /// Number of decoder layers
216    pub n_decoder_layers: usize,
217    /// Maximum sequence length
218    pub max_seqlen: usize,
219    /// Dropout rate
220    pub dropout: f64,
221    /// Vocabulary size
222    pub vocab_size: usize,
223}
224
225impl Default for TransformerConfig {
226    fn default() -> Self {
227        Self {
228            d_model: 512,
229            nheads: 8,
230            d_ff: 2048,
231            n_encoder_layers: 6,
232            n_decoder_layers: 6,
233            max_seqlen: 512,
234            dropout: 0.1,
235            vocab_size: 10000,
236        }
237    }
238}
239
240/// Position encoding for transformer models
241pub struct PositionalEncoding {
242    encodings: Array2<f64>,
243    max_len: usize,
244    #[allow(dead_code)]
245    d_model: usize,
246}
247
248impl PositionalEncoding {
249    /// Create new positional encoding
250    pub fn new(_max_len: usize, dmodel: usize) -> Self {
251        let mut encodings = Array2::<f64>::zeros((_max_len, dmodel));
252
253        for pos in 0.._max_len {
254            for i in (0..dmodel).step_by(2) {
255                let angle = pos as f64 / (10000.0_f64).powf(i as f64 / dmodel as f64);
256                encodings[[pos, i]] = angle.sin();
257                if i + 1 < dmodel {
258                    encodings[[pos, i + 1]] = angle.cos();
259                }
260            }
261        }
262
263        Self {
264            encodings,
265            max_len: _max_len,
266            d_model: dmodel,
267        }
268    }
269
270    /// Get position encoding for given sequence length
271    pub fn get_encoding(&self, seqlen: usize) -> Result<ArrayView2<f64>> {
272        if seqlen > self.max_len {
273            return Err(TextError::InvalidInput(format!(
274                "Sequence length {} exceeds maximum {}",
275                seqlen, self.max_len
276            )));
277        }
278        Ok(self.encodings.slice(s![0..seqlen, ..]))
279    }
280}
281
282/// Multi-head attention mechanism
283pub struct MultiHeadAttention {
284    d_model: usize,
285    nheads: usize,
286    d_k: usize,
287    w_q: Array2<f64>,
288    w_k: Array2<f64>,
289    w_v: Array2<f64>,
290    w_o: Array2<f64>,
291}
292
293impl MultiHeadAttention {
294    /// Create new multi-head attention layer
295    pub fn new(d_model: usize, nheads: usize) -> Result<Self> {
296        if !d_model.is_multiple_of(nheads) {
297            return Err(TextError::InvalidInput(
298                "d_model must be divisible by nheads".to_string(),
299            ));
300        }
301
302        let d_k = d_model / nheads;
303
304        // Initialize weight matrices with Xavier initialization
305        let scale = (2.0 / d_model as f64).sqrt();
306
307        let w_q = Array2::from_shape_fn((d_model, d_model), |_| {
308            scirs2_core::random::rng().random_range(-scale..scale)
309        });
310        let w_k = Array2::from_shape_fn((d_model, d_model), |_| {
311            scirs2_core::random::rng().random_range(-scale..scale)
312        });
313        let w_v = Array2::from_shape_fn((d_model, d_model), |_| {
314            scirs2_core::random::rng().random_range(-scale..scale)
315        });
316        let w_o = Array2::from_shape_fn((d_model, d_model), |_| {
317            scirs2_core::random::rng().random_range(-scale..scale)
318        });
319
320        Ok(Self {
321            d_model,
322            nheads,
323            d_k,
324            w_q,
325            w_k,
326            w_v,
327            w_o,
328        })
329    }
330
331    /// Compute scaled dot-product attention
332    fn scaled_dot_product_attention(
333        &self,
334        q: ArrayView2<f64>,
335        k: ArrayView2<f64>,
336        v: ArrayView2<f64>,
337        mask: Option<ArrayView2<bool>>,
338    ) -> Result<Array2<f64>> {
339        let d_k = self.d_k as f64;
340
341        // Compute attention scores: Q * K^T / sqrt(d_k)
342        let scores = q.dot(&k.t()) / d_k.sqrt();
343
344        // Apply mask if provided
345        let mut masked_scores = scores;
346        if let Some(mask) = mask {
347            for ((i, j), &should_mask) in mask.indexed_iter() {
348                if should_mask {
349                    masked_scores[[i, j]] = f64::NEG_INFINITY;
350                }
351            }
352        }
353
354        // Apply softmax
355        let attention_weights = self.softmax_2d(&masked_scores)?;
356
357        // Apply attention to values
358        Ok(attention_weights.dot(&v))
359    }
360
361    /// Apply softmax to 2D array along last axis
362    fn softmax_2d(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
363        let mut result = x.clone();
364
365        for mut row in result.rows_mut() {
366            let max_val = row.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
367            row.mapv_inplace(|x| (x - max_val).exp());
368            let sum: f64 = row.sum();
369            if sum > 0.0 {
370                row /= sum;
371            }
372        }
373
374        Ok(result)
375    }
376
377    /// Forward pass through multi-head attention
378    pub fn forward(
379        &self,
380        query: ArrayView2<f64>,
381        key: ArrayView2<f64>,
382        value: ArrayView2<f64>,
383        mask: Option<ArrayView2<bool>>,
384    ) -> Result<Array2<f64>> {
385        let _seqlen = query.shape()[0];
386
387        // Linear projections
388        let q = query.dot(&self.w_q);
389        let k = key.dot(&self.w_k);
390        let v = value.dot(&self.w_v);
391
392        // Reshape for multi-head attention
393        let q_heads = self.reshape_for_heads(&q)?;
394        let k_heads = self.reshape_for_heads(&k)?;
395        let v_heads = self.reshape_for_heads(&v)?;
396
397        // Apply attention for each head
398        let mut head_outputs = Vec::new();
399        for head in 0..self.nheads {
400            let q_head = q_heads.slice(s![head, .., ..]);
401            let k_head = k_heads.slice(s![head, .., ..]);
402            let v_head = v_heads.slice(s![head, .., ..]);
403
404            let head_output = self.scaled_dot_product_attention(q_head, k_head, v_head, mask)?;
405            head_outputs.push(head_output);
406        }
407
408        // Concatenate heads
409        let concatenated = self.concatenate_heads(&head_outputs)?;
410
411        // Final linear projection
412        Ok(concatenated.dot(&self.w_o))
413    }
414
415    /// Reshape tensor for multi-head attention
416    fn reshape_for_heads(&self, x: &Array2<f64>) -> Result<Array3<f64>> {
417        let (seqlen, d_model) = x.dim();
418        let reshaped = x
419            .clone()
420            .into_shape_with_order((seqlen, self.nheads, self.d_k))
421            .map_err(|e| TextError::InvalidInput(format!("Reshape error: {e}")))?;
422
423        // Transpose to (nheads, seqlen, d_k)
424        Ok(reshaped.permuted_axes([1, 0, 2]))
425    }
426
427    /// Concatenate attention heads
428    fn concatenate_heads(&self, heads: &[Array2<f64>]) -> Result<Array2<f64>> {
429        if heads.is_empty() {
430            return Err(TextError::InvalidInput("No heads provided".to_string()));
431        }
432
433        let seqlen = heads[0].shape()[0];
434        let mut result = Array2::zeros((seqlen, self.d_model));
435
436        for (i, head) in heads.iter().enumerate() {
437            let start_col = i * self.d_k;
438            let end_col = start_col + self.d_k;
439            result.slice_mut(s![.., start_col..end_col]).assign(head);
440        }
441
442        Ok(result)
443    }
444
445    /// Get attention weight matrices for serialization
446    pub fn get_weights(&self) -> (&Array2<f64>, &Array2<f64>, &Array2<f64>, &Array2<f64>) {
447        (&self.w_q, &self.w_k, &self.w_v, &self.w_o)
448    }
449
450    /// Set attention weight matrices from loaded weights
451    pub fn set_weights(
452        &mut self,
453        w_q: Array2<f64>,
454        w_k: Array2<f64>,
455        w_v: Array2<f64>,
456        w_o: Array2<f64>,
457    ) -> Result<()> {
458        if w_q.shape() != [self.d_model, self.d_model] {
459            return Err(TextError::InvalidInput("Invalid w_q shape".to_string()));
460        }
461        if w_k.shape() != [self.d_model, self.d_model] {
462            return Err(TextError::InvalidInput("Invalid w_k shape".to_string()));
463        }
464        if w_v.shape() != [self.d_model, self.d_model] {
465            return Err(TextError::InvalidInput("Invalid w_v shape".to_string()));
466        }
467        if w_o.shape() != [self.d_model, self.d_model] {
468            return Err(TextError::InvalidInput("Invalid w_o shape".to_string()));
469        }
470
471        self.w_q = w_q;
472        self.w_k = w_k;
473        self.w_v = w_v;
474        self.w_o = w_o;
475        Ok(())
476    }
477}
478
479/// Feed-forward network layer
480pub struct FeedForward {
481    w1: Array2<f64>,
482    w2: Array2<f64>,
483    b1: Array1<f64>,
484    b2: Array1<f64>,
485}
486
487impl FeedForward {
488    /// Create new feed-forward layer
489    pub fn new(_dmodel: usize, dff: usize) -> Self {
490        let scale = (2.0 / _dmodel as f64).sqrt();
491
492        let w1 = Array2::from_shape_fn((_dmodel, dff), |_| {
493            scirs2_core::random::rng().random_range(-scale..scale)
494        });
495        let w2 = Array2::from_shape_fn((dff, _dmodel), |_| {
496            scirs2_core::random::rng().random_range(-scale..scale)
497        });
498        let b1 = Array1::zeros(dff);
499        let b2 = Array1::zeros(_dmodel);
500
501        Self { w1, w2, b1, b2 }
502    }
503
504    /// Forward pass through feed-forward network
505    pub fn forward(&self, x: ArrayView2<f64>) -> Array2<f64> {
506        // First linear transformation + ReLU
507        let hidden = x.dot(&self.w1) + &self.b1;
508        let activated = hidden.mapv(|x| x.max(0.0)); // ReLU activation
509
510        // Second linear transformation
511        activated.dot(&self.w2) + &self.b2
512    }
513
514    /// Get feed-forward weight matrices for serialization
515    pub fn get_weights(&self) -> (&Array2<f64>, &Array2<f64>, &Array1<f64>, &Array1<f64>) {
516        (&self.w1, &self.w2, &self.b1, &self.b2)
517    }
518
519    /// Set feed-forward weight matrices from loaded weights
520    pub fn set_weights(
521        &mut self,
522        w1: Array2<f64>,
523        w2: Array2<f64>,
524        b1: Array1<f64>,
525        b2: Array1<f64>,
526    ) -> Result<()> {
527        if w1.shape()[1] != w2.shape()[0] {
528            return Err(TextError::InvalidInput(
529                "Weight matrix dimensions don't match".to_string(),
530            ));
531        }
532        if b1.len() != w1.shape()[1] {
533            return Err(TextError::InvalidInput(
534                "Bias b1 size doesn't match w1".to_string(),
535            ));
536        }
537        if b2.len() != w2.shape()[1] {
538            return Err(TextError::InvalidInput(
539                "Bias b2 size doesn't match w2".to_string(),
540            ));
541        }
542
543        self.w1 = w1;
544        self.w2 = w2;
545        self.b1 = b1;
546        self.b2 = b2;
547        Ok(())
548    }
549}
550
551/// Layer normalization
552pub struct LayerNorm {
553    gamma: Array1<f64>,
554    beta: Array1<f64>,
555    eps: f64,
556}
557
558impl LayerNorm {
559    /// Create new layer normalization
560    pub fn new(_dmodel: usize, eps: f64) -> Self {
561        Self {
562            gamma: Array1::ones(_dmodel),
563            beta: Array1::zeros(_dmodel),
564            eps,
565        }
566    }
567
568    /// Apply layer normalization
569    pub fn forward(&self, x: ArrayView2<f64>) -> Array2<f64> {
570        let mut result = Array2::zeros(x.raw_dim());
571
572        for (i, row) in x.rows().into_iter().enumerate() {
573            let mean = row.mean();
574            let var = row.mapv(|x| (x - mean).powi(2)).mean();
575            let std = (var + self.eps).sqrt();
576
577            let normalized = row.mapv(|x| (x - mean) / std);
578            let scaled = &normalized * &self.gamma + &self.beta;
579
580            result.row_mut(i).assign(&scaled);
581        }
582
583        result
584    }
585
586    /// Get layer normalization parameters for serialization
587    pub fn get_params(&self) -> (&Array1<f64>, &Array1<f64>) {
588        (&self.gamma, &self.beta)
589    }
590
591    /// Set layer normalization parameters from loaded weights
592    pub fn set_params(&mut self, gamma: Array1<f64>, beta: Array1<f64>) -> Result<()> {
593        if gamma.len() != beta.len() {
594            return Err(TextError::InvalidInput(
595                "Gamma and beta must have same length".to_string(),
596            ));
597        }
598        if gamma.len() != self.gamma.len() {
599            return Err(TextError::InvalidInput(
600                "Parameter size doesn't match layer dimension".to_string(),
601            ));
602        }
603
604        self.gamma = gamma;
605        self.beta = beta;
606        Ok(())
607    }
608}
609
610/// Transformer encoder layer
611pub struct TransformerEncoderLayer {
612    self_attention: MultiHeadAttention,
613    feed_forward: FeedForward,
614    norm1: LayerNorm,
615    norm2: LayerNorm,
616    #[allow(dead_code)]
617    dropout: f64,
618}
619
620impl TransformerEncoderLayer {
621    /// Create new transformer encoder layer
622    pub fn new(config: &TransformerConfig) -> Result<Self> {
623        Ok(Self {
624            self_attention: MultiHeadAttention::new(config.d_model, config.nheads)?,
625            feed_forward: FeedForward::new(config.d_model, config.d_ff),
626            norm1: LayerNorm::new(config.d_model, 1e-6),
627            norm2: LayerNorm::new(config.d_model, 1e-6),
628            dropout: config.dropout,
629        })
630    }
631
632    /// Forward pass through encoder layer
633    pub fn forward(
634        &self,
635        x: ArrayView2<f64>,
636        mask: Option<ArrayView2<bool>>,
637    ) -> Result<Array2<f64>> {
638        // Self-attention with residual connection and layer norm
639        let attn_output = self.self_attention.forward(x, x, x, mask)?;
640        let x = &self.norm1.forward(x) + &attn_output;
641
642        // Feed-forward with residual connection and layer norm
643        let ff_output = self.feed_forward.forward(x.view());
644        let output = &self.norm2.forward(x.view()) + &ff_output;
645
646        Ok(output)
647    }
648
649    /// Get mutable access to layer components for weight loading
650    pub fn get_components_mut(
651        &mut self,
652    ) -> (
653        &mut MultiHeadAttention,
654        &mut FeedForward,
655        &mut LayerNorm,
656        &mut LayerNorm,
657    ) {
658        (
659            &mut self.self_attention,
660            &mut self.feed_forward,
661            &mut self.norm1,
662            &mut self.norm2,
663        )
664    }
665
666    /// Get access to layer components for weight access
667    pub fn get_components(&self) -> (&MultiHeadAttention, &FeedForward, &LayerNorm, &LayerNorm) {
668        (
669            &self.self_attention,
670            &self.feed_forward,
671            &self.norm1,
672            &self.norm2,
673        )
674    }
675}
676
677/// Complete transformer encoder
678pub struct TransformerEncoder {
679    layers: Vec<TransformerEncoderLayer>,
680    position_encoding: PositionalEncoding,
681    config: TransformerConfig,
682}
683
684impl TransformerEncoder {
685    /// Create new transformer encoder
686    pub fn new(config: TransformerConfig) -> Result<Self> {
687        let mut layers = Vec::new();
688        for _ in 0..config.n_encoder_layers {
689            layers.push(TransformerEncoderLayer::new(&config)?);
690        }
691
692        let position_encoding = PositionalEncoding::new(config.max_seqlen, config.d_model);
693
694        Ok(Self {
695            layers,
696            position_encoding,
697            config,
698        })
699    }
700
701    /// Encode input sequence
702    pub fn encode(
703        &self,
704        embeddings: ArrayView2<f64>,
705        mask: Option<ArrayView2<bool>>,
706    ) -> Result<Array2<f64>> {
707        let seqlen = embeddings.shape()[0];
708
709        // Add positional encoding
710        let pos_enc = self.position_encoding.get_encoding(seqlen)?;
711        let mut x = embeddings.to_owned() + pos_enc;
712
713        // Pass through encoder layers
714        for layer in &self.layers {
715            x = layer.forward(x.view(), mask)?;
716        }
717
718        Ok(x)
719    }
720
721    /// Get configuration
722    pub fn config(&self) -> &TransformerConfig {
723        &self.config
724    }
725
726    /// Get mutable access to encoder layers for weight loading
727    pub fn get_layers_mut(&mut self) -> &mut Vec<TransformerEncoderLayer> {
728        &mut self.layers
729    }
730
731    /// Get access to encoder layers for weight access
732    pub fn get_layers(&self) -> &Vec<TransformerEncoderLayer> {
733        &self.layers
734    }
735}
736
737/// Transformer decoder layer with self-attention, cross-attention, and feed-forward
738pub struct TransformerDecoderLayer {
739    self_attention: MultiHeadAttention,
740    cross_attention: MultiHeadAttention,
741    feed_forward: FeedForward,
742    norm1: LayerNorm,
743    norm2: LayerNorm,
744    norm3: LayerNorm,
745    #[allow(dead_code)]
746    dropout: f64,
747}
748
749impl TransformerDecoderLayer {
750    /// Create new decoder layer
751    pub fn new(config: &TransformerConfig) -> Result<Self> {
752        Ok(Self {
753            self_attention: MultiHeadAttention::new(config.d_model, config.nheads)?,
754            cross_attention: MultiHeadAttention::new(config.d_model, config.nheads)?,
755            feed_forward: FeedForward::new(config.d_model, config.d_ff),
756            norm1: LayerNorm::new(config.d_model, 1e-6),
757            norm2: LayerNorm::new(config.d_model, 1e-6),
758            norm3: LayerNorm::new(config.d_model, 1e-6),
759            dropout: config.dropout,
760        })
761    }
762
763    /// Forward pass with encoder output for cross-attention
764    pub fn forward(
765        &self,
766        x: ArrayView2<f64>,
767        encoder_output: ArrayView2<f64>,
768        self_attn_mask: Option<ArrayView2<bool>>,
769        cross_attn_mask: Option<ArrayView2<bool>>,
770    ) -> Result<Array2<f64>> {
771        // Self-attention with residual connection and layer norm
772        let self_attn_out = self.self_attention.forward(x, x, x, self_attn_mask)?;
773        let x = self.norm1.forward((x.to_owned() + self_attn_out).view());
774
775        // Cross-attention with encoder _output
776        let cross_attn_out = self.cross_attention.forward(
777            x.view(),
778            encoder_output,
779            encoder_output,
780            cross_attn_mask,
781        )?;
782        let x = self.norm2.forward((x + cross_attn_out).view());
783
784        // Feed-forward with residual connection and layer norm
785        let ff_out = self.feed_forward.forward(x.view());
786        let _output = self.norm3.forward((x + ff_out).view());
787
788        Ok(_output)
789    }
790}
791
792/// Transformer decoder stack
793pub struct TransformerDecoder {
794    layers: Vec<TransformerDecoderLayer>,
795    position_encoding: PositionalEncoding,
796    config: TransformerConfig,
797}
798
799impl TransformerDecoder {
800    /// Create new decoder
801    pub fn new(config: TransformerConfig) -> Result<Self> {
802        let mut layers = Vec::new();
803        for _ in 0..config.n_decoder_layers {
804            layers.push(TransformerDecoderLayer::new(&config)?);
805        }
806
807        let position_encoding = PositionalEncoding::new(config.max_seqlen, config.d_model);
808
809        Ok(Self {
810            layers,
811            position_encoding,
812            config,
813        })
814    }
815
816    /// Forward pass through decoder
817    pub fn forward(
818        &self,
819        embeddings: ArrayView2<f64>,
820        encoder_output: ArrayView2<f64>,
821        self_attn_mask: Option<ArrayView2<bool>>,
822        cross_attn_mask: Option<ArrayView2<bool>>,
823    ) -> Result<Array2<f64>> {
824        let seqlen = embeddings.shape()[0];
825
826        // Add positional encoding
827        let pos_enc = self.position_encoding.get_encoding(seqlen)?;
828        let mut x = embeddings.to_owned() + pos_enc;
829
830        // Pass through decoder layers
831        for layer in &self.layers {
832            x = layer.forward(x.view(), encoder_output, self_attn_mask, cross_attn_mask)?;
833        }
834
835        Ok(x)
836    }
837
838    /// Get configuration
839    pub fn config(&self) -> &TransformerConfig {
840        &self.config
841    }
842}
843
844/// Token embedding layer
845pub struct TokenEmbedding {
846    embeddings: Array2<f64>,
847    vocab_size: usize,
848    d_model: usize,
849}
850
851impl TokenEmbedding {
852    /// Create new token embedding layer
853    pub fn new(_vocab_size: usize, dmodel: usize) -> Self {
854        let scale = (1.0 / dmodel as f64).sqrt();
855        let embeddings = Array2::from_shape_fn((_vocab_size, dmodel), |_| {
856            scirs2_core::random::rng().random_range(-scale..scale)
857        });
858
859        Self {
860            embeddings,
861            vocab_size: _vocab_size,
862            d_model: dmodel,
863        }
864    }
865
866    /// Get embeddings for token IDs
867    pub fn forward(&self, tokenids: &[usize]) -> Result<Array2<f64>> {
868        let mut result = Array2::zeros((tokenids.len(), self.d_model));
869
870        for (i, &token_id) in tokenids.iter().enumerate() {
871            if token_id >= self.vocab_size {
872                return Err(TextError::InvalidInput(format!(
873                    "Token ID {} exceeds vocabulary size {}",
874                    token_id, self.vocab_size
875                )));
876            }
877            result.row_mut(i).assign(&self.embeddings.row(token_id));
878        }
879
880        Ok(result)
881    }
882
883    /// Get access to the embedding matrix for serialization
884    pub fn get_embeddings(&self) -> &Array2<f64> {
885        &self.embeddings
886    }
887
888    /// Set the embedding matrix from loaded weights
889    pub fn set_embeddings(&mut self, embeddings: Array2<f64>) -> Result<()> {
890        if embeddings.shape()[0] != self.vocab_size || embeddings.shape()[1] != self.d_model {
891            return Err(TextError::InvalidInput(format!(
892                "Embedding shape {:?} doesn't match expected ({}, {})",
893                embeddings.shape(),
894                self.vocab_size,
895                self.d_model
896            )));
897        }
898        self.embeddings = embeddings;
899        Ok(())
900    }
901}
902
903/// Complete transformer model for text processing
904pub struct TransformerModel {
905    /// Model configuration
906    pub config: TransformerConfig,
907    /// Token embedding layer
908    pub token_embedding: TokenEmbedding,
909    /// Transformer encoder
910    pub encoder: TransformerEncoder,
911    /// Optional transformer decoder
912    pub decoder: Option<TransformerDecoder>,
913    vocab_to_id: HashMap<String, usize>,
914    id_to_vocab: HashMap<usize, String>,
915}
916
917impl TransformerModel {
918    /// Create new transformer model
919    pub fn new(config: TransformerConfig, vocabulary: Vec<String>) -> Result<Self> {
920        let vocab_size = vocabulary.len();
921        if vocab_size != config.vocab_size {
922            return Err(TextError::InvalidInput(format!(
923                "Vocabulary size {} doesn't match config {}",
924                vocab_size, config.vocab_size
925            )));
926        }
927
928        let mut vocab_to_id = HashMap::new();
929        let mut id_to_vocab = HashMap::new();
930
931        for (id, token) in vocabulary.into_iter().enumerate() {
932            vocab_to_id.insert(token.clone(), id);
933            id_to_vocab.insert(id, token);
934        }
935
936        Ok(Self {
937            config: config.clone(),
938            token_embedding: TokenEmbedding::new(config.vocab_size, config.d_model),
939            encoder: TransformerEncoder::new(config)?,
940            decoder: None, // Encoder-only model
941            vocab_to_id,
942            id_to_vocab,
943        })
944    }
945
946    /// Encode text tokens to contextual embeddings
947    pub fn encode_tokens(&self, tokens: &[String]) -> Result<Array2<f64>> {
948        // Convert tokens to IDs
949        let tokenids: Result<Vec<usize>> = tokens
950            .iter()
951            .map(|token| {
952                self.vocab_to_id
953                    .get(token)
954                    .cloned()
955                    .ok_or_else(|| TextError::InvalidInput(format!("Unknown token: {token}")))
956            })
957            .collect();
958        let tokenids = tokenids?;
959
960        // Get token embeddings
961        let embeddings = self.token_embedding.forward(&tokenids)?;
962
963        // Encode with transformer
964        self.encoder.encode(embeddings.view(), None)
965    }
966
967    /// Create new encoder-decoder transformer model
968    pub fn new_encoder_decoder(config: TransformerConfig, vocabulary: Vec<String>) -> Result<Self> {
969        let vocab_size = vocabulary.len();
970        if vocab_size != config.vocab_size {
971            return Err(TextError::InvalidInput(format!(
972                "Vocabulary size {} doesn't match config {}",
973                vocab_size, config.vocab_size
974            )));
975        }
976
977        let mut vocab_to_id = HashMap::new();
978        let mut id_to_vocab = HashMap::new();
979
980        for (id, token) in vocabulary.into_iter().enumerate() {
981            vocab_to_id.insert(token.clone(), id);
982            id_to_vocab.insert(id, token);
983        }
984
985        Ok(Self {
986            config: config.clone(),
987            token_embedding: TokenEmbedding::new(config.vocab_size, config.d_model),
988            encoder: TransformerEncoder::new(config.clone())?,
989            decoder: Some(TransformerDecoder::new(config)?),
990            vocab_to_id,
991            id_to_vocab,
992        })
993    }
994
995    /// Perform encoder-decoder forward pass
996    pub fn encode_decode(
997        &self,
998        input_tokens: &[String],
999        target_tokens: &[String],
1000    ) -> Result<Array2<f64>> {
1001        let decoder = self
1002            .decoder
1003            .as_ref()
1004            .ok_or_else(|| TextError::InvalidInput("Model has no decoder".to_string()))?;
1005
1006        // Encode input
1007        let encoder_output = self.encode_tokens(input_tokens)?;
1008
1009        // Convert target _tokens to IDs and embeddings
1010        let target_ids: Result<Vec<usize>> = target_tokens
1011            .iter()
1012            .map(|token| {
1013                self.vocab_to_id
1014                    .get(token)
1015                    .copied()
1016                    .ok_or_else(|| TextError::InvalidInput(format!("Unknown token: {token}")))
1017            })
1018            .collect();
1019        let target_ids = target_ids?;
1020
1021        let target_embeddings = self.token_embedding.forward(&target_ids)?;
1022
1023        // Generate causal mask for decoder self-attention
1024        let seqlen = target_tokens.len();
1025        let mut causal_mask = Array2::from_elem((seqlen, seqlen), false);
1026        for i in 0..seqlen {
1027            for j in (i + 1)..seqlen {
1028                causal_mask[[i, j]] = true; // Mask future positions
1029            }
1030        }
1031
1032        // Decode
1033        decoder.forward(
1034            target_embeddings.view(),
1035            encoder_output.view(),
1036            Some(causal_mask.view()),
1037            None,
1038        )
1039    }
1040
1041    /// Generate text using the decoder (for generation tasks)
1042    pub fn generate(
1043        &self,
1044        input_tokens: &[String],
1045        max_length: usize,
1046        start_token: &str,
1047    ) -> Result<Vec<String>> {
1048        let decoder = self
1049            .decoder
1050            .as_ref()
1051            .ok_or_else(|| TextError::InvalidInput("Model has no decoder".to_string()))?;
1052
1053        // Encode input
1054        let encoder_output = self.encode_tokens(input_tokens)?;
1055
1056        // Start with the start _token
1057        let mut generated_tokens = vec![start_token.to_string()];
1058
1059        for _ in 0..max_length {
1060            // Convert current _tokens to embeddings
1061            let current_ids: Result<Vec<usize>> = generated_tokens
1062                .iter()
1063                .map(|_token| {
1064                    self.vocab_to_id
1065                        .get(_token)
1066                        .copied()
1067                        .ok_or_else(|| TextError::InvalidInput(format!("Unknown token: {_token}")))
1068                })
1069                .collect();
1070            let current_ids = current_ids?;
1071
1072            let current_embeddings = self.token_embedding.forward(&current_ids)?;
1073
1074            // Generate causal mask
1075            let seqlen = generated_tokens.len();
1076            let mut causal_mask = Array2::from_elem((seqlen, seqlen), false);
1077            for i in 0..seqlen {
1078                for j in (i + 1)..seqlen {
1079                    causal_mask[[i, j]] = true;
1080                }
1081            }
1082
1083            // Decode
1084            let decoder_output = decoder.forward(
1085                current_embeddings.view(),
1086                encoder_output.view(),
1087                Some(causal_mask.view()),
1088                None,
1089            )?;
1090
1091            // Get the last timestep output
1092            let last_output = decoder_output.row(decoder_output.nrows() - 1);
1093
1094            // Simple greedy selection (find _token with highest logit)
1095            let mut best_token_id = 0;
1096            let mut best_score = last_output[0];
1097            for (i, &score) in last_output.iter().enumerate() {
1098                if score > best_score {
1099                    best_score = score;
1100                    best_token_id = i;
1101                }
1102            }
1103
1104            // Convert _token ID back to string
1105            if let Some(_token) = self.id_to_vocab.get(&best_token_id) {
1106                generated_tokens.push(_token.clone());
1107
1108                // Stop if we hit an end _token (you might want to customize this)
1109                if _token == "</s>" || _token == "<eos>" {
1110                    break;
1111                }
1112            } else {
1113                break;
1114            }
1115        }
1116
1117        Ok(generated_tokens)
1118    }
1119
1120    /// Get vocabulary mapping
1121    pub fn vocabulary(&self) -> (&HashMap<String, usize>, &HashMap<usize, String>) {
1122        (&self.vocab_to_id, &self.id_to_vocab)
1123    }
1124}
1125
1126#[cfg(test)]
1127mod tests {
1128    use super::*;
1129
1130    #[test]
1131    fn test_positional_encoding() {
1132        let pos_enc = PositionalEncoding::new(10, 4);
1133        let encoding = pos_enc.get_encoding(5).unwrap();
1134        assert_eq!(encoding.shape(), &[5, 4]);
1135
1136        // Test that positions are different
1137        let pos0 = encoding.row(0);
1138        let pos1 = encoding.row(1);
1139        assert!(pos0
1140            .iter()
1141            .zip(pos1.iter())
1142            .any(|(a, b)| (a - b).abs() > 1e-6));
1143    }
1144
1145    #[test]
1146    fn test_multi_head_attention() {
1147        let mha = MultiHeadAttention::new(8, 2).unwrap();
1148        let seqlen = 4;
1149        let d_model = 8;
1150
1151        let input = Array2::ones((seqlen, d_model));
1152        let output = mha
1153            .forward(input.view(), input.view(), input.view(), None)
1154            .unwrap();
1155
1156        assert_eq!(output.shape(), &[seqlen, d_model]);
1157    }
1158
1159    #[test]
1160    fn test_transformer_encoder() {
1161        let config = TransformerConfig {
1162            d_model: 8,
1163            nheads: 2,
1164            d_ff: 16,
1165            n_encoder_layers: 2,
1166            ..Default::default()
1167        };
1168
1169        let encoder = TransformerEncoder::new(config).unwrap();
1170        let input = Array2::ones((4, 8));
1171        let output = encoder.encode(input.view(), None).unwrap();
1172
1173        assert_eq!(output.shape(), &[4, 8]);
1174    }
1175}