scirs2_neural/utils/
positional_encoding.rs

1//! Positional encoding utilities for transformer models
2//!
3//! This module provides implementations of positional encoding techniques
4//! used in transformer architectures to incorporate sequence order information.
5
6use crate::error::{NeuralError, Result};
7use ndarray::{Array, IxDyn};
8use num_traits::Float;
9use std::fmt::Debug;
10
11/// Types of positional encoding available
12#[derive(Debug, Clone, Copy, PartialEq)]
13pub enum PositionalEncodingType {
14    /// Sinusoidal positional encoding from the "Attention Is All You Need" paper
15    Sinusoidal,
16    /// Learned positional embeddings, which are trained along with the model
17    Learned,
18    /// Relative positional encoding that focuses on relative distances
19    Relative,
20}
21
22/// Factory for creating positional encodings
23pub struct PositionalEncodingFactory;
24
25impl PositionalEncodingFactory {
26    /// Create a positional encoding
27    ///
28    /// # Arguments
29    ///
30    /// * `encoding_type` - Type of positional encoding to create
31    /// * `max_len` - Maximum sequence length
32    /// * `d_model` - Model embedding dimension
33    ///
34    /// # Returns
35    ///
36    /// * Box containing the positional encoding implementation
37    pub fn create<F: Float + Debug + 'static>(
38        encoding_type: PositionalEncodingType,
39        max_len: usize,
40        d_model: usize,
41    ) -> Result<Box<dyn PositionalEncoding<F>>> {
42        match encoding_type {
43            PositionalEncodingType::Sinusoidal => Ok(Box::new(SinusoidalPositionalEncoding::new(
44                max_len, d_model,
45            )?)),
46            PositionalEncodingType::Learned => {
47                Ok(Box::new(LearnedPositionalEncoding::new(max_len, d_model)))
48            }
49            PositionalEncodingType::Relative => {
50                Ok(Box::new(RelativePositionalEncoding::new(max_len, d_model)))
51            }
52        }
53    }
54}
55
56/// Trait for positional encoding implementations
57pub trait PositionalEncoding<F: Float + Debug> {
58    /// Apply positional encoding to input embeddings
59    ///
60    /// # Arguments
61    ///
62    /// * `embeddings` - Input embeddings [batch, seq_len, d_model]
63    ///
64    /// # Returns
65    ///
66    /// * Embeddings with positional encoding added
67    fn forward(&self, embeddings: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>>;
68
69    /// Get the positional encoding matrix directly
70    ///
71    /// # Arguments
72    ///
73    /// * `seq_len` - Sequence length to generate encodings for
74    ///
75    /// # Returns
76    ///
77    /// * Positional encoding matrix [seq_len, d_model]
78    fn get_encoding(&self, seq_len: usize) -> Result<Array<F, IxDyn>>;
79
80    /// Update learnable parameters if any
81    ///
82    /// # Arguments
83    ///
84    /// * `learning_rate` - Learning rate for the update
85    fn update(&mut self, learning_rate: F) -> Result<()>;
86
87    /// Get learnable parameters if any
88    ///
89    /// # Returns
90    ///
91    /// * Vector of parameters as arrays
92    fn params(&self) -> Vec<Array<F, IxDyn>> {
93        Vec::new() // Default implementation returns empty vector
94    }
95
96    /// Set training mode (does nothing by default)
97    fn set_training(&mut self, _training: bool) {
98        // Default implementation does nothing
99    }
100
101    /// Get training mode (false by default)
102    fn is_training(&self) -> bool {
103        false
104    }
105}
106
107/// Sinusoidal positional encoding from the "Attention Is All You Need" paper
108///
109/// Uses sine and cosine functions of different frequencies to encode position.
110/// The advantage is that this can extrapolate to longer sequences than those
111/// seen during training.
112#[derive(Debug, Clone)]
113pub struct SinusoidalPositionalEncoding<F: Float + Debug> {
114    /// Maximum sequence length
115    max_len: usize,
116    /// Model embedding dimension
117    d_model: usize,
118    /// Pre-computed encoding matrix [max_len, d_model]
119    encoding: Array<F, IxDyn>,
120}
121
122impl<F: Float + Debug + 'static> SinusoidalPositionalEncoding<F> {
123    /// Create a new sinusoidal positional encoding
124    ///
125    /// # Arguments
126    ///
127    /// * `max_len` - Maximum sequence length
128    /// * `d_model` - Model embedding dimension
129    ///
130    /// # Returns
131    ///
132    /// * A new sinusoidal positional encoding
133    pub fn new(max_len: usize, d_model: usize) -> Result<Self> {
134        if d_model % 2 != 0 {
135            return Err(NeuralError::InvalidArchitecture(format!(
136                "Model dimension ({}) must be even for sinusoidal positional encoding",
137                d_model
138            )));
139        }
140
141        let mut encoding = Array::<F, _>::zeros((max_len, d_model));
142
143        for pos in 0..max_len {
144            for i in 0..d_model / 2 {
145                let div_term = F::from(10000.0)
146                    .unwrap()
147                    .powf(F::from(2.0 * i as f64 / d_model as f64).unwrap());
148
149                // Use sin for even indices
150                encoding[[pos, 2 * i]] = F::from(pos as f64).unwrap().sin() / div_term;
151
152                // Use cos for odd indices
153                encoding[[pos, 2 * i + 1]] = F::from(pos as f64).unwrap().cos() / div_term;
154            }
155        }
156
157        Ok(Self {
158            max_len,
159            d_model,
160            encoding: encoding.into_dyn(),
161        })
162    }
163}
164
165impl<F: Float + Debug + 'static> PositionalEncoding<F> for SinusoidalPositionalEncoding<F> {
166    fn forward(&self, embeddings: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
167        if embeddings.ndim() < 2 {
168            return Err(NeuralError::InferenceError(
169                "Embeddings must have at least 2 dimensions [batch, seq_len, ...]".to_string(),
170            ));
171        }
172
173        let embed_shape = embeddings.shape();
174        let seq_len = embed_shape[1];
175
176        if seq_len > self.max_len {
177            return Err(NeuralError::InferenceError(format!(
178                "Sequence length ({}) exceeds maximum length ({})",
179                seq_len, self.max_len
180            )));
181        }
182
183        // Get positional encoding for the current sequence length
184        let pos_encoding = self.get_encoding(seq_len)?;
185
186        // Create a mutable copy of the embeddings
187        let mut output = embeddings.clone();
188
189        // Add positional encoding to each batch element
190        for batch_idx in 0..embed_shape[0] {
191            let mut batch_slice = output.slice_mut(ndarray::s![batch_idx, .., ..]);
192
193            // Add positional encoding
194            for pos in 0..seq_len {
195                for dim in 0..self.d_model {
196                    batch_slice[[pos, dim]] = batch_slice[[pos, dim]] + pos_encoding[[pos, dim]];
197                }
198            }
199        }
200
201        Ok(output)
202    }
203
204    fn get_encoding(&self, seq_len: usize) -> Result<Array<F, IxDyn>> {
205        if seq_len > self.max_len {
206            return Err(NeuralError::InferenceError(format!(
207                "Requested sequence length ({}) exceeds maximum length ({})",
208                seq_len, self.max_len
209            )));
210        }
211
212        // Return a slice of the pre-computed encoding, with proper dimension type
213        Ok(self
214            .encoding
215            .slice(ndarray::s![0..seq_len, ..])
216            .to_owned()
217            .into_dyn())
218    }
219
220    fn update(&mut self, _learning_rate: F) -> Result<()> {
221        // Sinusoidal encoding has no learnable parameters
222        Ok(())
223    }
224
225    /// Set training mode (does nothing for sinusoidal encoding)
226    fn set_training(&mut self, _training: bool) {
227        // Sinusoidal encoding has no training mode
228    }
229
230    /// Get training mode (always false for sinusoidal encoding)
231    fn is_training(&self) -> bool {
232        false
233    }
234}
235
236/// Learned positional embeddings
237///
238/// Uses a lookup table of learned position embeddings. These can potentially
239/// capture more complex position patterns but don't extrapolate well to unseen positions.
240pub struct LearnedPositionalEncoding<F: Float + Debug> {
241    /// Maximum sequence length
242    max_len: usize,
243    /// Model embedding dimension
244    d_model: usize,
245    /// Learnable position embeddings [max_len, d_model]
246    weights: Array<F, IxDyn>,
247    /// Gradient of weights
248    dweights: Array<F, IxDyn>,
249}
250
251impl<F: Float + Debug + 'static> LearnedPositionalEncoding<F> {
252    /// Create a new learned positional encoding
253    ///
254    /// # Arguments
255    ///
256    /// * `max_len` - Maximum sequence length
257    /// * `d_model` - Model embedding dimension
258    ///
259    /// # Returns
260    ///
261    /// * A new learned positional encoding
262    pub fn new(max_len: usize, d_model: usize) -> Self {
263        // Initialize to small random values
264        let init_scale = F::from(0.02).unwrap();
265        let weights = Array::<F, _>::from_elem((max_len, d_model), init_scale);
266        let dweights = Array::<F, _>::zeros((max_len, d_model));
267
268        Self {
269            max_len,
270            d_model,
271            weights: weights.into_dyn(),
272            dweights: dweights.into_dyn(),
273        }
274    }
275}
276
277impl<F: Float + Debug + 'static> PositionalEncoding<F> for LearnedPositionalEncoding<F> {
278    fn forward(&self, embeddings: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
279        if embeddings.ndim() < 2 {
280            return Err(NeuralError::InferenceError(
281                "Embeddings must have at least 2 dimensions [batch, seq_len, ...]".to_string(),
282            ));
283        }
284
285        let embed_shape = embeddings.shape();
286        let seq_len = embed_shape[1];
287
288        if seq_len > self.max_len {
289            return Err(NeuralError::InferenceError(format!(
290                "Sequence length ({}) exceeds maximum length ({})",
291                seq_len, self.max_len
292            )));
293        }
294
295        // Get positional encoding for the current sequence length
296        let pos_encoding = self.get_encoding(seq_len)?;
297
298        // Create a mutable copy of the embeddings
299        let mut output = embeddings.clone();
300
301        // Add positional encoding to each batch element
302        for batch_idx in 0..embed_shape[0] {
303            let mut batch_slice = output.slice_mut(ndarray::s![batch_idx, .., ..]);
304
305            // Add positional encoding
306            for pos in 0..seq_len {
307                for dim in 0..self.d_model {
308                    batch_slice[[pos, dim]] = batch_slice[[pos, dim]] + pos_encoding[[pos, dim]];
309                }
310            }
311        }
312
313        Ok(output)
314    }
315
316    fn get_encoding(&self, seq_len: usize) -> Result<Array<F, IxDyn>> {
317        if seq_len > self.max_len {
318            return Err(NeuralError::InferenceError(format!(
319                "Requested sequence length ({}) exceeds maximum length ({})",
320                seq_len, self.max_len
321            )));
322        }
323
324        // Return a slice of the weights, with proper dimension type
325        Ok(self
326            .weights
327            .slice(ndarray::s![0..seq_len, ..])
328            .to_owned()
329            .into_dyn())
330    }
331
332    fn update(&mut self, learning_rate: F) -> Result<()> {
333        // Update weights using the stored gradients
334        // This is a simplified implementation - in practice, an optimizer would be used
335        let small_change = F::from(0.001).unwrap();
336        let lr = learning_rate * small_change;
337
338        for i in 0..self.max_len {
339            for j in 0..self.d_model {
340                self.weights[[i, j]] = self.weights[[i, j]] - lr * self.dweights[[i, j]];
341            }
342        }
343
344        Ok(())
345    }
346
347    fn params(&self) -> Vec<Array<F, IxDyn>> {
348        // Return the learnable weights
349        vec![self.weights.clone()]
350    }
351}
352
353/// Relative positional encoding
354///
355/// Implements a form of relative positional encoding that can capture
356/// pairwise positional relationships efficiently.
357pub struct RelativePositionalEncoding<F: Float + Debug> {
358    /// Maximum sequence length
359    max_len: usize,
360    /// Model embedding dimension
361    d_model: usize,
362    /// Relative position embeddings [2*max_len-1, d_model]
363    weights: Array<F, IxDyn>,
364    /// Gradient of weights
365    dweights: Array<F, IxDyn>,
366}
367
368impl<F: Float + Debug + 'static> RelativePositionalEncoding<F> {
369    /// Create a new relative positional encoding
370    ///
371    /// # Arguments
372    ///
373    /// * `max_len` - Maximum sequence length
374    /// * `d_model` - Model embedding dimension
375    ///
376    /// # Returns
377    ///
378    /// * A new relative positional encoding
379    pub fn new(max_len: usize, d_model: usize) -> Self {
380        // For relative positions, we need 2*max_len-1 different embeddings
381        // to represent all possible relative positions from -(max_len-1) to +(max_len-1)
382        let rel_size = 2 * max_len - 1;
383
384        // Initialize to small random values
385        let init_scale = F::from(0.02).unwrap();
386        let weights = Array::<F, _>::from_elem((rel_size, d_model), init_scale);
387        let dweights = Array::<F, _>::zeros((rel_size, d_model));
388
389        Self {
390            max_len,
391            d_model,
392            weights: weights.into_dyn(),
393            dweights: dweights.into_dyn(),
394        }
395    }
396
397    /// Get the relative position index
398    ///
399    /// Convert a relative position to the corresponding index in the weights matrix
400    #[allow(dead_code)]
401    fn rel_pos_to_index(&self, rel_pos: isize) -> usize {
402        // Convert relative position to an index
403        // rel_pos = -max_len+1 -> index = 0
404        // rel_pos = 0 -> index = max_len-1
405        // rel_pos = max_len-1 -> index = 2*max_len-2
406        (rel_pos + self.max_len as isize - 1) as usize
407    }
408}
409
410impl<F: Float + Debug + 'static> PositionalEncoding<F> for RelativePositionalEncoding<F> {
411    fn forward(&self, embeddings: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
412        if embeddings.ndim() < 2 {
413            return Err(NeuralError::InferenceError(
414                "Embeddings must have at least 2 dimensions [batch, seq_len, ...]".to_string(),
415            ));
416        }
417
418        let embed_shape = embeddings.shape();
419        let seq_len = embed_shape[1];
420
421        if seq_len > self.max_len {
422            return Err(NeuralError::InferenceError(format!(
423                "Sequence length ({}) exceeds maximum length ({})",
424                seq_len, self.max_len
425            )));
426        }
427
428        // For relative positions, we typically don't add them directly to the embeddings
429        // but instead use them in the attention mechanism
430        // For compatibility with the positional encoding interface, we'll just return the input
431        Ok(embeddings.clone())
432    }
433
434    fn get_encoding(&self, seq_len: usize) -> Result<Array<F, IxDyn>> {
435        if seq_len > self.max_len {
436            return Err(NeuralError::InferenceError(format!(
437                "Requested sequence length ({}) exceeds maximum length ({})",
438                seq_len, self.max_len
439            )));
440        }
441
442        // For relative positional encoding, we'd need to compute an attention bias
443        // Based on all pairwise distances between positions
444        // For simplicity, we'll just return a zero matrix matching the interface
445        let encoding = Array::<F, _>::zeros((seq_len, self.d_model));
446        Ok(encoding.into_dyn())
447    }
448
449    fn update(&mut self, learning_rate: F) -> Result<()> {
450        // Update weights using the stored gradients
451        // This is a simplified implementation - in practice, an optimizer would be used
452        let small_change = F::from(0.001).unwrap();
453        let lr = learning_rate * small_change;
454
455        let rel_size = 2 * self.max_len - 1;
456        for i in 0..rel_size {
457            for j in 0..self.d_model {
458                self.weights[[i, j]] = self.weights[[i, j]] - lr * self.dweights[[i, j]];
459            }
460        }
461
462        Ok(())
463    }
464
465    fn params(&self) -> Vec<Array<F, IxDyn>> {
466        // Return the learnable weights
467        vec![self.weights.clone()]
468    }
469}
470
471#[cfg(test)]
472mod tests {
473    use super::*;
474    use approx::assert_relative_eq;
475    use ndarray::Array3;
476
477    #[test]
478    fn test_sinusoidal_encoding_shape() {
479        let max_len = 100;
480        let d_model = 64;
481
482        let pos_enc = SinusoidalPositionalEncoding::<f64>::new(max_len, d_model).unwrap();
483
484        // Get encoding for a specific sequence length
485        let encoding = pos_enc.get_encoding(50).unwrap();
486
487        // Check shape
488        assert_eq!(encoding.shape(), &[50, d_model]);
489    }
490
491    #[test]
492    fn test_sinusoidal_encoding_properties() {
493        let max_len = 100;
494        let d_model = 64;
495
496        let pos_enc = SinusoidalPositionalEncoding::<f64>::new(max_len, d_model).unwrap();
497
498        // Get encoding for a specific sequence length
499        let encoding = pos_enc.get_encoding(max_len).unwrap();
500
501        // Check that different positions have different encodings
502        let pos0 = encoding.slice(ndarray::s![0, ..]).to_owned();
503        let pos1 = encoding.slice(ndarray::s![1, ..]).to_owned();
504
505        // At least one element should be different
506        let mut all_equal = true;
507        for i in 0..d_model {
508            if (pos0[i] - pos1[i]).abs() > 1e-10 {
509                all_equal = false;
510                break;
511            }
512        }
513
514        assert!(
515            !all_equal,
516            "Positions 0 and 1 should have different encodings"
517        );
518    }
519
520    #[test]
521    fn test_positional_encoding_factory() {
522        let max_len = 100;
523        let d_model = 64;
524
525        // Create different types of encodings
526        let sinusoidal = PositionalEncodingFactory::create::<f64>(
527            PositionalEncodingType::Sinusoidal,
528            max_len,
529            d_model,
530        )
531        .unwrap();
532
533        let learned = PositionalEncodingFactory::create::<f64>(
534            PositionalEncodingType::Learned,
535            max_len,
536            d_model,
537        )
538        .unwrap();
539
540        let relative = PositionalEncodingFactory::create::<f64>(
541            PositionalEncodingType::Relative,
542            max_len,
543            d_model,
544        )
545        .unwrap();
546
547        // Create a test batch
548        let batch_size = 2;
549        let seq_len = 10;
550        let embeddings = Array3::<f64>::zeros((batch_size, seq_len, d_model)).into_dyn();
551
552        // Verify that all encodings can process the embeddings
553        let _ = sinusoidal.forward(&embeddings).unwrap();
554        let _ = learned.forward(&embeddings).unwrap();
555        let _ = relative.forward(&embeddings).unwrap();
556    }
557
558    #[test]
559    fn test_sinusoidal_encoding_addition() {
560        let max_len = 100;
561        let d_model = 64;
562
563        let pos_enc = SinusoidalPositionalEncoding::<f64>::new(max_len, d_model).unwrap();
564
565        // Create a batch with non-zero values
566        let batch_size = 2;
567        let seq_len = 10;
568        let embeddings = Array3::<f64>::from_elem((batch_size, seq_len, d_model), 1.0).into_dyn();
569
570        // Get encoding
571        let encoding = pos_enc.get_encoding(seq_len).unwrap();
572
573        // Apply positional encoding
574        let output = pos_enc.forward(&embeddings).unwrap();
575
576        // Verify that the embeddings were modified by the encoding
577        for b in 0..batch_size {
578            for s in 0..seq_len {
579                for d in 0..d_model {
580                    // output[b,s,d] should be embeddings[b,s,d] + encoding[s,d]
581                    assert_relative_eq!(output[[b, s, d]], 1.0 + encoding[[s, d]], epsilon = 1e-10);
582                }
583            }
584        }
585    }
586}