Skip to main content

scirs2_neural/utils/
positional_encoding.rs

1//! Positional Encoding for Transformer Models
2//!
3//! This module provides various positional encoding strategies used in transformer
4//! architectures. Positional encodings inject sequence order information into
5//! the model since self-attention is permutation invariant.
6//!
7//! # Available Encodings
8//!
9//! - **Sinusoidal**: Classic fixed positional encoding from "Attention Is All You Need"
10//! - **Learned**: Trainable position embeddings (like BERT)
11//! - **Relative**: Position-relative encodings for better length generalization
12//! - **Rotary (RoPE)**: Rotary position embeddings used in modern LLMs
13//!
14//! # Examples
15//!
16//! ```rust
17//! use scirs2_neural::utils::positional_encoding::{
18//!     SinusoidalPositionalEncoding, PositionalEncoding
19//! };
20//! use scirs2_core::ndarray::Array2;
21//!
22//! // Create sinusoidal encoding for d_model=64, max_len=100
23//! let pe = SinusoidalPositionalEncoding::<f32>::new(64, 100);
24//!
25//! // Get encoding for sequence of length 10
26//! let encoding = pe.encode(10);
27//! assert_eq!(encoding.shape(), &[10, 64]);
28//! ```
29
30use crate::error::{NeuralError, Result};
31use scirs2_core::ndarray::{s, Array, Array2, Array3, Axis, IxDyn, Zip};
32use scirs2_core::numeric::{Float, NumAssign};
33use scirs2_core::random::{Rng, RngExt};
34use std::f64::consts::PI;
35use std::fmt::Debug;
36
37/// Trait for positional encoding implementations
38pub trait PositionalEncoding<F: Float + Debug + NumAssign> {
39    /// Encode positions for a sequence of given length
40    ///
41    /// # Arguments
42    /// * `seq_len` - Length of the sequence
43    ///
44    /// # Returns
45    /// Position encodings of shape [seq_len, d_model]
46    fn encode(&self, seq_len: usize) -> Array2<F>;
47
48    /// Apply positional encoding to an input tensor
49    ///
50    /// # Arguments
51    /// * `input` - Input tensor of shape [batch_size, seq_len, d_model]
52    ///
53    /// # Returns
54    /// Tensor with positional encoding added
55    fn apply(&self, input: &Array3<F>) -> Result<Array3<F>>;
56
57    /// Forward pass with dynamic array
58    ///
59    /// # Arguments
60    /// * `input` - Input tensor of dynamic shape [batch_size, seq_len, d_model]
61    ///
62    /// # Returns
63    /// Tensor with positional encoding added
64    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
65        // Convert IxDyn to Array3
66        if input.ndim() != 3 {
67            return Err(NeuralError::InvalidArchitecture(format!(
68                "Expected 3D input, got {}D",
69                input.ndim()
70            )));
71        }
72
73        let shape = input.shape();
74        let input_3d = input
75            .view()
76            .into_dimensionality::<scirs2_core::ndarray::Ix3>()
77            .map_err(|e| {
78                NeuralError::InvalidArchitecture(format!("Failed to convert to 3D: {}", e))
79            })?;
80
81        let output_3d = self.apply(&input_3d.to_owned())?;
82        Ok(output_3d.into_dyn())
83    }
84
85    /// Update trainable parameters (for learned encodings)
86    ///
87    /// # Arguments
88    /// * `_learning_rate` - Learning rate for parameter updates
89    ///
90    /// # Returns
91    /// Result indicating success or failure
92    fn update(&mut self, _learning_rate: F) -> Result<()> {
93        // Default implementation: no-op for non-trainable encodings
94        Ok(())
95    }
96
97    /// Clone the positional encoding into a boxed trait object
98    fn clone_box(&self) -> Box<dyn PositionalEncoding<F> + Send + Sync>
99    where
100        F: Send + Sync + 'static;
101
102    /// Get the model dimension
103    fn d_model(&self) -> usize;
104
105    /// Get the maximum sequence length supported
106    fn max_len(&self) -> usize;
107}
108
109/// Sinusoidal Positional Encoding
110///
111/// Implements the classic positional encoding from "Attention Is All You Need":
112/// PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
113/// PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
114///
115/// This encoding has several desirable properties:
116/// - Deterministic and parameter-free
117/// - Each position gets a unique encoding
118/// - Can extrapolate to longer sequences than seen during training
119/// - Allows the model to attend to relative positions
120#[derive(Debug, Clone)]
121pub struct SinusoidalPositionalEncoding<F: Float + Debug + NumAssign> {
122    d_model: usize,
123    max_len: usize,
124    /// Pre-computed positional encodings
125    encodings: Array2<F>,
126    /// Dropout rate (optional)
127    dropout: Option<F>,
128}
129
130impl<F: Float + Debug + NumAssign> SinusoidalPositionalEncoding<F> {
131    /// Create a new sinusoidal positional encoding
132    ///
133    /// # Arguments
134    /// * `d_model` - Model dimension (must be even)
135    /// * `max_len` - Maximum sequence length
136    pub fn new(d_model: usize, max_len: usize) -> Self {
137        assert!(
138            d_model.is_multiple_of(2),
139            "d_model must be even for sinusoidal PE"
140        );
141
142        let encodings = Self::compute_encodings(d_model, max_len);
143
144        Self {
145            d_model,
146            max_len,
147            encodings,
148            dropout: None,
149        }
150    }
151
152    /// Create with dropout
153    pub fn with_dropout(d_model: usize, max_len: usize, dropout: F) -> Self {
154        let mut pe = Self::new(d_model, max_len);
155        pe.dropout = Some(dropout);
156        pe
157    }
158
159    /// Compute the sinusoidal encodings
160    fn compute_encodings(d_model: usize, max_len: usize) -> Array2<F> {
161        let mut encodings = Array2::zeros((max_len, d_model));
162
163        for pos in 0..max_len {
164            for i in 0..(d_model / 2) {
165                // Compute the divisor: 10000^(2i/d_model)
166                let exponent = (2 * i) as f64 / d_model as f64;
167                let div_term = (10000.0_f64).powf(exponent);
168                let angle = pos as f64 / div_term;
169
170                // sin for even indices, cos for odd indices
171                let sin_val = F::from(angle.sin()).unwrap_or(F::zero());
172                let cos_val = F::from(angle.cos()).unwrap_or(F::zero());
173
174                encodings[[pos, 2 * i]] = sin_val;
175                encodings[[pos, 2 * i + 1]] = cos_val;
176            }
177        }
178
179        encodings
180    }
181
182    /// Get parameters (sinusoidal encoding has no trainable parameters)
183    pub fn params(&self) -> Vec<&Array<F, IxDyn>> {
184        Vec::new()
185    }
186
187    /// Set training mode (no-op for sinusoidal encoding as it has no trainable parameters)
188    pub fn set_training(&mut self, _training: bool) {
189        // No-op: sinusoidal encoding is not trainable
190    }
191}
192
193impl<F: Float + Debug + NumAssign> PositionalEncoding<F> for SinusoidalPositionalEncoding<F> {
194    fn encode(&self, seq_len: usize) -> Array2<F> {
195        assert!(
196            seq_len <= self.max_len,
197            "seq_len {} exceeds max_len {}",
198            seq_len,
199            self.max_len
200        );
201        self.encodings.slice(s![..seq_len, ..]).to_owned()
202    }
203
204    fn apply(&self, input: &Array3<F>) -> Result<Array3<F>> {
205        let seq_len = input.shape()[1];
206        if seq_len > self.max_len {
207            return Err(NeuralError::InvalidArchitecture(format!(
208                "Sequence length {} exceeds max_len {}",
209                seq_len, self.max_len
210            )));
211        }
212
213        let encoding = self.encode(seq_len);
214        let mut output = input.clone();
215
216        // Add positional encoding to each batch
217        for mut batch in output.axis_iter_mut(Axis(0)) {
218            Zip::from(&mut batch)
219                .and(&encoding)
220                .for_each(|b, &e| *b += e);
221        }
222
223        Ok(output)
224    }
225
226    fn clone_box(&self) -> Box<dyn PositionalEncoding<F> + Send + Sync>
227    where
228        F: Send + Sync + 'static,
229    {
230        Box::new(self.clone())
231    }
232
233    fn d_model(&self) -> usize {
234        self.d_model
235    }
236
237    fn max_len(&self) -> usize {
238        self.max_len
239    }
240}
241
242/// Learned Positional Encoding
243///
244/// Uses trainable embeddings for each position, similar to BERT.
245/// More flexible than sinusoidal but requires training data and
246/// doesn't extrapolate to longer sequences.
247#[derive(Debug, Clone)]
248pub struct LearnedPositionalEncoding<F: Float + Debug + NumAssign> {
249    d_model: usize,
250    max_len: usize,
251    /// Learnable position embeddings
252    embeddings: Array2<F>,
253}
254
255impl<F: Float + Debug + NumAssign> LearnedPositionalEncoding<F> {
256    /// Create a new learned positional encoding with random initialization
257    ///
258    /// # Arguments
259    /// * `d_model` - Model dimension
260    /// * `max_len` - Maximum sequence length
261    /// * `rng` - Random number generator
262    pub fn new<R: Rng>(d_model: usize, max_len: usize, rng: &mut R) -> Self {
263        // Xavier/Glorot initialization
264        let std = (2.0 / (max_len + d_model) as f64).sqrt();
265
266        let mut embeddings = Array2::zeros((max_len, d_model));
267        for elem in embeddings.iter_mut() {
268            // Box-Muller transform for normal distribution
269            let u1: f64 = rng.random_range(0.0001..1.0);
270            let u2: f64 = rng.random_range(0.0..1.0);
271            let z = (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos();
272            *elem = F::from(z * std).unwrap_or(F::zero());
273        }
274
275        Self {
276            d_model,
277            max_len,
278            embeddings,
279        }
280    }
281
282    /// Create from existing embeddings
283    pub fn from_embeddings(embeddings: Array2<F>) -> Self {
284        let shape = embeddings.shape();
285        Self {
286            d_model: shape[1],
287            max_len: shape[0],
288            embeddings,
289        }
290    }
291
292    /// Get mutable reference to embeddings for training
293    pub fn embeddings_mut(&mut self) -> &mut Array2<F> {
294        &mut self.embeddings
295    }
296
297    /// Get reference to embeddings
298    pub fn embeddings(&self) -> &Array2<F> {
299        &self.embeddings
300    }
301}
302
303impl<F: Float + Debug + NumAssign> PositionalEncoding<F> for LearnedPositionalEncoding<F> {
304    fn encode(&self, seq_len: usize) -> Array2<F> {
305        assert!(
306            seq_len <= self.max_len,
307            "seq_len {} exceeds max_len {}",
308            seq_len,
309            self.max_len
310        );
311        self.embeddings.slice(s![..seq_len, ..]).to_owned()
312    }
313
314    fn apply(&self, input: &Array3<F>) -> Result<Array3<F>> {
315        let seq_len = input.shape()[1];
316        if seq_len > self.max_len {
317            return Err(NeuralError::InvalidArchitecture(format!(
318                "Sequence length {} exceeds max_len {}",
319                seq_len, self.max_len
320            )));
321        }
322
323        let encoding = self.encode(seq_len);
324        let mut output = input.clone();
325
326        for mut batch in output.axis_iter_mut(Axis(0)) {
327            Zip::from(&mut batch)
328                .and(&encoding)
329                .for_each(|b, &e| *b += e);
330        }
331
332        Ok(output)
333    }
334
335    fn clone_box(&self) -> Box<dyn PositionalEncoding<F> + Send + Sync>
336    where
337        F: Send + Sync + 'static,
338    {
339        Box::new(self.clone())
340    }
341
342    fn d_model(&self) -> usize {
343        self.d_model
344    }
345
346    fn max_len(&self) -> usize {
347        self.max_len
348    }
349}
350
351/// Rotary Positional Encoding (RoPE)
352///
353/// Implements rotary position embeddings as described in "RoFormer: Enhanced Transformer
354/// with Rotary Position Embedding". RoPE encodes position information by rotating
355/// the query and key vectors, which has several advantages:
356///
357/// - Relative position information is naturally encoded
358/// - Better length extrapolation than absolute position encodings
359/// - Used in modern LLMs like LLaMA, GPT-NeoX, etc.
360#[derive(Debug, Clone)]
361pub struct RotaryPositionalEncoding<F: Float + Debug + NumAssign> {
362    d_model: usize,
363    max_len: usize,
364    base: f64,
365    /// Pre-computed sin values: [max_len, d_model/2]
366    sin_cached: Array2<F>,
367    /// Pre-computed cos values: [max_len, d_model/2]
368    cos_cached: Array2<F>,
369}
370
371impl<F: Float + Debug + NumAssign> RotaryPositionalEncoding<F> {
372    /// Create a new RoPE encoding
373    ///
374    /// # Arguments
375    /// * `d_model` - Model dimension (must be even)
376    /// * `max_len` - Maximum sequence length
377    /// * `base` - Base for frequency computation (default: 10000.0)
378    pub fn new(d_model: usize, max_len: usize, base: f64) -> Self {
379        assert!(d_model.is_multiple_of(2), "d_model must be even for RoPE");
380
381        let (sin_cached, cos_cached) = Self::compute_rope_cache(d_model, max_len, base);
382
383        Self {
384            d_model,
385            max_len,
386            base,
387            sin_cached,
388            cos_cached,
389        }
390    }
391
392    /// Create with default base (10000.0)
393    pub fn default_base(d_model: usize, max_len: usize) -> Self {
394        Self::new(d_model, max_len, 10000.0)
395    }
396
397    /// Compute the RoPE sin/cos cache
398    fn compute_rope_cache(d_model: usize, max_len: usize, base: f64) -> (Array2<F>, Array2<F>) {
399        let half_dim = d_model / 2;
400        let mut sin_cached = Array2::zeros((max_len, half_dim));
401        let mut cos_cached = Array2::zeros((max_len, half_dim));
402
403        // Compute inverse frequencies
404        for pos in 0..max_len {
405            for i in 0..half_dim {
406                let freq = 1.0 / base.powf((2 * i) as f64 / d_model as f64);
407                let angle = pos as f64 * freq;
408
409                sin_cached[[pos, i]] = F::from(angle.sin()).unwrap_or(F::zero());
410                cos_cached[[pos, i]] = F::from(angle.cos()).unwrap_or(F::zero());
411            }
412        }
413
414        (sin_cached, cos_cached)
415    }
416
417    /// Apply rotary embedding to query or key tensor
418    ///
419    /// # Arguments
420    /// * `x` - Input tensor of shape [batch, seq_len, d_model]
421    /// * `offset` - Position offset (for KV cache during inference)
422    ///
423    /// # Returns
424    /// Rotated tensor with same shape
425    pub fn rotate(&self, x: &Array3<F>, offset: usize) -> Result<Array3<F>> {
426        let seq_len = x.shape()[1];
427        if seq_len + offset > self.max_len {
428            return Err(NeuralError::InvalidArchitecture(format!(
429                "Position {} exceeds max_len {}",
430                seq_len + offset,
431                self.max_len
432            )));
433        }
434
435        let batch_size = x.shape()[0];
436        let half_dim = self.d_model / 2;
437
438        let mut output = Array3::zeros(x.raw_dim());
439
440        for b in 0..batch_size {
441            for pos in 0..seq_len {
442                let abs_pos = pos + offset;
443                for i in 0..half_dim {
444                    let x1 = x[[b, pos, 2 * i]];
445                    let x2 = x[[b, pos, 2 * i + 1]];
446
447                    let cos = self.cos_cached[[abs_pos, i]];
448                    let sin = self.sin_cached[[abs_pos, i]];
449
450                    // Apply rotation: [cos, -sin; sin, cos] * [x1; x2]
451                    output[[b, pos, 2 * i]] = x1 * cos - x2 * sin;
452                    output[[b, pos, 2 * i + 1]] = x1 * sin + x2 * cos;
453                }
454            }
455        }
456
457        Ok(output)
458    }
459
460    /// Get sin cache
461    pub fn sin_cache(&self) -> &Array2<F> {
462        &self.sin_cached
463    }
464
465    /// Get cos cache
466    pub fn cos_cache(&self) -> &Array2<F> {
467        &self.cos_cached
468    }
469}
470
471impl<F: Float + Debug + NumAssign> PositionalEncoding<F> for RotaryPositionalEncoding<F> {
472    fn encode(&self, seq_len: usize) -> Array2<F> {
473        // Return combined sin/cos for compatibility
474        // In practice, use rotate() method directly
475        let half_dim = self.d_model / 2;
476        let mut encoding = Array2::zeros((seq_len, self.d_model));
477
478        for pos in 0..seq_len {
479            for i in 0..half_dim {
480                encoding[[pos, 2 * i]] = self.sin_cached[[pos, i]];
481                encoding[[pos, 2 * i + 1]] = self.cos_cached[[pos, i]];
482            }
483        }
484
485        encoding
486    }
487
488    fn apply(&self, input: &Array3<F>) -> Result<Array3<F>> {
489        // For RoPE, apply is the same as rotate with offset 0
490        self.rotate(input, 0)
491    }
492
493    fn clone_box(&self) -> Box<dyn PositionalEncoding<F> + Send + Sync>
494    where
495        F: Send + Sync + 'static,
496    {
497        Box::new(self.clone())
498    }
499
500    fn d_model(&self) -> usize {
501        self.d_model
502    }
503
504    fn max_len(&self) -> usize {
505        self.max_len
506    }
507}
508
509/// Relative Positional Encoding
510///
511/// Implements relative position encodings that represent the distance between
512/// positions rather than absolute positions. This allows better generalization
513/// to longer sequences.
514#[derive(Debug, Clone)]
515pub struct RelativePositionalEncoding<F: Float + Debug + NumAssign> {
516    d_model: usize,
517    max_len: usize,
518    /// Relative position embeddings: [2*max_len-1, d_model]
519    /// Index 0 = position -(max_len-1), index max_len-1 = position 0
520    rel_embeddings: Array2<F>,
521}
522
523impl<F: Float + Debug + NumAssign> RelativePositionalEncoding<F> {
524    /// Create a new relative positional encoding
525    ///
526    /// # Arguments
527    /// * `d_model` - Model dimension
528    /// * `max_len` - Maximum sequence length
529    /// * `rng` - Random number generator
530    pub fn new<R: Rng>(d_model: usize, max_len: usize, rng: &mut R) -> Self {
531        let num_positions = 2 * max_len - 1;
532        let std = (1.0 / d_model as f64).sqrt();
533
534        let mut rel_embeddings = Array2::zeros((num_positions, d_model));
535        for elem in rel_embeddings.iter_mut() {
536            let u1: f64 = rng.random_range(0.0001..1.0);
537            let u2: f64 = rng.random_range(0.0..1.0);
538            let z = (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos();
539            *elem = F::from(z * std).unwrap_or(F::zero());
540        }
541
542        Self {
543            d_model,
544            max_len,
545            rel_embeddings,
546        }
547    }
548
549    /// Get embedding for a relative position
550    ///
551    /// # Arguments
552    /// * `rel_pos` - Relative position (can be negative)
553    pub fn get_relative_embedding(&self, rel_pos: i64) -> Option<Array<F, IxDyn>> {
554        let max_rel = self.max_len as i64 - 1;
555        if rel_pos < -max_rel || rel_pos > max_rel {
556            return None;
557        }
558
559        let idx = (rel_pos + max_rel) as usize;
560        Some(self.rel_embeddings.slice(s![idx, ..]).to_owned().into_dyn())
561    }
562
563    /// Get relative position bias matrix for attention
564    ///
565    /// # Arguments
566    /// * `query_len` - Length of query sequence
567    /// * `key_len` - Length of key sequence
568    ///
569    /// # Returns
570    /// Relative position bias of shape [query_len, key_len, d_model]
571    pub fn get_attention_bias(&self, query_len: usize, key_len: usize) -> Result<Array3<F>> {
572        if query_len > self.max_len || key_len > self.max_len {
573            return Err(NeuralError::InvalidArchitecture(format!(
574                "Sequence length exceeds max_len {}",
575                self.max_len
576            )));
577        }
578
579        let mut bias = Array3::zeros((query_len, key_len, self.d_model));
580        let max_rel = self.max_len as i64 - 1;
581
582        for q in 0..query_len {
583            for k in 0..key_len {
584                let rel_pos = k as i64 - q as i64;
585                let idx = (rel_pos + max_rel) as usize;
586
587                for d in 0..self.d_model {
588                    bias[[q, k, d]] = self.rel_embeddings[[idx, d]];
589                }
590            }
591        }
592
593        Ok(bias)
594    }
595
596    /// Get mutable reference to embeddings
597    pub fn embeddings_mut(&mut self) -> &mut Array2<F> {
598        &mut self.rel_embeddings
599    }
600}
601
602impl<F: Float + Debug + NumAssign> PositionalEncoding<F> for RelativePositionalEncoding<F> {
603    fn encode(&self, seq_len: usize) -> Array2<F> {
604        // For relative PE, return the central positions (around 0 relative position)
605        let start = self.max_len - 1;
606        self.rel_embeddings
607            .slice(s![start..(start + seq_len), ..])
608            .to_owned()
609    }
610
611    fn apply(&self, input: &Array3<F>) -> Result<Array3<F>> {
612        // For relative PE, typically used differently in attention
613        // This provides a simple fallback that adds the center embeddings
614        let seq_len = input.shape()[1];
615        if seq_len > self.max_len {
616            return Err(NeuralError::InvalidArchitecture(format!(
617                "Sequence length {} exceeds max_len {}",
618                seq_len, self.max_len
619            )));
620        }
621
622        let encoding = self.encode(seq_len);
623        let mut output = input.clone();
624
625        for mut batch in output.axis_iter_mut(Axis(0)) {
626            Zip::from(&mut batch)
627                .and(&encoding)
628                .for_each(|b, &e| *b += e);
629        }
630
631        Ok(output)
632    }
633
634    fn clone_box(&self) -> Box<dyn PositionalEncoding<F> + Send + Sync>
635    where
636        F: Send + Sync + 'static,
637    {
638        Box::new(self.clone())
639    }
640
641    fn d_model(&self) -> usize {
642        self.d_model
643    }
644
645    fn max_len(&self) -> usize {
646        self.max_len
647    }
648}
649
650/// Factory for creating positional encodings
651#[derive(Debug, Clone, Copy, PartialEq, Eq)]
652pub enum PositionalEncodingType {
653    /// Fixed sinusoidal encoding (Transformer)
654    Sinusoidal,
655    /// Learnable position embeddings (BERT)
656    Learned,
657    /// Rotary position embeddings (RoPE)
658    Rotary,
659    /// Relative position encoding
660    Relative,
661}
662
663/// Factory for creating positional encodings
664pub struct PositionalEncodingFactory;
665
666impl PositionalEncodingFactory {
667    /// Create a positional encoding of the specified type
668    pub fn create<F, R>(
669        pe_type: PositionalEncodingType,
670        d_model: usize,
671        max_len: usize,
672        rng: &mut R,
673    ) -> Box<dyn PositionalEncoding<F> + Send + Sync>
674    where
675        F: Float + Debug + NumAssign + Send + Sync + 'static,
676        R: Rng,
677    {
678        match pe_type {
679            PositionalEncodingType::Sinusoidal => {
680                Box::new(SinusoidalPositionalEncoding::new(d_model, max_len))
681            }
682            PositionalEncodingType::Learned => {
683                Box::new(LearnedPositionalEncoding::new(d_model, max_len, rng))
684            }
685            PositionalEncodingType::Rotary => {
686                Box::new(RotaryPositionalEncoding::default_base(d_model, max_len))
687            }
688            PositionalEncodingType::Relative => {
689                Box::new(RelativePositionalEncoding::new(d_model, max_len, rng))
690            }
691        }
692    }
693}
694
695#[cfg(test)]
696mod tests {
697    use super::*;
698    use scirs2_core::ndarray::Array3;
699    use scirs2_core::random::SeedableRng;
700
701    #[test]
702    fn test_sinusoidal_encoding_shape() {
703        let pe = SinusoidalPositionalEncoding::<f32>::new(64, 100);
704
705        let encoding = pe.encode(10);
706        assert_eq!(encoding.shape(), &[10, 64]);
707
708        let encoding = pe.encode(50);
709        assert_eq!(encoding.shape(), &[50, 64]);
710    }
711
712    #[test]
713    fn test_sinusoidal_encoding_values() {
714        let pe = SinusoidalPositionalEncoding::<f64>::new(4, 10);
715
716        let encoding = pe.encode(3);
717
718        // Position 0 should have sin(0)=0, cos(0)=1 for the first pair
719        assert!((encoding[[0, 0]] - 0.0).abs() < 1e-6); // sin(0)
720        assert!((encoding[[0, 1]] - 1.0).abs() < 1e-6); // cos(0)
721
722        // Each position should have different values
723        assert!((encoding[[0, 0]] - encoding[[1, 0]]).abs() > 1e-10);
724    }
725
726    #[test]
727    fn test_sinusoidal_apply() {
728        let pe = SinusoidalPositionalEncoding::<f32>::new(8, 20);
729
730        let input = Array3::zeros((2, 10, 8)); // batch=2, seq=10, d_model=8
731        let output = pe.apply(&input).expect("Operation failed");
732
733        assert_eq!(output.shape(), input.shape());
734
735        // Check that encoding was added
736        let encoding = pe.encode(10);
737        for b in 0..2 {
738            for s in 0..10 {
739                for d in 0..8 {
740                    assert!((output[[b, s, d]] - encoding[[s, d]]).abs() < 1e-6);
741                }
742            }
743        }
744    }
745
746    #[test]
747    fn test_learned_encoding() {
748        let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(42);
749        let pe = LearnedPositionalEncoding::<f32>::new(32, 50, &mut rng);
750
751        let encoding = pe.encode(10);
752        assert_eq!(encoding.shape(), &[10, 32]);
753
754        // Values should be initialized (not all zero)
755        let sum: f32 = encoding.iter().map(|x| x.abs()).sum();
756        assert!(sum > 0.1);
757    }
758
759    #[test]
760    fn test_learned_from_embeddings() {
761        let embeddings = Array2::ones((20, 16));
762        let pe = LearnedPositionalEncoding::<f32>::from_embeddings(embeddings);
763
764        assert_eq!(pe.d_model(), 16);
765        assert_eq!(pe.max_len(), 20);
766    }
767
768    #[test]
769    fn test_rope_encoding() {
770        let pe = RotaryPositionalEncoding::<f32>::default_base(64, 100);
771
772        let encoding = pe.encode(10);
773        assert_eq!(encoding.shape(), &[10, 64]);
774    }
775
776    #[test]
777    fn test_rope_rotate() {
778        let pe = RotaryPositionalEncoding::<f64>::default_base(8, 20);
779
780        let input = Array3::ones((1, 5, 8));
781        let rotated = pe.rotate(&input, 0).expect("Operation failed");
782
783        assert_eq!(rotated.shape(), input.shape());
784
785        // At position 0, cos(0)=1, sin(0)=0 so rotation is identity
786        // Check position 1 or higher where rotation actually occurs
787        let mut different = false;
788        for pos in 1..5 {
789            for i in 0..8 {
790                if (rotated[[0, pos, i]] - input[[0, pos, i]]).abs() > 1e-6 {
791                    different = true;
792                    break;
793                }
794            }
795            if different {
796                break;
797            }
798        }
799        assert!(
800            different,
801            "RoPE should modify input values at non-zero positions"
802        );
803    }
804
805    #[test]
806    fn test_rope_with_offset() {
807        let pe = RotaryPositionalEncoding::<f32>::default_base(8, 100);
808
809        let input = Array3::ones((1, 10, 8));
810
811        let rotated_0 = pe.rotate(&input, 0).expect("Operation failed");
812        let rotated_5 = pe.rotate(&input, 5).expect("Operation failed");
813
814        // Different offsets should give different results
815        let mut different = false;
816        for s in 0..10 {
817            for d in 0..8 {
818                if (rotated_0[[0, s, d]] - rotated_5[[0, s, d]]).abs() > 1e-6 {
819                    different = true;
820                    break;
821                }
822            }
823        }
824        assert!(different, "Different offsets should give different results");
825    }
826
827    #[test]
828    fn test_relative_encoding() {
829        let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(42);
830        let pe = RelativePositionalEncoding::<f32>::new(16, 30, &mut rng);
831
832        // Check relative embedding retrieval
833        let rel_0 = pe.get_relative_embedding(0);
834        assert!(rel_0.is_some());
835
836        let rel_pos = pe.get_relative_embedding(5);
837        assert!(rel_pos.is_some());
838
839        let rel_neg = pe.get_relative_embedding(-5);
840        assert!(rel_neg.is_some());
841
842        // Out of range should return None
843        let out_of_range = pe.get_relative_embedding(100);
844        assert!(out_of_range.is_none());
845    }
846
847    #[test]
848    fn test_relative_attention_bias() {
849        let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(42);
850        let pe = RelativePositionalEncoding::<f32>::new(8, 20, &mut rng);
851
852        let bias = pe.get_attention_bias(10, 10).expect("Operation failed");
853        assert_eq!(bias.shape(), &[10, 10, 8]);
854
855        // Diagonal should have same values (relative position 0)
856        let rel_0 = pe.get_relative_embedding(0).expect("Operation failed");
857        for i in 0..10 {
858            for d in 0..8 {
859                assert!((bias[[i, i, d]] - rel_0[[d]]).abs() < 1e-6);
860            }
861        }
862    }
863
864    #[test]
865    fn test_factory() {
866        let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(42);
867
868        let sinusoidal = PositionalEncodingFactory::create::<f32, _>(
869            PositionalEncodingType::Sinusoidal,
870            32,
871            100,
872            &mut rng,
873        );
874        assert_eq!(sinusoidal.d_model(), 32);
875
876        let learned = PositionalEncodingFactory::create::<f32, _>(
877            PositionalEncodingType::Learned,
878            32,
879            100,
880            &mut rng,
881        );
882        assert_eq!(learned.d_model(), 32);
883
884        let rotary = PositionalEncodingFactory::create::<f32, _>(
885            PositionalEncodingType::Rotary,
886            32,
887            100,
888            &mut rng,
889        );
890        assert_eq!(rotary.d_model(), 32);
891
892        let relative = PositionalEncodingFactory::create::<f32, _>(
893            PositionalEncodingType::Relative,
894            32,
895            100,
896            &mut rng,
897        );
898        assert_eq!(relative.d_model(), 32);
899    }
900
901    #[test]
902    fn test_sinusoidal_properties() {
903        let pe = SinusoidalPositionalEncoding::<f64>::new(64, 1000);
904        let encoding = pe.encode(100);
905
906        // Each position should be unique
907        for i in 0..99 {
908            let mut same = true;
909            for d in 0..64 {
910                if (encoding[[i, d]] - encoding[[i + 1, d]]).abs() > 1e-10 {
911                    same = false;
912                    break;
913                }
914            }
915            assert!(!same, "Adjacent positions should be different");
916        }
917    }
918
919    #[test]
920    fn test_max_len_error() {
921        let pe = SinusoidalPositionalEncoding::<f32>::new(16, 10);
922
923        let input = Array3::zeros((1, 20, 16)); // seq_len > max_len
924        let result = pe.apply(&input);
925
926        assert!(result.is_err());
927    }
928}