Skip to main content

tensorlogic_trustformers/
position.rs

1//! Position encoding implementations for transformer models.
2//!
3//! This module provides various position encoding strategies that can be
4//! compiled to TensorLogic einsum graphs:
5//!
6//! 1. **Sinusoidal Encoding**: Fixed position encodings using sin/cos functions
7//!    (from "Attention Is All You Need")
8//! 2. **Learned Encoding**: Trainable position embeddings
9//! 3. **Relative Position**: Relative position biases for attention
10//!
11//! ## Sinusoidal Position Encoding
12//!
13//! ```text
14//! PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
15//! PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
16//! ```
17//!
18//! Where:
19//! - `pos` = position in sequence
20//! - `i` = dimension index
21//! - `d_model` = model dimension
22
23use serde::{Deserialize, Serialize};
24use tensorlogic_ir::{EinsumGraph, EinsumNode};
25
26use crate::error::{Result, TrustformerError};
27
28/// Configuration for position encodings
29#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
30pub struct PositionEncodingConfig {
31    /// Model dimension
32    pub d_model: usize,
33    /// Maximum sequence length
34    pub max_seq_len: usize,
35    /// Encoding type
36    pub encoding_type: PositionEncodingType,
37    /// Dropout probability
38    pub dropout: f64,
39}
40
41/// Type of position encoding
42#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
43pub enum PositionEncodingType {
44    /// Sinusoidal (fixed) position encoding
45    Sinusoidal {
46        /// Base for frequency computation (default: 10000.0)
47        base: f64,
48    },
49    /// Learned position embedding
50    Learned,
51    /// Relative position encoding
52    Relative {
53        /// Number of relative position buckets
54        num_buckets: usize,
55        /// Maximum relative distance
56        max_distance: usize,
57    },
58    /// Rotary Position Embedding (RoPE) - used in LLaMA, GPT-NeoX
59    Rotary {
60        /// Base for frequency computation (default: 10000.0)
61        base: f64,
62        /// Scaling factor for long sequences (default: 1.0)
63        scaling_factor: f64,
64    },
65    /// ALiBi (Attention with Linear Biases) - used in BLOOM
66    Alibi {
67        /// Number of attention heads
68        n_heads: usize,
69        /// Maximum sequence length
70        max_seq_len: usize,
71    },
72}
73
74impl PositionEncodingConfig {
75    /// Create a new sinusoidal position encoding configuration
76    pub fn sinusoidal(d_model: usize, max_seq_len: usize) -> Self {
77        Self {
78            d_model,
79            max_seq_len,
80            encoding_type: PositionEncodingType::Sinusoidal { base: 10000.0 },
81            dropout: 0.0,
82        }
83    }
84
85    /// Create a new learned position encoding configuration
86    pub fn learned(d_model: usize, max_seq_len: usize) -> Self {
87        Self {
88            d_model,
89            max_seq_len,
90            encoding_type: PositionEncodingType::Learned,
91            dropout: 0.0,
92        }
93    }
94
95    /// Create a new relative position encoding configuration
96    pub fn relative(d_model: usize, num_buckets: usize, max_distance: usize) -> Self {
97        Self {
98            d_model,
99            max_seq_len: 0, // Not used for relative encoding
100            encoding_type: PositionEncodingType::Relative {
101                num_buckets,
102                max_distance,
103            },
104            dropout: 0.0,
105        }
106    }
107
108    /// Create a new rotary position encoding (RoPE) configuration
109    pub fn rotary(d_model: usize, max_seq_len: usize) -> Self {
110        Self {
111            d_model,
112            max_seq_len,
113            encoding_type: PositionEncodingType::Rotary {
114                base: 10000.0,
115                scaling_factor: 1.0,
116            },
117            dropout: 0.0,
118        }
119    }
120
121    /// Create RoPE with custom base and scaling
122    pub fn rotary_scaled(
123        d_model: usize,
124        max_seq_len: usize,
125        base: f64,
126        scaling_factor: f64,
127    ) -> Self {
128        Self {
129            d_model,
130            max_seq_len,
131            encoding_type: PositionEncodingType::Rotary {
132                base,
133                scaling_factor,
134            },
135            dropout: 0.0,
136        }
137    }
138
139    /// Create a new ALiBi position encoding configuration
140    pub fn alibi(d_model: usize, n_heads: usize, max_seq_len: usize) -> Self {
141        Self {
142            d_model,
143            max_seq_len,
144            encoding_type: PositionEncodingType::Alibi {
145                n_heads,
146                max_seq_len,
147            },
148            dropout: 0.0,
149        }
150    }
151
152    /// Set dropout probability
153    pub fn with_dropout(mut self, dropout: f64) -> Self {
154        self.dropout = dropout;
155        self
156    }
157
158    /// Validate configuration
159    pub fn validate(&self) -> Result<()> {
160        if self.d_model == 0 {
161            return Err(TrustformerError::InvalidDimension {
162                expected: 1,
163                got: 0,
164                context: "d_model must be positive".to_string(),
165            });
166        }
167
168        if !(0.0..=1.0).contains(&self.dropout) {
169            return Err(TrustformerError::InvalidDimension {
170                expected: 1,
171                got: 0,
172                context: format!("dropout must be in [0,1], got {}", self.dropout),
173            });
174        }
175
176        match &self.encoding_type {
177            PositionEncodingType::Sinusoidal { base } => {
178                if *base <= 0.0 {
179                    return Err(TrustformerError::InvalidDimension {
180                        expected: 1,
181                        got: 0,
182                        context: "base must be positive".to_string(),
183                    });
184                }
185            }
186            PositionEncodingType::Relative {
187                num_buckets,
188                max_distance,
189            } => {
190                if *num_buckets == 0 {
191                    return Err(TrustformerError::InvalidDimension {
192                        expected: 1,
193                        got: 0,
194                        context: "num_buckets must be positive".to_string(),
195                    });
196                }
197                if *max_distance == 0 {
198                    return Err(TrustformerError::InvalidDimension {
199                        expected: 1,
200                        got: 0,
201                        context: "max_distance must be positive".to_string(),
202                    });
203                }
204            }
205            PositionEncodingType::Learned => {
206                if self.max_seq_len == 0 {
207                    return Err(TrustformerError::InvalidDimension {
208                        expected: 1,
209                        got: 0,
210                        context: "max_seq_len must be positive for learned encoding".to_string(),
211                    });
212                }
213            }
214            PositionEncodingType::Rotary {
215                base,
216                scaling_factor,
217            } => {
218                if *base <= 0.0 {
219                    return Err(TrustformerError::InvalidDimension {
220                        expected: 1,
221                        got: 0,
222                        context: "RoPE base must be positive".to_string(),
223                    });
224                }
225                if *scaling_factor <= 0.0 {
226                    return Err(TrustformerError::InvalidDimension {
227                        expected: 1,
228                        got: 0,
229                        context: "RoPE scaling_factor must be positive".to_string(),
230                    });
231                }
232                if self.max_seq_len == 0 {
233                    return Err(TrustformerError::InvalidDimension {
234                        expected: 1,
235                        got: 0,
236                        context: "max_seq_len must be positive for RoPE".to_string(),
237                    });
238                }
239                if !self.d_model.is_multiple_of(2) {
240                    return Err(TrustformerError::InvalidDimension {
241                        expected: 1,
242                        got: 0,
243                        context: "d_model must be even for RoPE".to_string(),
244                    });
245                }
246            }
247            PositionEncodingType::Alibi {
248                n_heads,
249                max_seq_len,
250            } => {
251                if *n_heads == 0 {
252                    return Err(TrustformerError::InvalidDimension {
253                        expected: 1,
254                        got: 0,
255                        context: "n_heads must be positive for ALiBi".to_string(),
256                    });
257                }
258                if *max_seq_len == 0 {
259                    return Err(TrustformerError::InvalidDimension {
260                        expected: 1,
261                        got: 0,
262                        context: "max_seq_len must be positive for ALiBi".to_string(),
263                    });
264                }
265            }
266        }
267
268        Ok(())
269    }
270}
271
272/// Sinusoidal position encoding
273#[derive(Clone, Debug)]
274pub struct SinusoidalPositionEncoding {
275    /// Configuration
276    pub config: PositionEncodingConfig,
277}
278
279impl SinusoidalPositionEncoding {
280    /// Create a new sinusoidal position encoding
281    pub fn new(config: PositionEncodingConfig) -> Result<Self> {
282        config.validate()?;
283        match config.encoding_type {
284            PositionEncodingType::Sinusoidal { .. } => Ok(Self { config }),
285            _ => Err(TrustformerError::InvalidDimension {
286                expected: 0,
287                got: 1,
288                context: "Expected Sinusoidal encoding type".to_string(),
289            }),
290        }
291    }
292
293    /// Build einsum graph for sinusoidal position encoding
294    ///
295    /// Input tensors:
296    /// - 0: x (input) [batch, seq_len, d_model]
297    /// - 1: position_ids [batch, seq_len] (optional, will use 0..seq_len if not provided)
298    ///
299    /// Output tensors:
300    /// - output: [batch, seq_len, d_model] (x + position_encoding)
301    pub fn build_encoding_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
302        // The sinusoidal encoding is computed as:
303        // PE(pos, 2i) = sin(pos / base^(2i/d_model))
304        // PE(pos, 2i+1) = cos(pos / base^(2i/d_model))
305
306        // For einsum representation, we add a pre-computed tensor
307        let pe_tensor = graph.add_tensor("sinusoidal_pe");
308
309        // Add position encoding to input
310        // einsum("bsd,sd->bsd", x, pe) (broadcast addition)
311        let output_tensor = graph.add_tensor("x_with_pe");
312        let add_node = EinsumNode::elem_binary("add", 0, pe_tensor, output_tensor);
313        graph.add_node(add_node)?;
314
315        // Apply dropout if configured
316        if self.config.dropout > 0.0 {
317            let dropout_tensor = graph.add_tensor("pe_dropout_output");
318            let dropout_node = EinsumNode::elem_unary(
319                format!("dropout_{}", self.config.dropout),
320                output_tensor,
321                dropout_tensor,
322            );
323            graph.add_node(dropout_node)?;
324            Ok(vec![dropout_tensor])
325        } else {
326            Ok(vec![output_tensor])
327        }
328    }
329
330    /// Get the base frequency for encoding
331    pub fn base(&self) -> f64 {
332        match self.config.encoding_type {
333            PositionEncodingType::Sinusoidal { base } => base,
334            _ => 10000.0,
335        }
336    }
337}
338
339/// Learned position encoding
340#[derive(Clone, Debug)]
341pub struct LearnedPositionEncoding {
342    /// Configuration
343    pub config: PositionEncodingConfig,
344}
345
346impl LearnedPositionEncoding {
347    /// Create a new learned position encoding
348    pub fn new(config: PositionEncodingConfig) -> Result<Self> {
349        config.validate()?;
350        match config.encoding_type {
351            PositionEncodingType::Learned => Ok(Self { config }),
352            _ => Err(TrustformerError::InvalidDimension {
353                expected: 0,
354                got: 1,
355                context: "Expected Learned encoding type".to_string(),
356            }),
357        }
358    }
359
360    /// Build einsum graph for learned position encoding
361    ///
362    /// Input tensors:
363    /// - 0: x (input) [batch, seq_len, d_model]
364    /// - 1: position_embeddings [max_seq_len, d_model] (learned parameter)
365    /// - 2: position_ids [batch, seq_len] (optional)
366    ///
367    /// Output tensors:
368    /// - output: [batch, seq_len, d_model] (x + position_embedding)
369    pub fn build_encoding_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
370        // Lookup position embeddings
371        // For simplicity, we use direct indexing which maps to gather operation
372        let pe_lookup = graph.add_tensor("pe_lookup");
373        let lookup_node = EinsumNode::elem_unary("gather_pos_emb", 1, pe_lookup);
374        graph.add_node(lookup_node)?;
375
376        // Add position encoding to input
377        let output_tensor = graph.add_tensor("x_with_learned_pe");
378        let add_node = EinsumNode::elem_binary("add", 0, pe_lookup, output_tensor);
379        graph.add_node(add_node)?;
380
381        // Apply dropout if configured
382        if self.config.dropout > 0.0 {
383            let dropout_tensor = graph.add_tensor("learned_pe_dropout_output");
384            let dropout_node = EinsumNode::elem_unary(
385                format!("dropout_{}", self.config.dropout),
386                output_tensor,
387                dropout_tensor,
388            );
389            graph.add_node(dropout_node)?;
390            Ok(vec![dropout_tensor])
391        } else {
392            Ok(vec![output_tensor])
393        }
394    }
395
396    /// Get maximum sequence length
397    pub fn max_seq_len(&self) -> usize {
398        self.config.max_seq_len
399    }
400}
401
402/// Relative position encoding
403#[derive(Clone, Debug)]
404pub struct RelativePositionEncoding {
405    /// Configuration
406    pub config: PositionEncodingConfig,
407}
408
409impl RelativePositionEncoding {
410    /// Create a new relative position encoding
411    pub fn new(config: PositionEncodingConfig) -> Result<Self> {
412        config.validate()?;
413        match config.encoding_type {
414            PositionEncodingType::Relative { .. } => Ok(Self { config }),
415            _ => Err(TrustformerError::InvalidDimension {
416                expected: 0,
417                got: 1,
418                context: "Expected Relative encoding type".to_string(),
419            }),
420        }
421    }
422
423    /// Build einsum graph for relative position bias
424    ///
425    /// Input tensors:
426    /// - 0: attention_scores [batch, n_heads, seq_len, seq_len]
427    /// - 1: relative_position_bias [n_heads, num_buckets]
428    /// - 2: relative_position_indices [seq_len, seq_len] (bucket indices)
429    ///
430    /// Output tensors:
431    /// - output: [batch, n_heads, seq_len, seq_len] (scores + bias)
432    pub fn build_bias_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
433        // Lookup relative position bias based on indices
434        let bias_lookup = graph.add_tensor("rel_pos_bias_lookup");
435        let lookup_node = EinsumNode::elem_unary("gather_rel_bias", 1, bias_lookup);
436        graph.add_node(lookup_node)?;
437
438        // Add bias to attention scores
439        // einsum("bhqk,hqk->bhqk", scores, bias) (broadcast addition)
440        let output_tensor = graph.add_tensor("scores_with_rel_bias");
441        let add_node = EinsumNode::elem_binary("add", 0, bias_lookup, output_tensor);
442        graph.add_node(add_node)?;
443
444        Ok(vec![output_tensor])
445    }
446
447    /// Get number of relative position buckets
448    pub fn num_buckets(&self) -> usize {
449        match self.config.encoding_type {
450            PositionEncodingType::Relative { num_buckets, .. } => num_buckets,
451            _ => 0,
452        }
453    }
454
455    /// Get maximum relative distance
456    pub fn max_distance(&self) -> usize {
457        match self.config.encoding_type {
458            PositionEncodingType::Relative { max_distance, .. } => max_distance,
459            _ => 0,
460        }
461    }
462}
463
464/// Rotary Position Embedding (RoPE)
465///
466/// Used in models like LLaMA, GPT-NeoX, PaLM. RoPE encodes position by rotating
467/// query and key vectors in the complex plane, providing natural relative position
468/// information without adding extra parameters.
469///
470/// Reference: "RoFormer: Enhanced Transformer with Rotary Position Embedding"
471/// <https://arxiv.org/abs/2104.09864>
472#[derive(Clone, Debug)]
473pub struct RotaryPositionEncoding {
474    /// Configuration
475    pub config: PositionEncodingConfig,
476}
477
478impl RotaryPositionEncoding {
479    /// Create a new rotary position encoding
480    pub fn new(config: PositionEncodingConfig) -> Result<Self> {
481        config.validate()?;
482        match config.encoding_type {
483            PositionEncodingType::Rotary { .. } => Ok(Self { config }),
484            _ => Err(TrustformerError::InvalidDimension {
485                expected: 0,
486                got: 1,
487                context: "Expected Rotary encoding type".to_string(),
488            }),
489        }
490    }
491
492    /// Build einsum graph for RoPE
493    ///
494    /// RoPE applies rotation to query and key vectors in attention:
495    /// - Splits d_model into pairs of dimensions
496    /// - Rotates each pair by position-dependent angle
497    /// - Preserves relative position information
498    ///
499    /// Input tensors:
500    /// - 0: x (input) [batch, seq_len, d_model]
501    /// - 1: cos_cached [max_seq_len, d_model/2] (precomputed cosines)
502    /// - 2: sin_cached [max_seq_len, d_model/2] (precomputed sines)
503    ///
504    /// Output tensors:
505    /// - output: [batch, seq_len, d_model] (rotated embeddings)
506    pub fn build_encoding_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
507        // RoPE rotation formula:
508        // x_rot = [x_even * cos - x_odd * sin, x_even * sin + x_odd * cos]
509
510        // Split input into even and odd indices
511        let x_even = graph.add_tensor("rope_x_even");
512        let x_odd = graph.add_tensor("rope_x_odd");
513        let split_node = EinsumNode::elem_unary("split_even_odd", 0, x_even);
514        graph.add_node(split_node)?;
515
516        // Apply rotation using cached cos/sin values
517        // x_even * cos
518        let even_cos = graph.add_tensor("rope_even_cos");
519        let even_cos_node = EinsumNode::elem_binary("mul", x_even, 1, even_cos);
520        graph.add_node(even_cos_node)?;
521
522        // x_odd * sin
523        let odd_sin = graph.add_tensor("rope_odd_sin");
524        let odd_sin_node = EinsumNode::elem_binary("mul", x_odd, 2, odd_sin);
525        graph.add_node(odd_sin_node)?;
526
527        // First half: x_even * cos - x_odd * sin
528        let rotated_0 = graph.add_tensor("rope_rotated_0");
529        let sub_node = EinsumNode::elem_binary("sub", even_cos, odd_sin, rotated_0);
530        graph.add_node(sub_node)?;
531
532        // x_even * sin
533        let even_sin = graph.add_tensor("rope_even_sin");
534        let even_sin_node = EinsumNode::elem_binary("mul", x_even, 2, even_sin);
535        graph.add_node(even_sin_node)?;
536
537        // x_odd * cos
538        let odd_cos = graph.add_tensor("rope_odd_cos");
539        let odd_cos_node = EinsumNode::elem_binary("mul", x_odd, 1, odd_cos);
540        graph.add_node(odd_cos_node)?;
541
542        // Second half: x_even * sin + x_odd * cos
543        let rotated_1 = graph.add_tensor("rope_rotated_1");
544        let add_node = EinsumNode::elem_binary("add", even_sin, odd_cos, rotated_1);
545        graph.add_node(add_node)?;
546
547        // Concatenate rotated halves
548        let output_tensor = graph.add_tensor("rope_output");
549        let concat_node = EinsumNode::elem_binary("concat", rotated_0, rotated_1, output_tensor);
550        graph.add_node(concat_node)?;
551
552        Ok(vec![output_tensor])
553    }
554
555    /// Get the base frequency for RoPE
556    pub fn base(&self) -> f64 {
557        match self.config.encoding_type {
558            PositionEncodingType::Rotary { base, .. } => base,
559            _ => 10000.0,
560        }
561    }
562
563    /// Get the scaling factor for long sequences
564    pub fn scaling_factor(&self) -> f64 {
565        match self.config.encoding_type {
566            PositionEncodingType::Rotary { scaling_factor, .. } => scaling_factor,
567            _ => 1.0,
568        }
569    }
570}
571
572/// ALiBi (Attention with Linear Biases)
573///
574/// Used in models like BLOOM. Instead of adding position embeddings to inputs,
575/// ALiBi adds a bias to attention scores that linearly penalizes distance.
576/// This allows extrapolation to longer sequences than seen during training.
577///
578/// Reference: "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation"
579/// <https://arxiv.org/abs/2108.12409>
580#[derive(Clone, Debug)]
581pub struct AlibiPositionEncoding {
582    /// Configuration
583    pub config: PositionEncodingConfig,
584}
585
586impl AlibiPositionEncoding {
587    /// Create a new ALiBi position encoding
588    pub fn new(config: PositionEncodingConfig) -> Result<Self> {
589        config.validate()?;
590        match config.encoding_type {
591            PositionEncodingType::Alibi { .. } => Ok(Self { config }),
592            _ => Err(TrustformerError::InvalidDimension {
593                expected: 0,
594                got: 1,
595                context: "Expected Alibi encoding type".to_string(),
596            }),
597        }
598    }
599
600    /// Build einsum graph for ALiBi bias
601    ///
602    /// ALiBi adds linear biases to attention scores based on query-key distance:
603    /// `bias(i, j) = -m * |i - j|`
604    /// where m is a head-specific slope
605    ///
606    /// Input tensors:
607    /// - 0: attention_scores `[batch, n_heads, seq_len, seq_len]`
608    /// - 1: alibi_slopes `[n_heads]` (precomputed slopes, one per head)
609    /// - 2: distance_matrix `[seq_len, seq_len]` (`|i - j|` for all positions)
610    ///
611    /// Output tensors:
612    /// - output: `[batch, n_heads, seq_len, seq_len]` (scores with ALiBi bias)
613    pub fn build_bias_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
614        // Compute -m * |i - j| for each head
615        // slopes: [n_heads, 1, 1]
616        // distance_matrix: [seq_len, seq_len]
617        // bias = -slopes * distance_matrix: [n_heads, seq_len, seq_len]
618
619        let slopes_expanded = graph.add_tensor("alibi_slopes_expanded");
620        let expand_node = EinsumNode::elem_unary("expand_dims", 1, slopes_expanded);
621        graph.add_node(expand_node)?;
622
623        let bias = graph.add_tensor("alibi_bias");
624        let bias_node = EinsumNode::elem_binary("mul", slopes_expanded, 2, bias);
625        graph.add_node(bias_node)?;
626
627        let neg_bias = graph.add_tensor("alibi_neg_bias");
628        let neg_node = EinsumNode::elem_unary("neg", bias, neg_bias);
629        graph.add_node(neg_node)?;
630
631        // Add bias to attention scores
632        // scores: [batch, n_heads, seq_len, seq_len]
633        // bias: [n_heads, seq_len, seq_len] (broadcasts over batch)
634        let output_tensor = graph.add_tensor("scores_with_alibi");
635        let add_node = EinsumNode::elem_binary("add", 0, neg_bias, output_tensor);
636        graph.add_node(add_node)?;
637
638        Ok(vec![output_tensor])
639    }
640
641    /// Get the number of attention heads
642    pub fn n_heads(&self) -> usize {
643        match self.config.encoding_type {
644            PositionEncodingType::Alibi { n_heads, .. } => n_heads,
645            _ => 0,
646        }
647    }
648
649    /// Compute ALiBi slopes for each attention head
650    ///
651    /// Slopes are computed as: m_i = 2^(-8i/n) for i in 1..n_heads
652    /// This gives different rates of distance penalty per head
653    pub fn compute_slopes(&self) -> Vec<f64> {
654        let n = self.n_heads();
655        (1..=n)
656            .map(|i| 2_f64.powf(-8.0 * (i as f64) / (n as f64)))
657            .collect()
658    }
659}
660
661#[cfg(test)]
662mod tests {
663    use super::*;
664
665    #[test]
666    fn test_sinusoidal_config_creation() {
667        let config = PositionEncodingConfig::sinusoidal(512, 2048);
668        assert_eq!(config.d_model, 512);
669        assert_eq!(config.max_seq_len, 2048);
670        assert!(matches!(
671            config.encoding_type,
672            PositionEncodingType::Sinusoidal { base: 10000.0 }
673        ));
674        assert!(config.validate().is_ok());
675    }
676
677    #[test]
678    fn test_learned_config_creation() {
679        let config = PositionEncodingConfig::learned(512, 2048);
680        assert_eq!(config.d_model, 512);
681        assert_eq!(config.max_seq_len, 2048);
682        assert!(matches!(
683            config.encoding_type,
684            PositionEncodingType::Learned
685        ));
686        assert!(config.validate().is_ok());
687    }
688
689    #[test]
690    fn test_relative_config_creation() {
691        let config = PositionEncodingConfig::relative(512, 32, 128);
692        assert_eq!(config.d_model, 512);
693        assert!(matches!(
694            config.encoding_type,
695            PositionEncodingType::Relative {
696                num_buckets: 32,
697                max_distance: 128
698            }
699        ));
700        assert!(config.validate().is_ok());
701    }
702
703    #[test]
704    fn test_config_with_dropout() {
705        let config = PositionEncodingConfig::sinusoidal(512, 2048).with_dropout(0.1);
706        assert!((config.dropout - 0.1).abs() < 1e-10);
707        assert!(config.validate().is_ok());
708    }
709
710    #[test]
711    fn test_sinusoidal_encoding_creation() {
712        let config = PositionEncodingConfig::sinusoidal(512, 2048);
713        let encoding = SinusoidalPositionEncoding::new(config).unwrap();
714        assert_eq!(encoding.config.d_model, 512);
715        assert_eq!(encoding.base(), 10000.0);
716    }
717
718    #[test]
719    fn test_learned_encoding_creation() {
720        let config = PositionEncodingConfig::learned(512, 2048);
721        let encoding = LearnedPositionEncoding::new(config).unwrap();
722        assert_eq!(encoding.max_seq_len(), 2048);
723    }
724
725    #[test]
726    fn test_relative_encoding_creation() {
727        let config = PositionEncodingConfig::relative(512, 32, 128);
728        let encoding = RelativePositionEncoding::new(config).unwrap();
729        assert_eq!(encoding.num_buckets(), 32);
730        assert_eq!(encoding.max_distance(), 128);
731    }
732
733    #[test]
734    fn test_sinusoidal_graph_building() {
735        let config = PositionEncodingConfig::sinusoidal(512, 2048);
736        let encoding = SinusoidalPositionEncoding::new(config).unwrap();
737
738        let mut graph = EinsumGraph::new();
739        graph.add_tensor("x");
740
741        let outputs = encoding.build_encoding_graph(&mut graph).unwrap();
742        assert_eq!(outputs.len(), 1);
743        assert!(!graph.nodes.is_empty());
744    }
745
746    #[test]
747    fn test_learned_graph_building() {
748        let config = PositionEncodingConfig::learned(512, 2048);
749        let encoding = LearnedPositionEncoding::new(config).unwrap();
750
751        let mut graph = EinsumGraph::new();
752        graph.add_tensor("x");
753        graph.add_tensor("position_embeddings");
754
755        let outputs = encoding.build_encoding_graph(&mut graph).unwrap();
756        assert_eq!(outputs.len(), 1);
757        assert!(!graph.nodes.is_empty());
758    }
759
760    #[test]
761    fn test_relative_bias_graph_building() {
762        let config = PositionEncodingConfig::relative(512, 32, 128);
763        let encoding = RelativePositionEncoding::new(config).unwrap();
764
765        let mut graph = EinsumGraph::new();
766        graph.add_tensor("attention_scores");
767        graph.add_tensor("relative_position_bias");
768        graph.add_tensor("relative_position_indices");
769
770        let outputs = encoding.build_bias_graph(&mut graph).unwrap();
771        assert_eq!(outputs.len(), 1);
772        assert!(!graph.nodes.is_empty());
773    }
774
775    #[test]
776    fn test_invalid_config_zero_dimension() {
777        let mut config = PositionEncodingConfig::sinusoidal(0, 2048);
778        assert!(config.validate().is_err());
779
780        config = PositionEncodingConfig::learned(512, 0);
781        assert!(config.validate().is_err());
782    }
783
784    #[test]
785    fn test_invalid_dropout() {
786        let config = PositionEncodingConfig::sinusoidal(512, 2048).with_dropout(1.5);
787        assert!(config.validate().is_err());
788    }
789
790    #[test]
791    fn test_wrong_encoding_type() {
792        let config = PositionEncodingConfig::learned(512, 2048);
793        let result = SinusoidalPositionEncoding::new(config);
794        assert!(result.is_err());
795    }
796
797    #[test]
798    fn test_rotary_config_creation() {
799        let config = PositionEncodingConfig::rotary(512, 2048);
800        assert_eq!(config.d_model, 512);
801        assert_eq!(config.max_seq_len, 2048);
802        assert!(matches!(
803            config.encoding_type,
804            PositionEncodingType::Rotary {
805                base: 10000.0,
806                scaling_factor: 1.0
807            }
808        ));
809        assert!(config.validate().is_ok());
810    }
811
812    #[test]
813    fn test_rotary_scaled_config() {
814        let config = PositionEncodingConfig::rotary_scaled(512, 4096, 10000.0, 2.0);
815        assert_eq!(config.max_seq_len, 4096);
816        match config.encoding_type {
817            PositionEncodingType::Rotary {
818                base,
819                scaling_factor,
820            } => {
821                assert!((base - 10000.0).abs() < 1e-10);
822                assert!((scaling_factor - 2.0).abs() < 1e-10);
823            }
824            _ => panic!("Expected Rotary encoding type"),
825        }
826    }
827
828    #[test]
829    fn test_rotary_encoding_creation() {
830        let config = PositionEncodingConfig::rotary(512, 2048);
831        let encoding = RotaryPositionEncoding::new(config).unwrap();
832        assert_eq!(encoding.config.d_model, 512);
833        assert_eq!(encoding.base(), 10000.0);
834        assert_eq!(encoding.scaling_factor(), 1.0);
835    }
836
837    #[test]
838    fn test_rotary_graph_building() {
839        let config = PositionEncodingConfig::rotary(512, 2048);
840        let encoding = RotaryPositionEncoding::new(config).unwrap();
841
842        let mut graph = EinsumGraph::new();
843        graph.add_tensor("x");
844        graph.add_tensor("cos_cached");
845        graph.add_tensor("sin_cached");
846
847        let outputs = encoding.build_encoding_graph(&mut graph).unwrap();
848        assert_eq!(outputs.len(), 1);
849        assert!(!graph.nodes.is_empty());
850    }
851
852    #[test]
853    fn test_rotary_requires_even_d_model() {
854        let config = PositionEncodingConfig::rotary(513, 2048); // Odd d_model
855        assert!(config.validate().is_err());
856    }
857
858    #[test]
859    fn test_alibi_config_creation() {
860        let config = PositionEncodingConfig::alibi(512, 8, 2048);
861        assert_eq!(config.d_model, 512);
862        assert_eq!(config.max_seq_len, 2048);
863        assert!(matches!(
864            config.encoding_type,
865            PositionEncodingType::Alibi {
866                n_heads: 8,
867                max_seq_len: 2048
868            }
869        ));
870        assert!(config.validate().is_ok());
871    }
872
873    #[test]
874    fn test_alibi_encoding_creation() {
875        let config = PositionEncodingConfig::alibi(512, 8, 2048);
876        let encoding = AlibiPositionEncoding::new(config).unwrap();
877        assert_eq!(encoding.n_heads(), 8);
878    }
879
880    #[test]
881    fn test_alibi_slopes_computation() {
882        let config = PositionEncodingConfig::alibi(512, 8, 2048);
883        let encoding = AlibiPositionEncoding::new(config).unwrap();
884        let slopes = encoding.compute_slopes();
885
886        assert_eq!(slopes.len(), 8);
887        // Slopes should be monotonically decreasing
888        for i in 1..slopes.len() {
889            assert!(slopes[i] < slopes[i - 1]);
890        }
891        // First slope should be largest
892        assert!(slopes[0] < 1.0);
893        assert!(slopes[0] > 0.0);
894    }
895
896    #[test]
897    fn test_alibi_graph_building() {
898        let config = PositionEncodingConfig::alibi(512, 8, 2048);
899        let encoding = AlibiPositionEncoding::new(config).unwrap();
900
901        let mut graph = EinsumGraph::new();
902        graph.add_tensor("attention_scores");
903        graph.add_tensor("alibi_slopes");
904        graph.add_tensor("distance_matrix");
905
906        let outputs = encoding.build_bias_graph(&mut graph).unwrap();
907        assert_eq!(outputs.len(), 1);
908        assert!(!graph.nodes.is_empty());
909    }
910
911    #[test]
912    fn test_alibi_invalid_zero_heads() {
913        let config = PositionEncodingConfig::alibi(512, 0, 2048);
914        assert!(config.validate().is_err());
915    }
916}