1use crate::error::{NeuralError, Result};
31use scirs2_core::ndarray::{s, Array, Array2, Array3, Axis, IxDyn, Zip};
32use scirs2_core::numeric::{Float, NumAssign};
33use scirs2_core::random::{Rng, RngExt};
34use std::f64::consts::PI;
35use std::fmt::Debug;
36
37pub trait PositionalEncoding<F: Float + Debug + NumAssign> {
39 fn encode(&self, seq_len: usize) -> Array2<F>;
47
48 fn apply(&self, input: &Array3<F>) -> Result<Array3<F>>;
56
57 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
65 if input.ndim() != 3 {
67 return Err(NeuralError::InvalidArchitecture(format!(
68 "Expected 3D input, got {}D",
69 input.ndim()
70 )));
71 }
72
73 let shape = input.shape();
74 let input_3d = input
75 .view()
76 .into_dimensionality::<scirs2_core::ndarray::Ix3>()
77 .map_err(|e| {
78 NeuralError::InvalidArchitecture(format!("Failed to convert to 3D: {}", e))
79 })?;
80
81 let output_3d = self.apply(&input_3d.to_owned())?;
82 Ok(output_3d.into_dyn())
83 }
84
85 fn update(&mut self, _learning_rate: F) -> Result<()> {
93 Ok(())
95 }
96
97 fn clone_box(&self) -> Box<dyn PositionalEncoding<F> + Send + Sync>
99 where
100 F: Send + Sync + 'static;
101
102 fn d_model(&self) -> usize;
104
105 fn max_len(&self) -> usize;
107}
108
109#[derive(Debug, Clone)]
121pub struct SinusoidalPositionalEncoding<F: Float + Debug + NumAssign> {
122 d_model: usize,
123 max_len: usize,
124 encodings: Array2<F>,
126 dropout: Option<F>,
128}
129
130impl<F: Float + Debug + NumAssign> SinusoidalPositionalEncoding<F> {
131 pub fn new(d_model: usize, max_len: usize) -> Self {
137 assert!(
138 d_model.is_multiple_of(2),
139 "d_model must be even for sinusoidal PE"
140 );
141
142 let encodings = Self::compute_encodings(d_model, max_len);
143
144 Self {
145 d_model,
146 max_len,
147 encodings,
148 dropout: None,
149 }
150 }
151
152 pub fn with_dropout(d_model: usize, max_len: usize, dropout: F) -> Self {
154 let mut pe = Self::new(d_model, max_len);
155 pe.dropout = Some(dropout);
156 pe
157 }
158
159 fn compute_encodings(d_model: usize, max_len: usize) -> Array2<F> {
161 let mut encodings = Array2::zeros((max_len, d_model));
162
163 for pos in 0..max_len {
164 for i in 0..(d_model / 2) {
165 let exponent = (2 * i) as f64 / d_model as f64;
167 let div_term = (10000.0_f64).powf(exponent);
168 let angle = pos as f64 / div_term;
169
170 let sin_val = F::from(angle.sin()).unwrap_or(F::zero());
172 let cos_val = F::from(angle.cos()).unwrap_or(F::zero());
173
174 encodings[[pos, 2 * i]] = sin_val;
175 encodings[[pos, 2 * i + 1]] = cos_val;
176 }
177 }
178
179 encodings
180 }
181
182 pub fn params(&self) -> Vec<&Array<F, IxDyn>> {
184 Vec::new()
185 }
186
187 pub fn set_training(&mut self, _training: bool) {
189 }
191}
192
193impl<F: Float + Debug + NumAssign> PositionalEncoding<F> for SinusoidalPositionalEncoding<F> {
194 fn encode(&self, seq_len: usize) -> Array2<F> {
195 assert!(
196 seq_len <= self.max_len,
197 "seq_len {} exceeds max_len {}",
198 seq_len,
199 self.max_len
200 );
201 self.encodings.slice(s![..seq_len, ..]).to_owned()
202 }
203
204 fn apply(&self, input: &Array3<F>) -> Result<Array3<F>> {
205 let seq_len = input.shape()[1];
206 if seq_len > self.max_len {
207 return Err(NeuralError::InvalidArchitecture(format!(
208 "Sequence length {} exceeds max_len {}",
209 seq_len, self.max_len
210 )));
211 }
212
213 let encoding = self.encode(seq_len);
214 let mut output = input.clone();
215
216 for mut batch in output.axis_iter_mut(Axis(0)) {
218 Zip::from(&mut batch)
219 .and(&encoding)
220 .for_each(|b, &e| *b += e);
221 }
222
223 Ok(output)
224 }
225
226 fn clone_box(&self) -> Box<dyn PositionalEncoding<F> + Send + Sync>
227 where
228 F: Send + Sync + 'static,
229 {
230 Box::new(self.clone())
231 }
232
233 fn d_model(&self) -> usize {
234 self.d_model
235 }
236
237 fn max_len(&self) -> usize {
238 self.max_len
239 }
240}
241
242#[derive(Debug, Clone)]
248pub struct LearnedPositionalEncoding<F: Float + Debug + NumAssign> {
249 d_model: usize,
250 max_len: usize,
251 embeddings: Array2<F>,
253}
254
255impl<F: Float + Debug + NumAssign> LearnedPositionalEncoding<F> {
256 pub fn new<R: Rng>(d_model: usize, max_len: usize, rng: &mut R) -> Self {
263 let std = (2.0 / (max_len + d_model) as f64).sqrt();
265
266 let mut embeddings = Array2::zeros((max_len, d_model));
267 for elem in embeddings.iter_mut() {
268 let u1: f64 = rng.random_range(0.0001..1.0);
270 let u2: f64 = rng.random_range(0.0..1.0);
271 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos();
272 *elem = F::from(z * std).unwrap_or(F::zero());
273 }
274
275 Self {
276 d_model,
277 max_len,
278 embeddings,
279 }
280 }
281
282 pub fn from_embeddings(embeddings: Array2<F>) -> Self {
284 let shape = embeddings.shape();
285 Self {
286 d_model: shape[1],
287 max_len: shape[0],
288 embeddings,
289 }
290 }
291
292 pub fn embeddings_mut(&mut self) -> &mut Array2<F> {
294 &mut self.embeddings
295 }
296
297 pub fn embeddings(&self) -> &Array2<F> {
299 &self.embeddings
300 }
301}
302
303impl<F: Float + Debug + NumAssign> PositionalEncoding<F> for LearnedPositionalEncoding<F> {
304 fn encode(&self, seq_len: usize) -> Array2<F> {
305 assert!(
306 seq_len <= self.max_len,
307 "seq_len {} exceeds max_len {}",
308 seq_len,
309 self.max_len
310 );
311 self.embeddings.slice(s![..seq_len, ..]).to_owned()
312 }
313
314 fn apply(&self, input: &Array3<F>) -> Result<Array3<F>> {
315 let seq_len = input.shape()[1];
316 if seq_len > self.max_len {
317 return Err(NeuralError::InvalidArchitecture(format!(
318 "Sequence length {} exceeds max_len {}",
319 seq_len, self.max_len
320 )));
321 }
322
323 let encoding = self.encode(seq_len);
324 let mut output = input.clone();
325
326 for mut batch in output.axis_iter_mut(Axis(0)) {
327 Zip::from(&mut batch)
328 .and(&encoding)
329 .for_each(|b, &e| *b += e);
330 }
331
332 Ok(output)
333 }
334
335 fn clone_box(&self) -> Box<dyn PositionalEncoding<F> + Send + Sync>
336 where
337 F: Send + Sync + 'static,
338 {
339 Box::new(self.clone())
340 }
341
342 fn d_model(&self) -> usize {
343 self.d_model
344 }
345
346 fn max_len(&self) -> usize {
347 self.max_len
348 }
349}
350
351#[derive(Debug, Clone)]
361pub struct RotaryPositionalEncoding<F: Float + Debug + NumAssign> {
362 d_model: usize,
363 max_len: usize,
364 base: f64,
365 sin_cached: Array2<F>,
367 cos_cached: Array2<F>,
369}
370
371impl<F: Float + Debug + NumAssign> RotaryPositionalEncoding<F> {
372 pub fn new(d_model: usize, max_len: usize, base: f64) -> Self {
379 assert!(d_model.is_multiple_of(2), "d_model must be even for RoPE");
380
381 let (sin_cached, cos_cached) = Self::compute_rope_cache(d_model, max_len, base);
382
383 Self {
384 d_model,
385 max_len,
386 base,
387 sin_cached,
388 cos_cached,
389 }
390 }
391
392 pub fn default_base(d_model: usize, max_len: usize) -> Self {
394 Self::new(d_model, max_len, 10000.0)
395 }
396
397 fn compute_rope_cache(d_model: usize, max_len: usize, base: f64) -> (Array2<F>, Array2<F>) {
399 let half_dim = d_model / 2;
400 let mut sin_cached = Array2::zeros((max_len, half_dim));
401 let mut cos_cached = Array2::zeros((max_len, half_dim));
402
403 for pos in 0..max_len {
405 for i in 0..half_dim {
406 let freq = 1.0 / base.powf((2 * i) as f64 / d_model as f64);
407 let angle = pos as f64 * freq;
408
409 sin_cached[[pos, i]] = F::from(angle.sin()).unwrap_or(F::zero());
410 cos_cached[[pos, i]] = F::from(angle.cos()).unwrap_or(F::zero());
411 }
412 }
413
414 (sin_cached, cos_cached)
415 }
416
417 pub fn rotate(&self, x: &Array3<F>, offset: usize) -> Result<Array3<F>> {
426 let seq_len = x.shape()[1];
427 if seq_len + offset > self.max_len {
428 return Err(NeuralError::InvalidArchitecture(format!(
429 "Position {} exceeds max_len {}",
430 seq_len + offset,
431 self.max_len
432 )));
433 }
434
435 let batch_size = x.shape()[0];
436 let half_dim = self.d_model / 2;
437
438 let mut output = Array3::zeros(x.raw_dim());
439
440 for b in 0..batch_size {
441 for pos in 0..seq_len {
442 let abs_pos = pos + offset;
443 for i in 0..half_dim {
444 let x1 = x[[b, pos, 2 * i]];
445 let x2 = x[[b, pos, 2 * i + 1]];
446
447 let cos = self.cos_cached[[abs_pos, i]];
448 let sin = self.sin_cached[[abs_pos, i]];
449
450 output[[b, pos, 2 * i]] = x1 * cos - x2 * sin;
452 output[[b, pos, 2 * i + 1]] = x1 * sin + x2 * cos;
453 }
454 }
455 }
456
457 Ok(output)
458 }
459
460 pub fn sin_cache(&self) -> &Array2<F> {
462 &self.sin_cached
463 }
464
465 pub fn cos_cache(&self) -> &Array2<F> {
467 &self.cos_cached
468 }
469}
470
471impl<F: Float + Debug + NumAssign> PositionalEncoding<F> for RotaryPositionalEncoding<F> {
472 fn encode(&self, seq_len: usize) -> Array2<F> {
473 let half_dim = self.d_model / 2;
476 let mut encoding = Array2::zeros((seq_len, self.d_model));
477
478 for pos in 0..seq_len {
479 for i in 0..half_dim {
480 encoding[[pos, 2 * i]] = self.sin_cached[[pos, i]];
481 encoding[[pos, 2 * i + 1]] = self.cos_cached[[pos, i]];
482 }
483 }
484
485 encoding
486 }
487
488 fn apply(&self, input: &Array3<F>) -> Result<Array3<F>> {
489 self.rotate(input, 0)
491 }
492
493 fn clone_box(&self) -> Box<dyn PositionalEncoding<F> + Send + Sync>
494 where
495 F: Send + Sync + 'static,
496 {
497 Box::new(self.clone())
498 }
499
500 fn d_model(&self) -> usize {
501 self.d_model
502 }
503
504 fn max_len(&self) -> usize {
505 self.max_len
506 }
507}
508
509#[derive(Debug, Clone)]
515pub struct RelativePositionalEncoding<F: Float + Debug + NumAssign> {
516 d_model: usize,
517 max_len: usize,
518 rel_embeddings: Array2<F>,
521}
522
523impl<F: Float + Debug + NumAssign> RelativePositionalEncoding<F> {
524 pub fn new<R: Rng>(d_model: usize, max_len: usize, rng: &mut R) -> Self {
531 let num_positions = 2 * max_len - 1;
532 let std = (1.0 / d_model as f64).sqrt();
533
534 let mut rel_embeddings = Array2::zeros((num_positions, d_model));
535 for elem in rel_embeddings.iter_mut() {
536 let u1: f64 = rng.random_range(0.0001..1.0);
537 let u2: f64 = rng.random_range(0.0..1.0);
538 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos();
539 *elem = F::from(z * std).unwrap_or(F::zero());
540 }
541
542 Self {
543 d_model,
544 max_len,
545 rel_embeddings,
546 }
547 }
548
549 pub fn get_relative_embedding(&self, rel_pos: i64) -> Option<Array<F, IxDyn>> {
554 let max_rel = self.max_len as i64 - 1;
555 if rel_pos < -max_rel || rel_pos > max_rel {
556 return None;
557 }
558
559 let idx = (rel_pos + max_rel) as usize;
560 Some(self.rel_embeddings.slice(s![idx, ..]).to_owned().into_dyn())
561 }
562
563 pub fn get_attention_bias(&self, query_len: usize, key_len: usize) -> Result<Array3<F>> {
572 if query_len > self.max_len || key_len > self.max_len {
573 return Err(NeuralError::InvalidArchitecture(format!(
574 "Sequence length exceeds max_len {}",
575 self.max_len
576 )));
577 }
578
579 let mut bias = Array3::zeros((query_len, key_len, self.d_model));
580 let max_rel = self.max_len as i64 - 1;
581
582 for q in 0..query_len {
583 for k in 0..key_len {
584 let rel_pos = k as i64 - q as i64;
585 let idx = (rel_pos + max_rel) as usize;
586
587 for d in 0..self.d_model {
588 bias[[q, k, d]] = self.rel_embeddings[[idx, d]];
589 }
590 }
591 }
592
593 Ok(bias)
594 }
595
596 pub fn embeddings_mut(&mut self) -> &mut Array2<F> {
598 &mut self.rel_embeddings
599 }
600}
601
602impl<F: Float + Debug + NumAssign> PositionalEncoding<F> for RelativePositionalEncoding<F> {
603 fn encode(&self, seq_len: usize) -> Array2<F> {
604 let start = self.max_len - 1;
606 self.rel_embeddings
607 .slice(s![start..(start + seq_len), ..])
608 .to_owned()
609 }
610
611 fn apply(&self, input: &Array3<F>) -> Result<Array3<F>> {
612 let seq_len = input.shape()[1];
615 if seq_len > self.max_len {
616 return Err(NeuralError::InvalidArchitecture(format!(
617 "Sequence length {} exceeds max_len {}",
618 seq_len, self.max_len
619 )));
620 }
621
622 let encoding = self.encode(seq_len);
623 let mut output = input.clone();
624
625 for mut batch in output.axis_iter_mut(Axis(0)) {
626 Zip::from(&mut batch)
627 .and(&encoding)
628 .for_each(|b, &e| *b += e);
629 }
630
631 Ok(output)
632 }
633
634 fn clone_box(&self) -> Box<dyn PositionalEncoding<F> + Send + Sync>
635 where
636 F: Send + Sync + 'static,
637 {
638 Box::new(self.clone())
639 }
640
641 fn d_model(&self) -> usize {
642 self.d_model
643 }
644
645 fn max_len(&self) -> usize {
646 self.max_len
647 }
648}
649
650#[derive(Debug, Clone, Copy, PartialEq, Eq)]
652pub enum PositionalEncodingType {
653 Sinusoidal,
655 Learned,
657 Rotary,
659 Relative,
661}
662
663pub struct PositionalEncodingFactory;
665
666impl PositionalEncodingFactory {
667 pub fn create<F, R>(
669 pe_type: PositionalEncodingType,
670 d_model: usize,
671 max_len: usize,
672 rng: &mut R,
673 ) -> Box<dyn PositionalEncoding<F> + Send + Sync>
674 where
675 F: Float + Debug + NumAssign + Send + Sync + 'static,
676 R: Rng,
677 {
678 match pe_type {
679 PositionalEncodingType::Sinusoidal => {
680 Box::new(SinusoidalPositionalEncoding::new(d_model, max_len))
681 }
682 PositionalEncodingType::Learned => {
683 Box::new(LearnedPositionalEncoding::new(d_model, max_len, rng))
684 }
685 PositionalEncodingType::Rotary => {
686 Box::new(RotaryPositionalEncoding::default_base(d_model, max_len))
687 }
688 PositionalEncodingType::Relative => {
689 Box::new(RelativePositionalEncoding::new(d_model, max_len, rng))
690 }
691 }
692 }
693}
694
695#[cfg(test)]
696mod tests {
697 use super::*;
698 use scirs2_core::ndarray::Array3;
699 use scirs2_core::random::SeedableRng;
700
701 #[test]
702 fn test_sinusoidal_encoding_shape() {
703 let pe = SinusoidalPositionalEncoding::<f32>::new(64, 100);
704
705 let encoding = pe.encode(10);
706 assert_eq!(encoding.shape(), &[10, 64]);
707
708 let encoding = pe.encode(50);
709 assert_eq!(encoding.shape(), &[50, 64]);
710 }
711
712 #[test]
713 fn test_sinusoidal_encoding_values() {
714 let pe = SinusoidalPositionalEncoding::<f64>::new(4, 10);
715
716 let encoding = pe.encode(3);
717
718 assert!((encoding[[0, 0]] - 0.0).abs() < 1e-6); assert!((encoding[[0, 1]] - 1.0).abs() < 1e-6); assert!((encoding[[0, 0]] - encoding[[1, 0]]).abs() > 1e-10);
724 }
725
726 #[test]
727 fn test_sinusoidal_apply() {
728 let pe = SinusoidalPositionalEncoding::<f32>::new(8, 20);
729
730 let input = Array3::zeros((2, 10, 8)); let output = pe.apply(&input).expect("Operation failed");
732
733 assert_eq!(output.shape(), input.shape());
734
735 let encoding = pe.encode(10);
737 for b in 0..2 {
738 for s in 0..10 {
739 for d in 0..8 {
740 assert!((output[[b, s, d]] - encoding[[s, d]]).abs() < 1e-6);
741 }
742 }
743 }
744 }
745
746 #[test]
747 fn test_learned_encoding() {
748 let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(42);
749 let pe = LearnedPositionalEncoding::<f32>::new(32, 50, &mut rng);
750
751 let encoding = pe.encode(10);
752 assert_eq!(encoding.shape(), &[10, 32]);
753
754 let sum: f32 = encoding.iter().map(|x| x.abs()).sum();
756 assert!(sum > 0.1);
757 }
758
759 #[test]
760 fn test_learned_from_embeddings() {
761 let embeddings = Array2::ones((20, 16));
762 let pe = LearnedPositionalEncoding::<f32>::from_embeddings(embeddings);
763
764 assert_eq!(pe.d_model(), 16);
765 assert_eq!(pe.max_len(), 20);
766 }
767
768 #[test]
769 fn test_rope_encoding() {
770 let pe = RotaryPositionalEncoding::<f32>::default_base(64, 100);
771
772 let encoding = pe.encode(10);
773 assert_eq!(encoding.shape(), &[10, 64]);
774 }
775
776 #[test]
777 fn test_rope_rotate() {
778 let pe = RotaryPositionalEncoding::<f64>::default_base(8, 20);
779
780 let input = Array3::ones((1, 5, 8));
781 let rotated = pe.rotate(&input, 0).expect("Operation failed");
782
783 assert_eq!(rotated.shape(), input.shape());
784
785 let mut different = false;
788 for pos in 1..5 {
789 for i in 0..8 {
790 if (rotated[[0, pos, i]] - input[[0, pos, i]]).abs() > 1e-6 {
791 different = true;
792 break;
793 }
794 }
795 if different {
796 break;
797 }
798 }
799 assert!(
800 different,
801 "RoPE should modify input values at non-zero positions"
802 );
803 }
804
805 #[test]
806 fn test_rope_with_offset() {
807 let pe = RotaryPositionalEncoding::<f32>::default_base(8, 100);
808
809 let input = Array3::ones((1, 10, 8));
810
811 let rotated_0 = pe.rotate(&input, 0).expect("Operation failed");
812 let rotated_5 = pe.rotate(&input, 5).expect("Operation failed");
813
814 let mut different = false;
816 for s in 0..10 {
817 for d in 0..8 {
818 if (rotated_0[[0, s, d]] - rotated_5[[0, s, d]]).abs() > 1e-6 {
819 different = true;
820 break;
821 }
822 }
823 }
824 assert!(different, "Different offsets should give different results");
825 }
826
827 #[test]
828 fn test_relative_encoding() {
829 let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(42);
830 let pe = RelativePositionalEncoding::<f32>::new(16, 30, &mut rng);
831
832 let rel_0 = pe.get_relative_embedding(0);
834 assert!(rel_0.is_some());
835
836 let rel_pos = pe.get_relative_embedding(5);
837 assert!(rel_pos.is_some());
838
839 let rel_neg = pe.get_relative_embedding(-5);
840 assert!(rel_neg.is_some());
841
842 let out_of_range = pe.get_relative_embedding(100);
844 assert!(out_of_range.is_none());
845 }
846
847 #[test]
848 fn test_relative_attention_bias() {
849 let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(42);
850 let pe = RelativePositionalEncoding::<f32>::new(8, 20, &mut rng);
851
852 let bias = pe.get_attention_bias(10, 10).expect("Operation failed");
853 assert_eq!(bias.shape(), &[10, 10, 8]);
854
855 let rel_0 = pe.get_relative_embedding(0).expect("Operation failed");
857 for i in 0..10 {
858 for d in 0..8 {
859 assert!((bias[[i, i, d]] - rel_0[[d]]).abs() < 1e-6);
860 }
861 }
862 }
863
864 #[test]
865 fn test_factory() {
866 let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(42);
867
868 let sinusoidal = PositionalEncodingFactory::create::<f32, _>(
869 PositionalEncodingType::Sinusoidal,
870 32,
871 100,
872 &mut rng,
873 );
874 assert_eq!(sinusoidal.d_model(), 32);
875
876 let learned = PositionalEncodingFactory::create::<f32, _>(
877 PositionalEncodingType::Learned,
878 32,
879 100,
880 &mut rng,
881 );
882 assert_eq!(learned.d_model(), 32);
883
884 let rotary = PositionalEncodingFactory::create::<f32, _>(
885 PositionalEncodingType::Rotary,
886 32,
887 100,
888 &mut rng,
889 );
890 assert_eq!(rotary.d_model(), 32);
891
892 let relative = PositionalEncodingFactory::create::<f32, _>(
893 PositionalEncodingType::Relative,
894 32,
895 100,
896 &mut rng,
897 );
898 assert_eq!(relative.d_model(), 32);
899 }
900
901 #[test]
902 fn test_sinusoidal_properties() {
903 let pe = SinusoidalPositionalEncoding::<f64>::new(64, 1000);
904 let encoding = pe.encode(100);
905
906 for i in 0..99 {
908 let mut same = true;
909 for d in 0..64 {
910 if (encoding[[i, d]] - encoding[[i + 1, d]]).abs() > 1e-10 {
911 same = false;
912 break;
913 }
914 }
915 assert!(!same, "Adjacent positions should be different");
916 }
917 }
918
919 #[test]
920 fn test_max_len_error() {
921 let pe = SinusoidalPositionalEncoding::<f32>::new(16, 10);
922
923 let input = Array3::zeros((1, 20, 16)); let result = pe.apply(&input);
925
926 assert!(result.is_err());
927 }
928}