1use serde::{Deserialize, Serialize};
24use tensorlogic_ir::{EinsumGraph, EinsumNode};
25
26use crate::error::{Result, TrustformerError};
27
28#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
30pub struct PositionEncodingConfig {
31 pub d_model: usize,
33 pub max_seq_len: usize,
35 pub encoding_type: PositionEncodingType,
37 pub dropout: f64,
39}
40
41#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
43pub enum PositionEncodingType {
44 Sinusoidal {
46 base: f64,
48 },
49 Learned,
51 Relative {
53 num_buckets: usize,
55 max_distance: usize,
57 },
58 Rotary {
60 base: f64,
62 scaling_factor: f64,
64 },
65 Alibi {
67 n_heads: usize,
69 max_seq_len: usize,
71 },
72}
73
74impl PositionEncodingConfig {
75 pub fn sinusoidal(d_model: usize, max_seq_len: usize) -> Self {
77 Self {
78 d_model,
79 max_seq_len,
80 encoding_type: PositionEncodingType::Sinusoidal { base: 10000.0 },
81 dropout: 0.0,
82 }
83 }
84
85 pub fn learned(d_model: usize, max_seq_len: usize) -> Self {
87 Self {
88 d_model,
89 max_seq_len,
90 encoding_type: PositionEncodingType::Learned,
91 dropout: 0.0,
92 }
93 }
94
95 pub fn relative(d_model: usize, num_buckets: usize, max_distance: usize) -> Self {
97 Self {
98 d_model,
99 max_seq_len: 0, encoding_type: PositionEncodingType::Relative {
101 num_buckets,
102 max_distance,
103 },
104 dropout: 0.0,
105 }
106 }
107
108 pub fn rotary(d_model: usize, max_seq_len: usize) -> Self {
110 Self {
111 d_model,
112 max_seq_len,
113 encoding_type: PositionEncodingType::Rotary {
114 base: 10000.0,
115 scaling_factor: 1.0,
116 },
117 dropout: 0.0,
118 }
119 }
120
121 pub fn rotary_scaled(
123 d_model: usize,
124 max_seq_len: usize,
125 base: f64,
126 scaling_factor: f64,
127 ) -> Self {
128 Self {
129 d_model,
130 max_seq_len,
131 encoding_type: PositionEncodingType::Rotary {
132 base,
133 scaling_factor,
134 },
135 dropout: 0.0,
136 }
137 }
138
139 pub fn alibi(d_model: usize, n_heads: usize, max_seq_len: usize) -> Self {
141 Self {
142 d_model,
143 max_seq_len,
144 encoding_type: PositionEncodingType::Alibi {
145 n_heads,
146 max_seq_len,
147 },
148 dropout: 0.0,
149 }
150 }
151
152 pub fn with_dropout(mut self, dropout: f64) -> Self {
154 self.dropout = dropout;
155 self
156 }
157
158 pub fn validate(&self) -> Result<()> {
160 if self.d_model == 0 {
161 return Err(TrustformerError::InvalidDimension {
162 expected: 1,
163 got: 0,
164 context: "d_model must be positive".to_string(),
165 });
166 }
167
168 if !(0.0..=1.0).contains(&self.dropout) {
169 return Err(TrustformerError::InvalidDimension {
170 expected: 1,
171 got: 0,
172 context: format!("dropout must be in [0,1], got {}", self.dropout),
173 });
174 }
175
176 match &self.encoding_type {
177 PositionEncodingType::Sinusoidal { base } => {
178 if *base <= 0.0 {
179 return Err(TrustformerError::InvalidDimension {
180 expected: 1,
181 got: 0,
182 context: "base must be positive".to_string(),
183 });
184 }
185 }
186 PositionEncodingType::Relative {
187 num_buckets,
188 max_distance,
189 } => {
190 if *num_buckets == 0 {
191 return Err(TrustformerError::InvalidDimension {
192 expected: 1,
193 got: 0,
194 context: "num_buckets must be positive".to_string(),
195 });
196 }
197 if *max_distance == 0 {
198 return Err(TrustformerError::InvalidDimension {
199 expected: 1,
200 got: 0,
201 context: "max_distance must be positive".to_string(),
202 });
203 }
204 }
205 PositionEncodingType::Learned => {
206 if self.max_seq_len == 0 {
207 return Err(TrustformerError::InvalidDimension {
208 expected: 1,
209 got: 0,
210 context: "max_seq_len must be positive for learned encoding".to_string(),
211 });
212 }
213 }
214 PositionEncodingType::Rotary {
215 base,
216 scaling_factor,
217 } => {
218 if *base <= 0.0 {
219 return Err(TrustformerError::InvalidDimension {
220 expected: 1,
221 got: 0,
222 context: "RoPE base must be positive".to_string(),
223 });
224 }
225 if *scaling_factor <= 0.0 {
226 return Err(TrustformerError::InvalidDimension {
227 expected: 1,
228 got: 0,
229 context: "RoPE scaling_factor must be positive".to_string(),
230 });
231 }
232 if self.max_seq_len == 0 {
233 return Err(TrustformerError::InvalidDimension {
234 expected: 1,
235 got: 0,
236 context: "max_seq_len must be positive for RoPE".to_string(),
237 });
238 }
239 if !self.d_model.is_multiple_of(2) {
240 return Err(TrustformerError::InvalidDimension {
241 expected: 1,
242 got: 0,
243 context: "d_model must be even for RoPE".to_string(),
244 });
245 }
246 }
247 PositionEncodingType::Alibi {
248 n_heads,
249 max_seq_len,
250 } => {
251 if *n_heads == 0 {
252 return Err(TrustformerError::InvalidDimension {
253 expected: 1,
254 got: 0,
255 context: "n_heads must be positive for ALiBi".to_string(),
256 });
257 }
258 if *max_seq_len == 0 {
259 return Err(TrustformerError::InvalidDimension {
260 expected: 1,
261 got: 0,
262 context: "max_seq_len must be positive for ALiBi".to_string(),
263 });
264 }
265 }
266 }
267
268 Ok(())
269 }
270}
271
272#[derive(Clone, Debug)]
274pub struct SinusoidalPositionEncoding {
275 pub config: PositionEncodingConfig,
277}
278
279impl SinusoidalPositionEncoding {
280 pub fn new(config: PositionEncodingConfig) -> Result<Self> {
282 config.validate()?;
283 match config.encoding_type {
284 PositionEncodingType::Sinusoidal { .. } => Ok(Self { config }),
285 _ => Err(TrustformerError::InvalidDimension {
286 expected: 0,
287 got: 1,
288 context: "Expected Sinusoidal encoding type".to_string(),
289 }),
290 }
291 }
292
293 pub fn build_encoding_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
302 let pe_tensor = graph.add_tensor("sinusoidal_pe");
308
309 let output_tensor = graph.add_tensor("x_with_pe");
312 let add_node = EinsumNode::elem_binary("add", 0, pe_tensor, output_tensor);
313 graph.add_node(add_node)?;
314
315 if self.config.dropout > 0.0 {
317 let dropout_tensor = graph.add_tensor("pe_dropout_output");
318 let dropout_node = EinsumNode::elem_unary(
319 format!("dropout_{}", self.config.dropout),
320 output_tensor,
321 dropout_tensor,
322 );
323 graph.add_node(dropout_node)?;
324 Ok(vec![dropout_tensor])
325 } else {
326 Ok(vec![output_tensor])
327 }
328 }
329
330 pub fn base(&self) -> f64 {
332 match self.config.encoding_type {
333 PositionEncodingType::Sinusoidal { base } => base,
334 _ => 10000.0,
335 }
336 }
337}
338
339#[derive(Clone, Debug)]
341pub struct LearnedPositionEncoding {
342 pub config: PositionEncodingConfig,
344}
345
346impl LearnedPositionEncoding {
347 pub fn new(config: PositionEncodingConfig) -> Result<Self> {
349 config.validate()?;
350 match config.encoding_type {
351 PositionEncodingType::Learned => Ok(Self { config }),
352 _ => Err(TrustformerError::InvalidDimension {
353 expected: 0,
354 got: 1,
355 context: "Expected Learned encoding type".to_string(),
356 }),
357 }
358 }
359
360 pub fn build_encoding_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
370 let pe_lookup = graph.add_tensor("pe_lookup");
373 let lookup_node = EinsumNode::elem_unary("gather_pos_emb", 1, pe_lookup);
374 graph.add_node(lookup_node)?;
375
376 let output_tensor = graph.add_tensor("x_with_learned_pe");
378 let add_node = EinsumNode::elem_binary("add", 0, pe_lookup, output_tensor);
379 graph.add_node(add_node)?;
380
381 if self.config.dropout > 0.0 {
383 let dropout_tensor = graph.add_tensor("learned_pe_dropout_output");
384 let dropout_node = EinsumNode::elem_unary(
385 format!("dropout_{}", self.config.dropout),
386 output_tensor,
387 dropout_tensor,
388 );
389 graph.add_node(dropout_node)?;
390 Ok(vec![dropout_tensor])
391 } else {
392 Ok(vec![output_tensor])
393 }
394 }
395
396 pub fn max_seq_len(&self) -> usize {
398 self.config.max_seq_len
399 }
400}
401
402#[derive(Clone, Debug)]
404pub struct RelativePositionEncoding {
405 pub config: PositionEncodingConfig,
407}
408
409impl RelativePositionEncoding {
410 pub fn new(config: PositionEncodingConfig) -> Result<Self> {
412 config.validate()?;
413 match config.encoding_type {
414 PositionEncodingType::Relative { .. } => Ok(Self { config }),
415 _ => Err(TrustformerError::InvalidDimension {
416 expected: 0,
417 got: 1,
418 context: "Expected Relative encoding type".to_string(),
419 }),
420 }
421 }
422
423 pub fn build_bias_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
433 let bias_lookup = graph.add_tensor("rel_pos_bias_lookup");
435 let lookup_node = EinsumNode::elem_unary("gather_rel_bias", 1, bias_lookup);
436 graph.add_node(lookup_node)?;
437
438 let output_tensor = graph.add_tensor("scores_with_rel_bias");
441 let add_node = EinsumNode::elem_binary("add", 0, bias_lookup, output_tensor);
442 graph.add_node(add_node)?;
443
444 Ok(vec![output_tensor])
445 }
446
447 pub fn num_buckets(&self) -> usize {
449 match self.config.encoding_type {
450 PositionEncodingType::Relative { num_buckets, .. } => num_buckets,
451 _ => 0,
452 }
453 }
454
455 pub fn max_distance(&self) -> usize {
457 match self.config.encoding_type {
458 PositionEncodingType::Relative { max_distance, .. } => max_distance,
459 _ => 0,
460 }
461 }
462}
463
464#[derive(Clone, Debug)]
473pub struct RotaryPositionEncoding {
474 pub config: PositionEncodingConfig,
476}
477
478impl RotaryPositionEncoding {
479 pub fn new(config: PositionEncodingConfig) -> Result<Self> {
481 config.validate()?;
482 match config.encoding_type {
483 PositionEncodingType::Rotary { .. } => Ok(Self { config }),
484 _ => Err(TrustformerError::InvalidDimension {
485 expected: 0,
486 got: 1,
487 context: "Expected Rotary encoding type".to_string(),
488 }),
489 }
490 }
491
492 pub fn build_encoding_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
507 let x_even = graph.add_tensor("rope_x_even");
512 let x_odd = graph.add_tensor("rope_x_odd");
513 let split_node = EinsumNode::elem_unary("split_even_odd", 0, x_even);
514 graph.add_node(split_node)?;
515
516 let even_cos = graph.add_tensor("rope_even_cos");
519 let even_cos_node = EinsumNode::elem_binary("mul", x_even, 1, even_cos);
520 graph.add_node(even_cos_node)?;
521
522 let odd_sin = graph.add_tensor("rope_odd_sin");
524 let odd_sin_node = EinsumNode::elem_binary("mul", x_odd, 2, odd_sin);
525 graph.add_node(odd_sin_node)?;
526
527 let rotated_0 = graph.add_tensor("rope_rotated_0");
529 let sub_node = EinsumNode::elem_binary("sub", even_cos, odd_sin, rotated_0);
530 graph.add_node(sub_node)?;
531
532 let even_sin = graph.add_tensor("rope_even_sin");
534 let even_sin_node = EinsumNode::elem_binary("mul", x_even, 2, even_sin);
535 graph.add_node(even_sin_node)?;
536
537 let odd_cos = graph.add_tensor("rope_odd_cos");
539 let odd_cos_node = EinsumNode::elem_binary("mul", x_odd, 1, odd_cos);
540 graph.add_node(odd_cos_node)?;
541
542 let rotated_1 = graph.add_tensor("rope_rotated_1");
544 let add_node = EinsumNode::elem_binary("add", even_sin, odd_cos, rotated_1);
545 graph.add_node(add_node)?;
546
547 let output_tensor = graph.add_tensor("rope_output");
549 let concat_node = EinsumNode::elem_binary("concat", rotated_0, rotated_1, output_tensor);
550 graph.add_node(concat_node)?;
551
552 Ok(vec![output_tensor])
553 }
554
555 pub fn base(&self) -> f64 {
557 match self.config.encoding_type {
558 PositionEncodingType::Rotary { base, .. } => base,
559 _ => 10000.0,
560 }
561 }
562
563 pub fn scaling_factor(&self) -> f64 {
565 match self.config.encoding_type {
566 PositionEncodingType::Rotary { scaling_factor, .. } => scaling_factor,
567 _ => 1.0,
568 }
569 }
570}
571
572#[derive(Clone, Debug)]
581pub struct AlibiPositionEncoding {
582 pub config: PositionEncodingConfig,
584}
585
586impl AlibiPositionEncoding {
587 pub fn new(config: PositionEncodingConfig) -> Result<Self> {
589 config.validate()?;
590 match config.encoding_type {
591 PositionEncodingType::Alibi { .. } => Ok(Self { config }),
592 _ => Err(TrustformerError::InvalidDimension {
593 expected: 0,
594 got: 1,
595 context: "Expected Alibi encoding type".to_string(),
596 }),
597 }
598 }
599
600 pub fn build_bias_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
614 let slopes_expanded = graph.add_tensor("alibi_slopes_expanded");
620 let expand_node = EinsumNode::elem_unary("expand_dims", 1, slopes_expanded);
621 graph.add_node(expand_node)?;
622
623 let bias = graph.add_tensor("alibi_bias");
624 let bias_node = EinsumNode::elem_binary("mul", slopes_expanded, 2, bias);
625 graph.add_node(bias_node)?;
626
627 let neg_bias = graph.add_tensor("alibi_neg_bias");
628 let neg_node = EinsumNode::elem_unary("neg", bias, neg_bias);
629 graph.add_node(neg_node)?;
630
631 let output_tensor = graph.add_tensor("scores_with_alibi");
635 let add_node = EinsumNode::elem_binary("add", 0, neg_bias, output_tensor);
636 graph.add_node(add_node)?;
637
638 Ok(vec![output_tensor])
639 }
640
641 pub fn n_heads(&self) -> usize {
643 match self.config.encoding_type {
644 PositionEncodingType::Alibi { n_heads, .. } => n_heads,
645 _ => 0,
646 }
647 }
648
649 pub fn compute_slopes(&self) -> Vec<f64> {
654 let n = self.n_heads();
655 (1..=n)
656 .map(|i| 2_f64.powf(-8.0 * (i as f64) / (n as f64)))
657 .collect()
658 }
659}
660
661#[cfg(test)]
662mod tests {
663 use super::*;
664
665 #[test]
666 fn test_sinusoidal_config_creation() {
667 let config = PositionEncodingConfig::sinusoidal(512, 2048);
668 assert_eq!(config.d_model, 512);
669 assert_eq!(config.max_seq_len, 2048);
670 assert!(matches!(
671 config.encoding_type,
672 PositionEncodingType::Sinusoidal { base: 10000.0 }
673 ));
674 assert!(config.validate().is_ok());
675 }
676
677 #[test]
678 fn test_learned_config_creation() {
679 let config = PositionEncodingConfig::learned(512, 2048);
680 assert_eq!(config.d_model, 512);
681 assert_eq!(config.max_seq_len, 2048);
682 assert!(matches!(
683 config.encoding_type,
684 PositionEncodingType::Learned
685 ));
686 assert!(config.validate().is_ok());
687 }
688
689 #[test]
690 fn test_relative_config_creation() {
691 let config = PositionEncodingConfig::relative(512, 32, 128);
692 assert_eq!(config.d_model, 512);
693 assert!(matches!(
694 config.encoding_type,
695 PositionEncodingType::Relative {
696 num_buckets: 32,
697 max_distance: 128
698 }
699 ));
700 assert!(config.validate().is_ok());
701 }
702
703 #[test]
704 fn test_config_with_dropout() {
705 let config = PositionEncodingConfig::sinusoidal(512, 2048).with_dropout(0.1);
706 assert!((config.dropout - 0.1).abs() < 1e-10);
707 assert!(config.validate().is_ok());
708 }
709
710 #[test]
711 fn test_sinusoidal_encoding_creation() {
712 let config = PositionEncodingConfig::sinusoidal(512, 2048);
713 let encoding = SinusoidalPositionEncoding::new(config).unwrap();
714 assert_eq!(encoding.config.d_model, 512);
715 assert_eq!(encoding.base(), 10000.0);
716 }
717
718 #[test]
719 fn test_learned_encoding_creation() {
720 let config = PositionEncodingConfig::learned(512, 2048);
721 let encoding = LearnedPositionEncoding::new(config).unwrap();
722 assert_eq!(encoding.max_seq_len(), 2048);
723 }
724
725 #[test]
726 fn test_relative_encoding_creation() {
727 let config = PositionEncodingConfig::relative(512, 32, 128);
728 let encoding = RelativePositionEncoding::new(config).unwrap();
729 assert_eq!(encoding.num_buckets(), 32);
730 assert_eq!(encoding.max_distance(), 128);
731 }
732
733 #[test]
734 fn test_sinusoidal_graph_building() {
735 let config = PositionEncodingConfig::sinusoidal(512, 2048);
736 let encoding = SinusoidalPositionEncoding::new(config).unwrap();
737
738 let mut graph = EinsumGraph::new();
739 graph.add_tensor("x");
740
741 let outputs = encoding.build_encoding_graph(&mut graph).unwrap();
742 assert_eq!(outputs.len(), 1);
743 assert!(!graph.nodes.is_empty());
744 }
745
746 #[test]
747 fn test_learned_graph_building() {
748 let config = PositionEncodingConfig::learned(512, 2048);
749 let encoding = LearnedPositionEncoding::new(config).unwrap();
750
751 let mut graph = EinsumGraph::new();
752 graph.add_tensor("x");
753 graph.add_tensor("position_embeddings");
754
755 let outputs = encoding.build_encoding_graph(&mut graph).unwrap();
756 assert_eq!(outputs.len(), 1);
757 assert!(!graph.nodes.is_empty());
758 }
759
760 #[test]
761 fn test_relative_bias_graph_building() {
762 let config = PositionEncodingConfig::relative(512, 32, 128);
763 let encoding = RelativePositionEncoding::new(config).unwrap();
764
765 let mut graph = EinsumGraph::new();
766 graph.add_tensor("attention_scores");
767 graph.add_tensor("relative_position_bias");
768 graph.add_tensor("relative_position_indices");
769
770 let outputs = encoding.build_bias_graph(&mut graph).unwrap();
771 assert_eq!(outputs.len(), 1);
772 assert!(!graph.nodes.is_empty());
773 }
774
775 #[test]
776 fn test_invalid_config_zero_dimension() {
777 let mut config = PositionEncodingConfig::sinusoidal(0, 2048);
778 assert!(config.validate().is_err());
779
780 config = PositionEncodingConfig::learned(512, 0);
781 assert!(config.validate().is_err());
782 }
783
784 #[test]
785 fn test_invalid_dropout() {
786 let config = PositionEncodingConfig::sinusoidal(512, 2048).with_dropout(1.5);
787 assert!(config.validate().is_err());
788 }
789
790 #[test]
791 fn test_wrong_encoding_type() {
792 let config = PositionEncodingConfig::learned(512, 2048);
793 let result = SinusoidalPositionEncoding::new(config);
794 assert!(result.is_err());
795 }
796
797 #[test]
798 fn test_rotary_config_creation() {
799 let config = PositionEncodingConfig::rotary(512, 2048);
800 assert_eq!(config.d_model, 512);
801 assert_eq!(config.max_seq_len, 2048);
802 assert!(matches!(
803 config.encoding_type,
804 PositionEncodingType::Rotary {
805 base: 10000.0,
806 scaling_factor: 1.0
807 }
808 ));
809 assert!(config.validate().is_ok());
810 }
811
812 #[test]
813 fn test_rotary_scaled_config() {
814 let config = PositionEncodingConfig::rotary_scaled(512, 4096, 10000.0, 2.0);
815 assert_eq!(config.max_seq_len, 4096);
816 match config.encoding_type {
817 PositionEncodingType::Rotary {
818 base,
819 scaling_factor,
820 } => {
821 assert!((base - 10000.0).abs() < 1e-10);
822 assert!((scaling_factor - 2.0).abs() < 1e-10);
823 }
824 _ => panic!("Expected Rotary encoding type"),
825 }
826 }
827
828 #[test]
829 fn test_rotary_encoding_creation() {
830 let config = PositionEncodingConfig::rotary(512, 2048);
831 let encoding = RotaryPositionEncoding::new(config).unwrap();
832 assert_eq!(encoding.config.d_model, 512);
833 assert_eq!(encoding.base(), 10000.0);
834 assert_eq!(encoding.scaling_factor(), 1.0);
835 }
836
837 #[test]
838 fn test_rotary_graph_building() {
839 let config = PositionEncodingConfig::rotary(512, 2048);
840 let encoding = RotaryPositionEncoding::new(config).unwrap();
841
842 let mut graph = EinsumGraph::new();
843 graph.add_tensor("x");
844 graph.add_tensor("cos_cached");
845 graph.add_tensor("sin_cached");
846
847 let outputs = encoding.build_encoding_graph(&mut graph).unwrap();
848 assert_eq!(outputs.len(), 1);
849 assert!(!graph.nodes.is_empty());
850 }
851
852 #[test]
853 fn test_rotary_requires_even_d_model() {
854 let config = PositionEncodingConfig::rotary(513, 2048); assert!(config.validate().is_err());
856 }
857
858 #[test]
859 fn test_alibi_config_creation() {
860 let config = PositionEncodingConfig::alibi(512, 8, 2048);
861 assert_eq!(config.d_model, 512);
862 assert_eq!(config.max_seq_len, 2048);
863 assert!(matches!(
864 config.encoding_type,
865 PositionEncodingType::Alibi {
866 n_heads: 8,
867 max_seq_len: 2048
868 }
869 ));
870 assert!(config.validate().is_ok());
871 }
872
873 #[test]
874 fn test_alibi_encoding_creation() {
875 let config = PositionEncodingConfig::alibi(512, 8, 2048);
876 let encoding = AlibiPositionEncoding::new(config).unwrap();
877 assert_eq!(encoding.n_heads(), 8);
878 }
879
880 #[test]
881 fn test_alibi_slopes_computation() {
882 let config = PositionEncodingConfig::alibi(512, 8, 2048);
883 let encoding = AlibiPositionEncoding::new(config).unwrap();
884 let slopes = encoding.compute_slopes();
885
886 assert_eq!(slopes.len(), 8);
887 for i in 1..slopes.len() {
889 assert!(slopes[i] < slopes[i - 1]);
890 }
891 assert!(slopes[0] < 1.0);
893 assert!(slopes[0] > 0.0);
894 }
895
896 #[test]
897 fn test_alibi_graph_building() {
898 let config = PositionEncodingConfig::alibi(512, 8, 2048);
899 let encoding = AlibiPositionEncoding::new(config).unwrap();
900
901 let mut graph = EinsumGraph::new();
902 graph.add_tensor("attention_scores");
903 graph.add_tensor("alibi_slopes");
904 graph.add_tensor("distance_matrix");
905
906 let outputs = encoding.build_bias_graph(&mut graph).unwrap();
907 assert_eq!(outputs.len(), 1);
908 assert!(!graph.nodes.is_empty());
909 }
910
911 #[test]
912 fn test_alibi_invalid_zero_heads() {
913 let config = PositionEncodingConfig::alibi(512, 0, 2048);
914 assert!(config.validate().is_err());
915 }
916}