1use crate::error::{Result, TextError};
199use scirs2_core::ndarray::{s, Array1, Array2, Array3, ArrayView2};
200use scirs2_core::random::Rng;
201use statrs::statistics::Statistics;
202use std::collections::HashMap;
203
204#[derive(Debug, Clone)]
206pub struct TransformerConfig {
207 pub d_model: usize,
209 pub nheads: usize,
211 pub d_ff: usize,
213 pub n_encoder_layers: usize,
215 pub n_decoder_layers: usize,
217 pub max_seqlen: usize,
219 pub dropout: f64,
221 pub vocab_size: usize,
223}
224
225impl Default for TransformerConfig {
226 fn default() -> Self {
227 Self {
228 d_model: 512,
229 nheads: 8,
230 d_ff: 2048,
231 n_encoder_layers: 6,
232 n_decoder_layers: 6,
233 max_seqlen: 512,
234 dropout: 0.1,
235 vocab_size: 10000,
236 }
237 }
238}
239
240pub struct PositionalEncoding {
242 encodings: Array2<f64>,
243 max_len: usize,
244 #[allow(dead_code)]
245 d_model: usize,
246}
247
248impl PositionalEncoding {
249 pub fn new(_max_len: usize, dmodel: usize) -> Self {
251 let mut encodings = Array2::<f64>::zeros((_max_len, dmodel));
252
253 for pos in 0.._max_len {
254 for i in (0..dmodel).step_by(2) {
255 let angle = pos as f64 / (10000.0_f64).powf(i as f64 / dmodel as f64);
256 encodings[[pos, i]] = angle.sin();
257 if i + 1 < dmodel {
258 encodings[[pos, i + 1]] = angle.cos();
259 }
260 }
261 }
262
263 Self {
264 encodings,
265 max_len: _max_len,
266 d_model: dmodel,
267 }
268 }
269
270 pub fn get_encoding(&self, seqlen: usize) -> Result<ArrayView2<f64>> {
272 if seqlen > self.max_len {
273 return Err(TextError::InvalidInput(format!(
274 "Sequence length {} exceeds maximum {}",
275 seqlen, self.max_len
276 )));
277 }
278 Ok(self.encodings.slice(s![0..seqlen, ..]))
279 }
280}
281
282pub struct MultiHeadAttention {
284 d_model: usize,
285 nheads: usize,
286 d_k: usize,
287 w_q: Array2<f64>,
288 w_k: Array2<f64>,
289 w_v: Array2<f64>,
290 w_o: Array2<f64>,
291}
292
293impl MultiHeadAttention {
294 pub fn new(d_model: usize, nheads: usize) -> Result<Self> {
296 if !d_model.is_multiple_of(nheads) {
297 return Err(TextError::InvalidInput(
298 "d_model must be divisible by nheads".to_string(),
299 ));
300 }
301
302 let d_k = d_model / nheads;
303
304 let scale = (2.0 / d_model as f64).sqrt();
306
307 let w_q = Array2::from_shape_fn((d_model, d_model), |_| {
308 scirs2_core::random::rng().random_range(-scale..scale)
309 });
310 let w_k = Array2::from_shape_fn((d_model, d_model), |_| {
311 scirs2_core::random::rng().random_range(-scale..scale)
312 });
313 let w_v = Array2::from_shape_fn((d_model, d_model), |_| {
314 scirs2_core::random::rng().random_range(-scale..scale)
315 });
316 let w_o = Array2::from_shape_fn((d_model, d_model), |_| {
317 scirs2_core::random::rng().random_range(-scale..scale)
318 });
319
320 Ok(Self {
321 d_model,
322 nheads,
323 d_k,
324 w_q,
325 w_k,
326 w_v,
327 w_o,
328 })
329 }
330
331 fn scaled_dot_product_attention(
333 &self,
334 q: ArrayView2<f64>,
335 k: ArrayView2<f64>,
336 v: ArrayView2<f64>,
337 mask: Option<ArrayView2<bool>>,
338 ) -> Result<Array2<f64>> {
339 let d_k = self.d_k as f64;
340
341 let scores = q.dot(&k.t()) / d_k.sqrt();
343
344 let mut masked_scores = scores;
346 if let Some(mask) = mask {
347 for ((i, j), &should_mask) in mask.indexed_iter() {
348 if should_mask {
349 masked_scores[[i, j]] = f64::NEG_INFINITY;
350 }
351 }
352 }
353
354 let attention_weights = self.softmax_2d(&masked_scores)?;
356
357 Ok(attention_weights.dot(&v))
359 }
360
361 fn softmax_2d(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
363 let mut result = x.clone();
364
365 for mut row in result.rows_mut() {
366 let max_val = row.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
367 row.mapv_inplace(|x| (x - max_val).exp());
368 let sum: f64 = row.sum();
369 if sum > 0.0 {
370 row /= sum;
371 }
372 }
373
374 Ok(result)
375 }
376
377 pub fn forward(
379 &self,
380 query: ArrayView2<f64>,
381 key: ArrayView2<f64>,
382 value: ArrayView2<f64>,
383 mask: Option<ArrayView2<bool>>,
384 ) -> Result<Array2<f64>> {
385 let _seqlen = query.shape()[0];
386
387 let q = query.dot(&self.w_q);
389 let k = key.dot(&self.w_k);
390 let v = value.dot(&self.w_v);
391
392 let q_heads = self.reshape_for_heads(&q)?;
394 let k_heads = self.reshape_for_heads(&k)?;
395 let v_heads = self.reshape_for_heads(&v)?;
396
397 let mut head_outputs = Vec::new();
399 for head in 0..self.nheads {
400 let q_head = q_heads.slice(s![head, .., ..]);
401 let k_head = k_heads.slice(s![head, .., ..]);
402 let v_head = v_heads.slice(s![head, .., ..]);
403
404 let head_output = self.scaled_dot_product_attention(q_head, k_head, v_head, mask)?;
405 head_outputs.push(head_output);
406 }
407
408 let concatenated = self.concatenate_heads(&head_outputs)?;
410
411 Ok(concatenated.dot(&self.w_o))
413 }
414
415 fn reshape_for_heads(&self, x: &Array2<f64>) -> Result<Array3<f64>> {
417 let (seqlen, d_model) = x.dim();
418 let reshaped = x
419 .clone()
420 .into_shape_with_order((seqlen, self.nheads, self.d_k))
421 .map_err(|e| TextError::InvalidInput(format!("Reshape error: {e}")))?;
422
423 Ok(reshaped.permuted_axes([1, 0, 2]))
425 }
426
427 fn concatenate_heads(&self, heads: &[Array2<f64>]) -> Result<Array2<f64>> {
429 if heads.is_empty() {
430 return Err(TextError::InvalidInput("No heads provided".to_string()));
431 }
432
433 let seqlen = heads[0].shape()[0];
434 let mut result = Array2::zeros((seqlen, self.d_model));
435
436 for (i, head) in heads.iter().enumerate() {
437 let start_col = i * self.d_k;
438 let end_col = start_col + self.d_k;
439 result.slice_mut(s![.., start_col..end_col]).assign(head);
440 }
441
442 Ok(result)
443 }
444
445 pub fn get_weights(&self) -> (&Array2<f64>, &Array2<f64>, &Array2<f64>, &Array2<f64>) {
447 (&self.w_q, &self.w_k, &self.w_v, &self.w_o)
448 }
449
450 pub fn set_weights(
452 &mut self,
453 w_q: Array2<f64>,
454 w_k: Array2<f64>,
455 w_v: Array2<f64>,
456 w_o: Array2<f64>,
457 ) -> Result<()> {
458 if w_q.shape() != [self.d_model, self.d_model] {
459 return Err(TextError::InvalidInput("Invalid w_q shape".to_string()));
460 }
461 if w_k.shape() != [self.d_model, self.d_model] {
462 return Err(TextError::InvalidInput("Invalid w_k shape".to_string()));
463 }
464 if w_v.shape() != [self.d_model, self.d_model] {
465 return Err(TextError::InvalidInput("Invalid w_v shape".to_string()));
466 }
467 if w_o.shape() != [self.d_model, self.d_model] {
468 return Err(TextError::InvalidInput("Invalid w_o shape".to_string()));
469 }
470
471 self.w_q = w_q;
472 self.w_k = w_k;
473 self.w_v = w_v;
474 self.w_o = w_o;
475 Ok(())
476 }
477}
478
479pub struct FeedForward {
481 w1: Array2<f64>,
482 w2: Array2<f64>,
483 b1: Array1<f64>,
484 b2: Array1<f64>,
485}
486
487impl FeedForward {
488 pub fn new(_dmodel: usize, dff: usize) -> Self {
490 let scale = (2.0 / _dmodel as f64).sqrt();
491
492 let w1 = Array2::from_shape_fn((_dmodel, dff), |_| {
493 scirs2_core::random::rng().random_range(-scale..scale)
494 });
495 let w2 = Array2::from_shape_fn((dff, _dmodel), |_| {
496 scirs2_core::random::rng().random_range(-scale..scale)
497 });
498 let b1 = Array1::zeros(dff);
499 let b2 = Array1::zeros(_dmodel);
500
501 Self { w1, w2, b1, b2 }
502 }
503
504 pub fn forward(&self, x: ArrayView2<f64>) -> Array2<f64> {
506 let hidden = x.dot(&self.w1) + &self.b1;
508 let activated = hidden.mapv(|x| x.max(0.0)); activated.dot(&self.w2) + &self.b2
512 }
513
514 pub fn get_weights(&self) -> (&Array2<f64>, &Array2<f64>, &Array1<f64>, &Array1<f64>) {
516 (&self.w1, &self.w2, &self.b1, &self.b2)
517 }
518
519 pub fn set_weights(
521 &mut self,
522 w1: Array2<f64>,
523 w2: Array2<f64>,
524 b1: Array1<f64>,
525 b2: Array1<f64>,
526 ) -> Result<()> {
527 if w1.shape()[1] != w2.shape()[0] {
528 return Err(TextError::InvalidInput(
529 "Weight matrix dimensions don't match".to_string(),
530 ));
531 }
532 if b1.len() != w1.shape()[1] {
533 return Err(TextError::InvalidInput(
534 "Bias b1 size doesn't match w1".to_string(),
535 ));
536 }
537 if b2.len() != w2.shape()[1] {
538 return Err(TextError::InvalidInput(
539 "Bias b2 size doesn't match w2".to_string(),
540 ));
541 }
542
543 self.w1 = w1;
544 self.w2 = w2;
545 self.b1 = b1;
546 self.b2 = b2;
547 Ok(())
548 }
549}
550
551pub struct LayerNorm {
553 gamma: Array1<f64>,
554 beta: Array1<f64>,
555 eps: f64,
556}
557
558impl LayerNorm {
559 pub fn new(_dmodel: usize, eps: f64) -> Self {
561 Self {
562 gamma: Array1::ones(_dmodel),
563 beta: Array1::zeros(_dmodel),
564 eps,
565 }
566 }
567
568 pub fn forward(&self, x: ArrayView2<f64>) -> Array2<f64> {
570 let mut result = Array2::zeros(x.raw_dim());
571
572 for (i, row) in x.rows().into_iter().enumerate() {
573 let mean = row.mean();
574 let var = row.mapv(|x| (x - mean).powi(2)).mean();
575 let std = (var + self.eps).sqrt();
576
577 let normalized = row.mapv(|x| (x - mean) / std);
578 let scaled = &normalized * &self.gamma + &self.beta;
579
580 result.row_mut(i).assign(&scaled);
581 }
582
583 result
584 }
585
586 pub fn get_params(&self) -> (&Array1<f64>, &Array1<f64>) {
588 (&self.gamma, &self.beta)
589 }
590
591 pub fn set_params(&mut self, gamma: Array1<f64>, beta: Array1<f64>) -> Result<()> {
593 if gamma.len() != beta.len() {
594 return Err(TextError::InvalidInput(
595 "Gamma and beta must have same length".to_string(),
596 ));
597 }
598 if gamma.len() != self.gamma.len() {
599 return Err(TextError::InvalidInput(
600 "Parameter size doesn't match layer dimension".to_string(),
601 ));
602 }
603
604 self.gamma = gamma;
605 self.beta = beta;
606 Ok(())
607 }
608}
609
610pub struct TransformerEncoderLayer {
612 self_attention: MultiHeadAttention,
613 feed_forward: FeedForward,
614 norm1: LayerNorm,
615 norm2: LayerNorm,
616 #[allow(dead_code)]
617 dropout: f64,
618}
619
620impl TransformerEncoderLayer {
621 pub fn new(config: &TransformerConfig) -> Result<Self> {
623 Ok(Self {
624 self_attention: MultiHeadAttention::new(config.d_model, config.nheads)?,
625 feed_forward: FeedForward::new(config.d_model, config.d_ff),
626 norm1: LayerNorm::new(config.d_model, 1e-6),
627 norm2: LayerNorm::new(config.d_model, 1e-6),
628 dropout: config.dropout,
629 })
630 }
631
632 pub fn forward(
634 &self,
635 x: ArrayView2<f64>,
636 mask: Option<ArrayView2<bool>>,
637 ) -> Result<Array2<f64>> {
638 let attn_output = self.self_attention.forward(x, x, x, mask)?;
640 let x = &self.norm1.forward(x) + &attn_output;
641
642 let ff_output = self.feed_forward.forward(x.view());
644 let output = &self.norm2.forward(x.view()) + &ff_output;
645
646 Ok(output)
647 }
648
649 pub fn get_components_mut(
651 &mut self,
652 ) -> (
653 &mut MultiHeadAttention,
654 &mut FeedForward,
655 &mut LayerNorm,
656 &mut LayerNorm,
657 ) {
658 (
659 &mut self.self_attention,
660 &mut self.feed_forward,
661 &mut self.norm1,
662 &mut self.norm2,
663 )
664 }
665
666 pub fn get_components(&self) -> (&MultiHeadAttention, &FeedForward, &LayerNorm, &LayerNorm) {
668 (
669 &self.self_attention,
670 &self.feed_forward,
671 &self.norm1,
672 &self.norm2,
673 )
674 }
675}
676
677pub struct TransformerEncoder {
679 layers: Vec<TransformerEncoderLayer>,
680 position_encoding: PositionalEncoding,
681 config: TransformerConfig,
682}
683
684impl TransformerEncoder {
685 pub fn new(config: TransformerConfig) -> Result<Self> {
687 let mut layers = Vec::new();
688 for _ in 0..config.n_encoder_layers {
689 layers.push(TransformerEncoderLayer::new(&config)?);
690 }
691
692 let position_encoding = PositionalEncoding::new(config.max_seqlen, config.d_model);
693
694 Ok(Self {
695 layers,
696 position_encoding,
697 config,
698 })
699 }
700
701 pub fn encode(
703 &self,
704 embeddings: ArrayView2<f64>,
705 mask: Option<ArrayView2<bool>>,
706 ) -> Result<Array2<f64>> {
707 let seqlen = embeddings.shape()[0];
708
709 let pos_enc = self.position_encoding.get_encoding(seqlen)?;
711 let mut x = embeddings.to_owned() + pos_enc;
712
713 for layer in &self.layers {
715 x = layer.forward(x.view(), mask)?;
716 }
717
718 Ok(x)
719 }
720
721 pub fn config(&self) -> &TransformerConfig {
723 &self.config
724 }
725
726 pub fn get_layers_mut(&mut self) -> &mut Vec<TransformerEncoderLayer> {
728 &mut self.layers
729 }
730
731 pub fn get_layers(&self) -> &Vec<TransformerEncoderLayer> {
733 &self.layers
734 }
735}
736
737pub struct TransformerDecoderLayer {
739 self_attention: MultiHeadAttention,
740 cross_attention: MultiHeadAttention,
741 feed_forward: FeedForward,
742 norm1: LayerNorm,
743 norm2: LayerNorm,
744 norm3: LayerNorm,
745 #[allow(dead_code)]
746 dropout: f64,
747}
748
749impl TransformerDecoderLayer {
750 pub fn new(config: &TransformerConfig) -> Result<Self> {
752 Ok(Self {
753 self_attention: MultiHeadAttention::new(config.d_model, config.nheads)?,
754 cross_attention: MultiHeadAttention::new(config.d_model, config.nheads)?,
755 feed_forward: FeedForward::new(config.d_model, config.d_ff),
756 norm1: LayerNorm::new(config.d_model, 1e-6),
757 norm2: LayerNorm::new(config.d_model, 1e-6),
758 norm3: LayerNorm::new(config.d_model, 1e-6),
759 dropout: config.dropout,
760 })
761 }
762
763 pub fn forward(
765 &self,
766 x: ArrayView2<f64>,
767 encoder_output: ArrayView2<f64>,
768 self_attn_mask: Option<ArrayView2<bool>>,
769 cross_attn_mask: Option<ArrayView2<bool>>,
770 ) -> Result<Array2<f64>> {
771 let self_attn_out = self.self_attention.forward(x, x, x, self_attn_mask)?;
773 let x = self.norm1.forward((x.to_owned() + self_attn_out).view());
774
775 let cross_attn_out = self.cross_attention.forward(
777 x.view(),
778 encoder_output,
779 encoder_output,
780 cross_attn_mask,
781 )?;
782 let x = self.norm2.forward((x + cross_attn_out).view());
783
784 let ff_out = self.feed_forward.forward(x.view());
786 let _output = self.norm3.forward((x + ff_out).view());
787
788 Ok(_output)
789 }
790}
791
792pub struct TransformerDecoder {
794 layers: Vec<TransformerDecoderLayer>,
795 position_encoding: PositionalEncoding,
796 config: TransformerConfig,
797}
798
799impl TransformerDecoder {
800 pub fn new(config: TransformerConfig) -> Result<Self> {
802 let mut layers = Vec::new();
803 for _ in 0..config.n_decoder_layers {
804 layers.push(TransformerDecoderLayer::new(&config)?);
805 }
806
807 let position_encoding = PositionalEncoding::new(config.max_seqlen, config.d_model);
808
809 Ok(Self {
810 layers,
811 position_encoding,
812 config,
813 })
814 }
815
816 pub fn forward(
818 &self,
819 embeddings: ArrayView2<f64>,
820 encoder_output: ArrayView2<f64>,
821 self_attn_mask: Option<ArrayView2<bool>>,
822 cross_attn_mask: Option<ArrayView2<bool>>,
823 ) -> Result<Array2<f64>> {
824 let seqlen = embeddings.shape()[0];
825
826 let pos_enc = self.position_encoding.get_encoding(seqlen)?;
828 let mut x = embeddings.to_owned() + pos_enc;
829
830 for layer in &self.layers {
832 x = layer.forward(x.view(), encoder_output, self_attn_mask, cross_attn_mask)?;
833 }
834
835 Ok(x)
836 }
837
838 pub fn config(&self) -> &TransformerConfig {
840 &self.config
841 }
842}
843
844pub struct TokenEmbedding {
846 embeddings: Array2<f64>,
847 vocab_size: usize,
848 d_model: usize,
849}
850
851impl TokenEmbedding {
852 pub fn new(_vocab_size: usize, dmodel: usize) -> Self {
854 let scale = (1.0 / dmodel as f64).sqrt();
855 let embeddings = Array2::from_shape_fn((_vocab_size, dmodel), |_| {
856 scirs2_core::random::rng().random_range(-scale..scale)
857 });
858
859 Self {
860 embeddings,
861 vocab_size: _vocab_size,
862 d_model: dmodel,
863 }
864 }
865
866 pub fn forward(&self, tokenids: &[usize]) -> Result<Array2<f64>> {
868 let mut result = Array2::zeros((tokenids.len(), self.d_model));
869
870 for (i, &token_id) in tokenids.iter().enumerate() {
871 if token_id >= self.vocab_size {
872 return Err(TextError::InvalidInput(format!(
873 "Token ID {} exceeds vocabulary size {}",
874 token_id, self.vocab_size
875 )));
876 }
877 result.row_mut(i).assign(&self.embeddings.row(token_id));
878 }
879
880 Ok(result)
881 }
882
883 pub fn get_embeddings(&self) -> &Array2<f64> {
885 &self.embeddings
886 }
887
888 pub fn set_embeddings(&mut self, embeddings: Array2<f64>) -> Result<()> {
890 if embeddings.shape()[0] != self.vocab_size || embeddings.shape()[1] != self.d_model {
891 return Err(TextError::InvalidInput(format!(
892 "Embedding shape {:?} doesn't match expected ({}, {})",
893 embeddings.shape(),
894 self.vocab_size,
895 self.d_model
896 )));
897 }
898 self.embeddings = embeddings;
899 Ok(())
900 }
901}
902
903pub struct TransformerModel {
905 pub config: TransformerConfig,
907 pub token_embedding: TokenEmbedding,
909 pub encoder: TransformerEncoder,
911 pub decoder: Option<TransformerDecoder>,
913 vocab_to_id: HashMap<String, usize>,
914 id_to_vocab: HashMap<usize, String>,
915}
916
917impl TransformerModel {
918 pub fn new(config: TransformerConfig, vocabulary: Vec<String>) -> Result<Self> {
920 let vocab_size = vocabulary.len();
921 if vocab_size != config.vocab_size {
922 return Err(TextError::InvalidInput(format!(
923 "Vocabulary size {} doesn't match config {}",
924 vocab_size, config.vocab_size
925 )));
926 }
927
928 let mut vocab_to_id = HashMap::new();
929 let mut id_to_vocab = HashMap::new();
930
931 for (id, token) in vocabulary.into_iter().enumerate() {
932 vocab_to_id.insert(token.clone(), id);
933 id_to_vocab.insert(id, token);
934 }
935
936 Ok(Self {
937 config: config.clone(),
938 token_embedding: TokenEmbedding::new(config.vocab_size, config.d_model),
939 encoder: TransformerEncoder::new(config)?,
940 decoder: None, vocab_to_id,
942 id_to_vocab,
943 })
944 }
945
946 pub fn encode_tokens(&self, tokens: &[String]) -> Result<Array2<f64>> {
948 let tokenids: Result<Vec<usize>> = tokens
950 .iter()
951 .map(|token| {
952 self.vocab_to_id
953 .get(token)
954 .cloned()
955 .ok_or_else(|| TextError::InvalidInput(format!("Unknown token: {token}")))
956 })
957 .collect();
958 let tokenids = tokenids?;
959
960 let embeddings = self.token_embedding.forward(&tokenids)?;
962
963 self.encoder.encode(embeddings.view(), None)
965 }
966
967 pub fn new_encoder_decoder(config: TransformerConfig, vocabulary: Vec<String>) -> Result<Self> {
969 let vocab_size = vocabulary.len();
970 if vocab_size != config.vocab_size {
971 return Err(TextError::InvalidInput(format!(
972 "Vocabulary size {} doesn't match config {}",
973 vocab_size, config.vocab_size
974 )));
975 }
976
977 let mut vocab_to_id = HashMap::new();
978 let mut id_to_vocab = HashMap::new();
979
980 for (id, token) in vocabulary.into_iter().enumerate() {
981 vocab_to_id.insert(token.clone(), id);
982 id_to_vocab.insert(id, token);
983 }
984
985 Ok(Self {
986 config: config.clone(),
987 token_embedding: TokenEmbedding::new(config.vocab_size, config.d_model),
988 encoder: TransformerEncoder::new(config.clone())?,
989 decoder: Some(TransformerDecoder::new(config)?),
990 vocab_to_id,
991 id_to_vocab,
992 })
993 }
994
995 pub fn encode_decode(
997 &self,
998 input_tokens: &[String],
999 target_tokens: &[String],
1000 ) -> Result<Array2<f64>> {
1001 let decoder = self
1002 .decoder
1003 .as_ref()
1004 .ok_or_else(|| TextError::InvalidInput("Model has no decoder".to_string()))?;
1005
1006 let encoder_output = self.encode_tokens(input_tokens)?;
1008
1009 let target_ids: Result<Vec<usize>> = target_tokens
1011 .iter()
1012 .map(|token| {
1013 self.vocab_to_id
1014 .get(token)
1015 .copied()
1016 .ok_or_else(|| TextError::InvalidInput(format!("Unknown token: {token}")))
1017 })
1018 .collect();
1019 let target_ids = target_ids?;
1020
1021 let target_embeddings = self.token_embedding.forward(&target_ids)?;
1022
1023 let seqlen = target_tokens.len();
1025 let mut causal_mask = Array2::from_elem((seqlen, seqlen), false);
1026 for i in 0..seqlen {
1027 for j in (i + 1)..seqlen {
1028 causal_mask[[i, j]] = true; }
1030 }
1031
1032 decoder.forward(
1034 target_embeddings.view(),
1035 encoder_output.view(),
1036 Some(causal_mask.view()),
1037 None,
1038 )
1039 }
1040
1041 pub fn generate(
1043 &self,
1044 input_tokens: &[String],
1045 max_length: usize,
1046 start_token: &str,
1047 ) -> Result<Vec<String>> {
1048 let decoder = self
1049 .decoder
1050 .as_ref()
1051 .ok_or_else(|| TextError::InvalidInput("Model has no decoder".to_string()))?;
1052
1053 let encoder_output = self.encode_tokens(input_tokens)?;
1055
1056 let mut generated_tokens = vec![start_token.to_string()];
1058
1059 for _ in 0..max_length {
1060 let current_ids: Result<Vec<usize>> = generated_tokens
1062 .iter()
1063 .map(|_token| {
1064 self.vocab_to_id
1065 .get(_token)
1066 .copied()
1067 .ok_or_else(|| TextError::InvalidInput(format!("Unknown token: {_token}")))
1068 })
1069 .collect();
1070 let current_ids = current_ids?;
1071
1072 let current_embeddings = self.token_embedding.forward(¤t_ids)?;
1073
1074 let seqlen = generated_tokens.len();
1076 let mut causal_mask = Array2::from_elem((seqlen, seqlen), false);
1077 for i in 0..seqlen {
1078 for j in (i + 1)..seqlen {
1079 causal_mask[[i, j]] = true;
1080 }
1081 }
1082
1083 let decoder_output = decoder.forward(
1085 current_embeddings.view(),
1086 encoder_output.view(),
1087 Some(causal_mask.view()),
1088 None,
1089 )?;
1090
1091 let last_output = decoder_output.row(decoder_output.nrows() - 1);
1093
1094 let mut best_token_id = 0;
1096 let mut best_score = last_output[0];
1097 for (i, &score) in last_output.iter().enumerate() {
1098 if score > best_score {
1099 best_score = score;
1100 best_token_id = i;
1101 }
1102 }
1103
1104 if let Some(_token) = self.id_to_vocab.get(&best_token_id) {
1106 generated_tokens.push(_token.clone());
1107
1108 if _token == "</s>" || _token == "<eos>" {
1110 break;
1111 }
1112 } else {
1113 break;
1114 }
1115 }
1116
1117 Ok(generated_tokens)
1118 }
1119
1120 pub fn vocabulary(&self) -> (&HashMap<String, usize>, &HashMap<usize, String>) {
1122 (&self.vocab_to_id, &self.id_to_vocab)
1123 }
1124}
1125
1126#[cfg(test)]
1127mod tests {
1128 use super::*;
1129
1130 #[test]
1131 fn test_positional_encoding() {
1132 let pos_enc = PositionalEncoding::new(10, 4);
1133 let encoding = pos_enc.get_encoding(5).unwrap();
1134 assert_eq!(encoding.shape(), &[5, 4]);
1135
1136 let pos0 = encoding.row(0);
1138 let pos1 = encoding.row(1);
1139 assert!(pos0
1140 .iter()
1141 .zip(pos1.iter())
1142 .any(|(a, b)| (a - b).abs() > 1e-6));
1143 }
1144
1145 #[test]
1146 fn test_multi_head_attention() {
1147 let mha = MultiHeadAttention::new(8, 2).unwrap();
1148 let seqlen = 4;
1149 let d_model = 8;
1150
1151 let input = Array2::ones((seqlen, d_model));
1152 let output = mha
1153 .forward(input.view(), input.view(), input.view(), None)
1154 .unwrap();
1155
1156 assert_eq!(output.shape(), &[seqlen, d_model]);
1157 }
1158
1159 #[test]
1160 fn test_transformer_encoder() {
1161 let config = TransformerConfig {
1162 d_model: 8,
1163 nheads: 2,
1164 d_ff: 16,
1165 n_encoder_layers: 2,
1166 ..Default::default()
1167 };
1168
1169 let encoder = TransformerEncoder::new(config).unwrap();
1170 let input = Array2::ones((4, 8));
1171 let output = encoder.encode(input.view(), None).unwrap();
1172
1173 assert_eq!(output.shape(), &[4, 8]);
1174 }
1175}