1use crate::error::{Result, TextError};
7use scirs2_core::ndarray::{s, Array1, Array2, Array3, Axis};
8
9#[derive(Debug, Clone)]
13pub struct TransformerEncoderConfig {
14 pub vocab_size: usize,
16 pub hidden_size: usize,
18 pub num_heads: usize,
20 pub num_layers: usize,
22 pub max_seq_len: usize,
24 pub dropout: f32,
26 pub seed: u64,
28}
29
30impl Default for TransformerEncoderConfig {
31 fn default() -> Self {
32 Self {
33 vocab_size: 30000,
34 hidden_size: 256,
35 num_heads: 4,
36 num_layers: 2,
37 max_seq_len: 512,
38 dropout: 0.1,
39 seed: 42,
40 }
41 }
42}
43
44struct MhsaLayer {
48 w_q: Array2<f32>,
50 w_k: Array2<f32>,
52 w_v: Array2<f32>,
54 w_o: Array2<f32>,
56 ln1_scale: Array1<f32>,
58 ln1_bias: Array1<f32>,
60 n_heads: usize,
61 d_k: usize,
62}
63
64struct FfnLayer {
66 w1: Array2<f32>,
68 b1: Array1<f32>,
69 w2: Array2<f32>,
71 b2: Array1<f32>,
72 ln2_scale: Array1<f32>,
74 ln2_bias: Array1<f32>,
75}
76
77fn next_lcg(seed: &mut u64) -> f32 {
80 *seed = seed
81 .wrapping_mul(6364136223846793005)
82 .wrapping_add(1442695040888963407);
83 let bits = (*seed >> 33) as f32 / (u32::MAX as f32);
84 (bits - 0.5) * 2.0 }
86
87fn xavier_init(rows: usize, cols: usize, seed: &mut u64) -> Array2<f32> {
88 let scale = (6.0_f32 / (rows + cols) as f32).sqrt();
89 Array2::from_shape_fn((rows, cols), |_| next_lcg(seed) * scale)
90}
91
92fn zeros1(n: usize) -> Array1<f32> {
93 Array1::zeros(n)
94}
95
96fn ones1(n: usize) -> Array1<f32> {
97 Array1::ones(n)
98}
99
100fn softmax_rows(x: &mut Array2<f32>) {
104 let (rows, cols) = x.dim();
105 for i in 0..rows {
106 let max_val = x.row(i).fold(f32::NEG_INFINITY, |a, &b| a.max(b));
107 let mut sum = 0.0_f32;
108 for j in 0..cols {
109 x[[i, j]] = (x[[i, j]] - max_val).exp();
110 sum += x[[i, j]];
111 }
112 if sum > 0.0 {
113 for j in 0..cols {
114 x[[i, j]] /= sum;
115 }
116 }
117 }
118}
119
120#[inline]
122fn gelu(x: f32) -> f32 {
123 let inner = (2.0_f32 / std::f32::consts::PI).sqrt() * (x + 0.044715 * x * x * x);
124 0.5 * x * (1.0 + inner.tanh())
125}
126
127fn layer_norm(x: &Array2<f32>, scale: &Array1<f32>, bias: &Array1<f32>) -> Array2<f32> {
129 let eps = 1e-5_f32;
130 let (seq, hidden) = x.dim();
131 let mut out = Array2::zeros((seq, hidden));
132 for i in 0..seq {
133 let row = x.row(i);
134 let mean = row.sum() / hidden as f32;
135 let var = row.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / hidden as f32;
136 let inv_std = 1.0 / (var + eps).sqrt();
137 for j in 0..hidden {
138 out[[i, j]] = (x[[i, j]] - mean) * inv_std * scale[j] + bias[j];
139 }
140 }
141 out
142}
143
144impl MhsaLayer {
147 fn new(hidden: usize, n_heads: usize, seed: &mut u64) -> Result<Self> {
148 if !hidden.is_multiple_of(n_heads) {
149 return Err(TextError::InvalidInput(format!(
150 "hidden_size {hidden} must be divisible by num_heads {n_heads}"
151 )));
152 }
153 let d_k = hidden / n_heads;
154 Ok(Self {
155 w_q: xavier_init(hidden, hidden, seed),
156 w_k: xavier_init(hidden, hidden, seed),
157 w_v: xavier_init(hidden, hidden, seed),
158 w_o: xavier_init(hidden, hidden, seed),
159 ln1_scale: ones1(hidden),
160 ln1_bias: zeros1(hidden),
161 n_heads,
162 d_k,
163 })
164 }
165
166 fn forward_with_attn(&self, x: &Array2<f32>) -> Result<(Array2<f32>, Array2<f32>)> {
168 let (seq, hidden) = x.dim();
169
170 let xn = layer_norm(x, &self.ln1_scale, &self.ln1_bias);
172
173 let q = xn.dot(&self.w_q);
175 let k = xn.dot(&self.w_k);
176 let v = xn.dot(&self.w_v);
177
178 let scale = (self.d_k as f32).sqrt();
179
180 let mut out = Array2::zeros((seq, hidden));
182 let mut avg_attn = Array2::zeros((seq, seq));
184
185 for h in 0..self.n_heads {
186 let start = h * self.d_k;
187 let end = start + self.d_k;
188
189 let q_h = q.slice(s![.., start..end]).to_owned(); let k_h = k.slice(s![.., start..end]).to_owned(); let v_h = v.slice(s![.., start..end]).to_owned(); let mut scores = q_h.dot(&k_h.t()) / scale; softmax_rows(&mut scores);
196
197 avg_attn += &scores;
199
200 let ctx = scores.dot(&v_h);
202 out.slice_mut(s![.., start..end]).assign(&ctx);
203 }
204
205 let n_heads_f = self.n_heads as f32;
207 avg_attn.mapv_inplace(|v| v / n_heads_f);
208
209 let proj = out.dot(&self.w_o);
211 let result = x + &proj;
212
213 Ok((result, avg_attn))
214 }
215
216 fn forward_all_heads(&self, x: &Array2<f32>) -> Result<(Array2<f32>, Array3<f32>)> {
218 let (seq, hidden) = x.dim();
219
220 let xn = layer_norm(x, &self.ln1_scale, &self.ln1_bias);
221
222 let q = xn.dot(&self.w_q);
223 let k = xn.dot(&self.w_k);
224 let v = xn.dot(&self.w_v);
225
226 let scale = (self.d_k as f32).sqrt();
227
228 let mut out = Array2::zeros((seq, hidden));
229 let mut all_attn = Array3::zeros((self.n_heads, seq, seq));
230
231 for h in 0..self.n_heads {
232 let start = h * self.d_k;
233 let end = start + self.d_k;
234
235 let q_h = q.slice(s![.., start..end]).to_owned();
236 let k_h = k.slice(s![.., start..end]).to_owned();
237 let v_h = v.slice(s![.., start..end]).to_owned();
238
239 let mut scores = q_h.dot(&k_h.t()) / scale;
240 softmax_rows(&mut scores);
241
242 all_attn.slice_mut(s![h, .., ..]).assign(&scores);
243
244 let ctx = scores.dot(&v_h);
245 out.slice_mut(s![.., start..end]).assign(&ctx);
246 }
247
248 let proj = out.dot(&self.w_o);
249 let result = x + &proj;
250
251 Ok((result, all_attn))
252 }
253}
254
255impl FfnLayer {
258 fn new(hidden: usize, seed: &mut u64) -> Self {
259 let ffn_dim = 4 * hidden;
260 Self {
261 w1: xavier_init(hidden, ffn_dim, seed),
262 b1: zeros1(ffn_dim),
263 w2: xavier_init(ffn_dim, hidden, seed),
264 b2: zeros1(hidden),
265 ln2_scale: ones1(hidden),
266 ln2_bias: zeros1(hidden),
267 }
268 }
269
270 fn forward(&self, x: &Array2<f32>) -> Array2<f32> {
271 let xn = layer_norm(x, &self.ln2_scale, &self.ln2_bias);
273
274 let h1 = xn.dot(&self.w1) + &self.b1;
276 let h1 = h1.mapv(gelu);
277
278 let h2 = h1.dot(&self.w2) + &self.b2;
280 x + &h2
281 }
282}
283
284fn sinusoidal_pe(max_seq: usize, hidden: usize) -> Array2<f32> {
287 let mut pe = Array2::zeros((max_seq, hidden));
288 for pos in 0..max_seq {
289 for i in (0..hidden).step_by(2) {
290 let angle = pos as f32 / 10000.0_f32.powf(i as f32 / hidden as f32);
291 pe[[pos, i]] = angle.sin();
292 if i + 1 < hidden {
293 pe[[pos, i + 1]] = angle.cos();
294 }
295 }
296 }
297 pe
298}
299
300pub struct TransformerTextEncoder {
306 config: TransformerEncoderConfig,
307 embedding: Array2<f32>,
309 position_enc: Array2<f32>,
311 attn_layers: Vec<MhsaLayer>,
313 ffn_layers: Vec<FfnLayer>,
315}
316
317impl TransformerTextEncoder {
318 pub fn new(config: TransformerEncoderConfig) -> Result<Self> {
320 let mut seed = config.seed;
321
322 let scale = (config.hidden_size as f32).sqrt();
323 let embedding = Array2::from_shape_fn((config.vocab_size, config.hidden_size), |_| {
324 next_lcg(&mut seed) / scale
325 });
326
327 let position_enc = sinusoidal_pe(config.max_seq_len, config.hidden_size);
328
329 let mut attn_layers = Vec::with_capacity(config.num_layers);
330 let mut ffn_layers = Vec::with_capacity(config.num_layers);
331 for _ in 0..config.num_layers {
332 attn_layers.push(MhsaLayer::new(
333 config.hidden_size,
334 config.num_heads,
335 &mut seed,
336 )?);
337 ffn_layers.push(FfnLayer::new(config.hidden_size, &mut seed));
338 }
339
340 Ok(Self {
341 config,
342 embedding,
343 position_enc,
344 attn_layers,
345 ffn_layers,
346 })
347 }
348
349 fn embed_tokens(&self, tokens: &[usize]) -> Result<Array2<f32>> {
351 let seq = tokens.len();
352 if seq == 0 {
353 return Err(TextError::InvalidInput("Empty token sequence".to_string()));
354 }
355 if seq > self.config.max_seq_len {
356 return Err(TextError::InvalidInput(format!(
357 "Sequence length {seq} exceeds max_seq_len {}",
358 self.config.max_seq_len
359 )));
360 }
361
362 let hidden = self.config.hidden_size;
363 let mut x = Array2::zeros((seq, hidden));
364 for (i, &tok) in tokens.iter().enumerate() {
365 if tok >= self.config.vocab_size {
366 return Err(TextError::InvalidInput(format!(
367 "Token ID {tok} out of vocab range {}",
368 self.config.vocab_size
369 )));
370 }
371 let emb_row = self.embedding.row(tok);
372 let pe_row = self.position_enc.row(i);
373 for j in 0..hidden {
374 x[[i, j]] = emb_row[j] + pe_row[j];
375 }
376 }
377 Ok(x)
378 }
379
380 pub fn encode_tokens(&self, tokens: &[usize]) -> Result<Array2<f32>> {
382 let mut x = self.embed_tokens(tokens)?;
383 for (attn, ffn) in self.attn_layers.iter().zip(self.ffn_layers.iter()) {
384 let (out, _) = attn.forward_with_attn(&x)?;
385 x = ffn.forward(&out);
386 }
387 Ok(x)
388 }
389
390 pub fn encode_sentence(&self, tokens: &[usize]) -> Result<Array1<f32>> {
393 let ctx = self.encode_tokens(tokens)?;
394 ctx.mean_axis(Axis(0))
395 .ok_or_else(|| TextError::InvalidInput("Cannot mean-pool empty context".to_string()))
396 }
397
398 pub fn encode_with_attention(
403 &self,
404 tokens: &[usize],
405 ) -> Result<(Array2<f32>, Vec<Array3<f32>>)> {
406 let mut x = self.embed_tokens(tokens)?;
407 let mut all_attn = Vec::with_capacity(self.config.num_layers);
408
409 for (attn, ffn) in self.attn_layers.iter().zip(self.ffn_layers.iter()) {
410 let (out, layer_attn) = attn.forward_all_heads(&x)?;
411 x = ffn.forward(&out);
412 all_attn.push(layer_attn);
413 }
414
415 Ok((x, all_attn))
416 }
417
418 pub fn config(&self) -> &TransformerEncoderConfig {
420 &self.config
421 }
422
423 pub fn embedding(&self) -> &Array2<f32> {
425 &self.embedding
426 }
427
428 pub fn embedding_mut(&mut self) -> &mut Array2<f32> {
430 &mut self.embedding
431 }
432}