Skip to main content

scirs2_text/neural_nlp/
transformer_encoder.rs

1//! Pure-ndarray multi-head self-attention transformer encoder for text.
2//!
3//! Implements a minimal BERT-style encoder: embedding table + sinusoidal
4//! position encoding + N × (pre-norm MHA + pre-norm FFN) layers using `f32`.
5
6use crate::error::{Result, TextError};
7use scirs2_core::ndarray::{s, Array1, Array2, Array3, Axis};
8
9// ─── Configuration ────────────────────────────────────────────────────────────
10
11/// Configuration for [`TransformerTextEncoder`].
12#[derive(Debug, Clone)]
13pub struct TransformerEncoderConfig {
14    /// Vocabulary size (number of distinct token IDs).
15    pub vocab_size: usize,
16    /// Dimensionality of token + position embeddings.
17    pub hidden_size: usize,
18    /// Number of attention heads.
19    pub num_heads: usize,
20    /// Number of encoder layers.
21    pub num_layers: usize,
22    /// Maximum sequence length supported.
23    pub max_seq_len: usize,
24    /// Dropout probability (applied during training; unused at inference).
25    pub dropout: f32,
26    /// PRNG seed for weight initialisation.
27    pub seed: u64,
28}
29
30impl Default for TransformerEncoderConfig {
31    fn default() -> Self {
32        Self {
33            vocab_size: 30000,
34            hidden_size: 256,
35            num_heads: 4,
36            num_layers: 2,
37            max_seq_len: 512,
38            dropout: 0.1,
39            seed: 42,
40        }
41    }
42}
43
44// ─── Attention Layer ──────────────────────────────────────────────────────────
45
46/// Single multi-head self-attention layer (f32).
47struct MhsaLayer {
48    /// Q projection: [hidden, hidden]
49    w_q: Array2<f32>,
50    /// K projection: [hidden, hidden]
51    w_k: Array2<f32>,
52    /// V projection: [hidden, hidden]
53    w_v: Array2<f32>,
54    /// Output projection: [hidden, hidden]
55    w_o: Array2<f32>,
56    /// Pre-attention LayerNorm scale
57    ln1_scale: Array1<f32>,
58    /// Pre-attention LayerNorm bias
59    ln1_bias: Array1<f32>,
60    n_heads: usize,
61    d_k: usize,
62}
63
64/// Feed-forward sub-layer (two linear + GELU).
65struct FfnLayer {
66    /// W1: [hidden, 4*hidden]
67    w1: Array2<f32>,
68    b1: Array1<f32>,
69    /// W2: [4*hidden, hidden]
70    w2: Array2<f32>,
71    b2: Array1<f32>,
72    /// Pre-FFN LayerNorm scale
73    ln2_scale: Array1<f32>,
74    ln2_bias: Array1<f32>,
75}
76
77// ─── LCG-based weight initialisation ─────────────────────────────────────────
78
79fn next_lcg(seed: &mut u64) -> f32 {
80    *seed = seed
81        .wrapping_mul(6364136223846793005)
82        .wrapping_add(1442695040888963407);
83    let bits = (*seed >> 33) as f32 / (u32::MAX as f32);
84    (bits - 0.5) * 2.0 // uniform in [-1, 1]
85}
86
87fn xavier_init(rows: usize, cols: usize, seed: &mut u64) -> Array2<f32> {
88    let scale = (6.0_f32 / (rows + cols) as f32).sqrt();
89    Array2::from_shape_fn((rows, cols), |_| next_lcg(seed) * scale)
90}
91
92fn zeros1(n: usize) -> Array1<f32> {
93    Array1::zeros(n)
94}
95
96fn ones1(n: usize) -> Array1<f32> {
97    Array1::ones(n)
98}
99
100// ─── Math helpers ─────────────────────────────────────────────────────────────
101
102/// Row-wise softmax in place.
103fn softmax_rows(x: &mut Array2<f32>) {
104    let (rows, cols) = x.dim();
105    for i in 0..rows {
106        let max_val = x.row(i).fold(f32::NEG_INFINITY, |a, &b| a.max(b));
107        let mut sum = 0.0_f32;
108        for j in 0..cols {
109            x[[i, j]] = (x[[i, j]] - max_val).exp();
110            sum += x[[i, j]];
111        }
112        if sum > 0.0 {
113            for j in 0..cols {
114                x[[i, j]] /= sum;
115            }
116        }
117    }
118}
119
120/// GELU approximation: 0.5x(1+tanh(√(2/π)(x+0.044715x³))).
121#[inline]
122fn gelu(x: f32) -> f32 {
123    let inner = (2.0_f32 / std::f32::consts::PI).sqrt() * (x + 0.044715 * x * x * x);
124    0.5 * x * (1.0 + inner.tanh())
125}
126
127/// Layer normalisation: (x - μ) / (σ + ε) * scale + bias.
128fn layer_norm(x: &Array2<f32>, scale: &Array1<f32>, bias: &Array1<f32>) -> Array2<f32> {
129    let eps = 1e-5_f32;
130    let (seq, hidden) = x.dim();
131    let mut out = Array2::zeros((seq, hidden));
132    for i in 0..seq {
133        let row = x.row(i);
134        let mean = row.sum() / hidden as f32;
135        let var = row.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / hidden as f32;
136        let inv_std = 1.0 / (var + eps).sqrt();
137        for j in 0..hidden {
138            out[[i, j]] = (x[[i, j]] - mean) * inv_std * scale[j] + bias[j];
139        }
140    }
141    out
142}
143
144// ─── MhsaLayer impl ──────────────────────────────────────────────────────────
145
146impl MhsaLayer {
147    fn new(hidden: usize, n_heads: usize, seed: &mut u64) -> Result<Self> {
148        if !hidden.is_multiple_of(n_heads) {
149            return Err(TextError::InvalidInput(format!(
150                "hidden_size {hidden} must be divisible by num_heads {n_heads}"
151            )));
152        }
153        let d_k = hidden / n_heads;
154        Ok(Self {
155            w_q: xavier_init(hidden, hidden, seed),
156            w_k: xavier_init(hidden, hidden, seed),
157            w_v: xavier_init(hidden, hidden, seed),
158            w_o: xavier_init(hidden, hidden, seed),
159            ln1_scale: ones1(hidden),
160            ln1_bias: zeros1(hidden),
161            n_heads,
162            d_k,
163        })
164    }
165
166    /// Forward pass; returns (output [seq, hidden], attention [n_heads, seq, seq]).
167    fn forward_with_attn(&self, x: &Array2<f32>) -> Result<(Array2<f32>, Array2<f32>)> {
168        let (seq, hidden) = x.dim();
169
170        // Pre-norm
171        let xn = layer_norm(x, &self.ln1_scale, &self.ln1_bias);
172
173        // Q, K, V projections: [seq, hidden]
174        let q = xn.dot(&self.w_q);
175        let k = xn.dot(&self.w_k);
176        let v = xn.dot(&self.w_v);
177
178        let scale = (self.d_k as f32).sqrt();
179
180        // Compute attention per head, accumulate output
181        let mut out = Array2::zeros((seq, hidden));
182        // averaged attention weights [seq, seq]
183        let mut avg_attn = Array2::zeros((seq, seq));
184
185        for h in 0..self.n_heads {
186            let start = h * self.d_k;
187            let end = start + self.d_k;
188
189            let q_h = q.slice(s![.., start..end]).to_owned(); // [seq, d_k]
190            let k_h = k.slice(s![.., start..end]).to_owned(); // [seq, d_k]
191            let v_h = v.slice(s![.., start..end]).to_owned(); // [seq, d_k]
192
193            // Attention scores [seq, seq]
194            let mut scores = q_h.dot(&k_h.t()) / scale; // [seq, seq]
195            softmax_rows(&mut scores);
196
197            // Add to avg_attn
198            avg_attn += &scores;
199
200            // Context [seq, d_k]
201            let ctx = scores.dot(&v_h);
202            out.slice_mut(s![.., start..end]).assign(&ctx);
203        }
204
205        // Average across heads
206        let n_heads_f = self.n_heads as f32;
207        avg_attn.mapv_inplace(|v| v / n_heads_f);
208
209        // Output projection + residual
210        let proj = out.dot(&self.w_o);
211        let result = x + &proj;
212
213        Ok((result, avg_attn))
214    }
215
216    /// Forward pass for all n_heads attention maps: returns [n_heads, seq, seq].
217    fn forward_all_heads(&self, x: &Array2<f32>) -> Result<(Array2<f32>, Array3<f32>)> {
218        let (seq, hidden) = x.dim();
219
220        let xn = layer_norm(x, &self.ln1_scale, &self.ln1_bias);
221
222        let q = xn.dot(&self.w_q);
223        let k = xn.dot(&self.w_k);
224        let v = xn.dot(&self.w_v);
225
226        let scale = (self.d_k as f32).sqrt();
227
228        let mut out = Array2::zeros((seq, hidden));
229        let mut all_attn = Array3::zeros((self.n_heads, seq, seq));
230
231        for h in 0..self.n_heads {
232            let start = h * self.d_k;
233            let end = start + self.d_k;
234
235            let q_h = q.slice(s![.., start..end]).to_owned();
236            let k_h = k.slice(s![.., start..end]).to_owned();
237            let v_h = v.slice(s![.., start..end]).to_owned();
238
239            let mut scores = q_h.dot(&k_h.t()) / scale;
240            softmax_rows(&mut scores);
241
242            all_attn.slice_mut(s![h, .., ..]).assign(&scores);
243
244            let ctx = scores.dot(&v_h);
245            out.slice_mut(s![.., start..end]).assign(&ctx);
246        }
247
248        let proj = out.dot(&self.w_o);
249        let result = x + &proj;
250
251        Ok((result, all_attn))
252    }
253}
254
255// ─── FfnLayer impl ───────────────────────────────────────────────────────────
256
257impl FfnLayer {
258    fn new(hidden: usize, seed: &mut u64) -> Self {
259        let ffn_dim = 4 * hidden;
260        Self {
261            w1: xavier_init(hidden, ffn_dim, seed),
262            b1: zeros1(ffn_dim),
263            w2: xavier_init(ffn_dim, hidden, seed),
264            b2: zeros1(hidden),
265            ln2_scale: ones1(hidden),
266            ln2_bias: zeros1(hidden),
267        }
268    }
269
270    fn forward(&self, x: &Array2<f32>) -> Array2<f32> {
271        // Pre-norm
272        let xn = layer_norm(x, &self.ln2_scale, &self.ln2_bias);
273
274        // W1 + bias + GELU
275        let h1 = xn.dot(&self.w1) + &self.b1;
276        let h1 = h1.mapv(gelu);
277
278        // W2 + bias + residual
279        let h2 = h1.dot(&self.w2) + &self.b2;
280        x + &h2
281    }
282}
283
284// ─── Sinusoidal positional encoding ──────────────────────────────────────────
285
286fn sinusoidal_pe(max_seq: usize, hidden: usize) -> Array2<f32> {
287    let mut pe = Array2::zeros((max_seq, hidden));
288    for pos in 0..max_seq {
289        for i in (0..hidden).step_by(2) {
290            let angle = pos as f32 / 10000.0_f32.powf(i as f32 / hidden as f32);
291            pe[[pos, i]] = angle.sin();
292            if i + 1 < hidden {
293                pe[[pos, i + 1]] = angle.cos();
294            }
295        }
296    }
297    pe
298}
299
300// ─── TransformerTextEncoder ───────────────────────────────────────────────────
301
302/// Transformer-based text encoder that maps token-ID sequences to contextual embeddings.
303///
304/// Uses pure-ndarray f32 multi-head self-attention with sinusoidal position encoding.
305pub struct TransformerTextEncoder {
306    config: TransformerEncoderConfig,
307    /// Token embedding table [vocab_size, hidden]
308    embedding: Array2<f32>,
309    /// Positional encoding table [max_seq_len, hidden]
310    position_enc: Array2<f32>,
311    /// Attention sub-layers (one per encoder layer)
312    attn_layers: Vec<MhsaLayer>,
313    /// FFN sub-layers (one per encoder layer)
314    ffn_layers: Vec<FfnLayer>,
315}
316
317impl TransformerTextEncoder {
318    /// Create a new encoder from the given config.
319    pub fn new(config: TransformerEncoderConfig) -> Result<Self> {
320        let mut seed = config.seed;
321
322        let scale = (config.hidden_size as f32).sqrt();
323        let embedding = Array2::from_shape_fn((config.vocab_size, config.hidden_size), |_| {
324            next_lcg(&mut seed) / scale
325        });
326
327        let position_enc = sinusoidal_pe(config.max_seq_len, config.hidden_size);
328
329        let mut attn_layers = Vec::with_capacity(config.num_layers);
330        let mut ffn_layers = Vec::with_capacity(config.num_layers);
331        for _ in 0..config.num_layers {
332            attn_layers.push(MhsaLayer::new(
333                config.hidden_size,
334                config.num_heads,
335                &mut seed,
336            )?);
337            ffn_layers.push(FfnLayer::new(config.hidden_size, &mut seed));
338        }
339
340        Ok(Self {
341            config,
342            embedding,
343            position_enc,
344            attn_layers,
345            ffn_layers,
346        })
347    }
348
349    /// Look up embeddings + add positional encoding for the given token IDs.
350    fn embed_tokens(&self, tokens: &[usize]) -> Result<Array2<f32>> {
351        let seq = tokens.len();
352        if seq == 0 {
353            return Err(TextError::InvalidInput("Empty token sequence".to_string()));
354        }
355        if seq > self.config.max_seq_len {
356            return Err(TextError::InvalidInput(format!(
357                "Sequence length {seq} exceeds max_seq_len {}",
358                self.config.max_seq_len
359            )));
360        }
361
362        let hidden = self.config.hidden_size;
363        let mut x = Array2::zeros((seq, hidden));
364        for (i, &tok) in tokens.iter().enumerate() {
365            if tok >= self.config.vocab_size {
366                return Err(TextError::InvalidInput(format!(
367                    "Token ID {tok} out of vocab range {}",
368                    self.config.vocab_size
369                )));
370            }
371            let emb_row = self.embedding.row(tok);
372            let pe_row = self.position_enc.row(i);
373            for j in 0..hidden {
374                x[[i, j]] = emb_row[j] + pe_row[j];
375            }
376        }
377        Ok(x)
378    }
379
380    /// Encode token IDs to contextual embeddings `[seq_len, hidden_size]`.
381    pub fn encode_tokens(&self, tokens: &[usize]) -> Result<Array2<f32>> {
382        let mut x = self.embed_tokens(tokens)?;
383        for (attn, ffn) in self.attn_layers.iter().zip(self.ffn_layers.iter()) {
384            let (out, _) = attn.forward_with_attn(&x)?;
385            x = ffn.forward(&out);
386        }
387        Ok(x)
388    }
389
390    /// Pool contextual embeddings to a single sentence embedding `[hidden_size]`.
391    /// Uses mean pooling across all token positions.
392    pub fn encode_sentence(&self, tokens: &[usize]) -> Result<Array1<f32>> {
393        let ctx = self.encode_tokens(tokens)?;
394        ctx.mean_axis(Axis(0))
395            .ok_or_else(|| TextError::InvalidInput("Cannot mean-pool empty context".to_string()))
396    }
397
398    /// Encode tokens and expose per-layer per-head attention weights.
399    ///
400    /// Returns `(embeddings [seq, hidden], attention_weights)` where
401    /// `attention_weights[layer]` has shape `[n_heads, seq, seq]`.
402    pub fn encode_with_attention(
403        &self,
404        tokens: &[usize],
405    ) -> Result<(Array2<f32>, Vec<Array3<f32>>)> {
406        let mut x = self.embed_tokens(tokens)?;
407        let mut all_attn = Vec::with_capacity(self.config.num_layers);
408
409        for (attn, ffn) in self.attn_layers.iter().zip(self.ffn_layers.iter()) {
410            let (out, layer_attn) = attn.forward_all_heads(&x)?;
411            x = ffn.forward(&out);
412            all_attn.push(layer_attn);
413        }
414
415        Ok((x, all_attn))
416    }
417
418    /// Access the encoder configuration.
419    pub fn config(&self) -> &TransformerEncoderConfig {
420        &self.config
421    }
422
423    /// Access the embedding table (read-only).
424    pub fn embedding(&self) -> &Array2<f32> {
425        &self.embedding
426    }
427
428    /// Mutably access the embedding table (for fine-tuning).
429    pub fn embedding_mut(&mut self) -> &mut Array2<f32> {
430        &mut self.embedding
431    }
432}