1use crate::error::{NeuralError, Result};
7use ndarray::{Array, IxDyn};
8use num_traits::Float;
9use std::fmt::Debug;
10
11#[derive(Debug, Clone, Copy, PartialEq)]
13pub enum PositionalEncodingType {
14 Sinusoidal,
16 Learned,
18 Relative,
20}
21
22pub struct PositionalEncodingFactory;
24
25impl PositionalEncodingFactory {
26 pub fn create<F: Float + Debug + 'static>(
38 encoding_type: PositionalEncodingType,
39 max_len: usize,
40 d_model: usize,
41 ) -> Result<Box<dyn PositionalEncoding<F>>> {
42 match encoding_type {
43 PositionalEncodingType::Sinusoidal => Ok(Box::new(SinusoidalPositionalEncoding::new(
44 max_len, d_model,
45 )?)),
46 PositionalEncodingType::Learned => {
47 Ok(Box::new(LearnedPositionalEncoding::new(max_len, d_model)))
48 }
49 PositionalEncodingType::Relative => {
50 Ok(Box::new(RelativePositionalEncoding::new(max_len, d_model)))
51 }
52 }
53 }
54}
55
56pub trait PositionalEncoding<F: Float + Debug> {
58 fn forward(&self, embeddings: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>>;
68
69 fn get_encoding(&self, seq_len: usize) -> Result<Array<F, IxDyn>>;
79
80 fn update(&mut self, learning_rate: F) -> Result<()>;
86
87 fn params(&self) -> Vec<Array<F, IxDyn>> {
93 Vec::new() }
95
96 fn set_training(&mut self, _training: bool) {
98 }
100
101 fn is_training(&self) -> bool {
103 false
104 }
105}
106
107#[derive(Debug, Clone)]
113pub struct SinusoidalPositionalEncoding<F: Float + Debug> {
114 max_len: usize,
116 d_model: usize,
118 encoding: Array<F, IxDyn>,
120}
121
122impl<F: Float + Debug + 'static> SinusoidalPositionalEncoding<F> {
123 pub fn new(max_len: usize, d_model: usize) -> Result<Self> {
134 if d_model % 2 != 0 {
135 return Err(NeuralError::InvalidArchitecture(format!(
136 "Model dimension ({}) must be even for sinusoidal positional encoding",
137 d_model
138 )));
139 }
140
141 let mut encoding = Array::<F, _>::zeros((max_len, d_model));
142
143 for pos in 0..max_len {
144 for i in 0..d_model / 2 {
145 let div_term = F::from(10000.0)
146 .unwrap()
147 .powf(F::from(2.0 * i as f64 / d_model as f64).unwrap());
148
149 encoding[[pos, 2 * i]] = F::from(pos as f64).unwrap().sin() / div_term;
151
152 encoding[[pos, 2 * i + 1]] = F::from(pos as f64).unwrap().cos() / div_term;
154 }
155 }
156
157 Ok(Self {
158 max_len,
159 d_model,
160 encoding: encoding.into_dyn(),
161 })
162 }
163}
164
165impl<F: Float + Debug + 'static> PositionalEncoding<F> for SinusoidalPositionalEncoding<F> {
166 fn forward(&self, embeddings: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
167 if embeddings.ndim() < 2 {
168 return Err(NeuralError::InferenceError(
169 "Embeddings must have at least 2 dimensions [batch, seq_len, ...]".to_string(),
170 ));
171 }
172
173 let embed_shape = embeddings.shape();
174 let seq_len = embed_shape[1];
175
176 if seq_len > self.max_len {
177 return Err(NeuralError::InferenceError(format!(
178 "Sequence length ({}) exceeds maximum length ({})",
179 seq_len, self.max_len
180 )));
181 }
182
183 let pos_encoding = self.get_encoding(seq_len)?;
185
186 let mut output = embeddings.clone();
188
189 for batch_idx in 0..embed_shape[0] {
191 let mut batch_slice = output.slice_mut(ndarray::s![batch_idx, .., ..]);
192
193 for pos in 0..seq_len {
195 for dim in 0..self.d_model {
196 batch_slice[[pos, dim]] = batch_slice[[pos, dim]] + pos_encoding[[pos, dim]];
197 }
198 }
199 }
200
201 Ok(output)
202 }
203
204 fn get_encoding(&self, seq_len: usize) -> Result<Array<F, IxDyn>> {
205 if seq_len > self.max_len {
206 return Err(NeuralError::InferenceError(format!(
207 "Requested sequence length ({}) exceeds maximum length ({})",
208 seq_len, self.max_len
209 )));
210 }
211
212 Ok(self
214 .encoding
215 .slice(ndarray::s![0..seq_len, ..])
216 .to_owned()
217 .into_dyn())
218 }
219
220 fn update(&mut self, _learning_rate: F) -> Result<()> {
221 Ok(())
223 }
224
225 fn set_training(&mut self, _training: bool) {
227 }
229
230 fn is_training(&self) -> bool {
232 false
233 }
234}
235
236pub struct LearnedPositionalEncoding<F: Float + Debug> {
241 max_len: usize,
243 d_model: usize,
245 weights: Array<F, IxDyn>,
247 dweights: Array<F, IxDyn>,
249}
250
251impl<F: Float + Debug + 'static> LearnedPositionalEncoding<F> {
252 pub fn new(max_len: usize, d_model: usize) -> Self {
263 let init_scale = F::from(0.02).unwrap();
265 let weights = Array::<F, _>::from_elem((max_len, d_model), init_scale);
266 let dweights = Array::<F, _>::zeros((max_len, d_model));
267
268 Self {
269 max_len,
270 d_model,
271 weights: weights.into_dyn(),
272 dweights: dweights.into_dyn(),
273 }
274 }
275}
276
277impl<F: Float + Debug + 'static> PositionalEncoding<F> for LearnedPositionalEncoding<F> {
278 fn forward(&self, embeddings: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
279 if embeddings.ndim() < 2 {
280 return Err(NeuralError::InferenceError(
281 "Embeddings must have at least 2 dimensions [batch, seq_len, ...]".to_string(),
282 ));
283 }
284
285 let embed_shape = embeddings.shape();
286 let seq_len = embed_shape[1];
287
288 if seq_len > self.max_len {
289 return Err(NeuralError::InferenceError(format!(
290 "Sequence length ({}) exceeds maximum length ({})",
291 seq_len, self.max_len
292 )));
293 }
294
295 let pos_encoding = self.get_encoding(seq_len)?;
297
298 let mut output = embeddings.clone();
300
301 for batch_idx in 0..embed_shape[0] {
303 let mut batch_slice = output.slice_mut(ndarray::s![batch_idx, .., ..]);
304
305 for pos in 0..seq_len {
307 for dim in 0..self.d_model {
308 batch_slice[[pos, dim]] = batch_slice[[pos, dim]] + pos_encoding[[pos, dim]];
309 }
310 }
311 }
312
313 Ok(output)
314 }
315
316 fn get_encoding(&self, seq_len: usize) -> Result<Array<F, IxDyn>> {
317 if seq_len > self.max_len {
318 return Err(NeuralError::InferenceError(format!(
319 "Requested sequence length ({}) exceeds maximum length ({})",
320 seq_len, self.max_len
321 )));
322 }
323
324 Ok(self
326 .weights
327 .slice(ndarray::s![0..seq_len, ..])
328 .to_owned()
329 .into_dyn())
330 }
331
332 fn update(&mut self, learning_rate: F) -> Result<()> {
333 let small_change = F::from(0.001).unwrap();
336 let lr = learning_rate * small_change;
337
338 for i in 0..self.max_len {
339 for j in 0..self.d_model {
340 self.weights[[i, j]] = self.weights[[i, j]] - lr * self.dweights[[i, j]];
341 }
342 }
343
344 Ok(())
345 }
346
347 fn params(&self) -> Vec<Array<F, IxDyn>> {
348 vec![self.weights.clone()]
350 }
351}
352
353pub struct RelativePositionalEncoding<F: Float + Debug> {
358 max_len: usize,
360 d_model: usize,
362 weights: Array<F, IxDyn>,
364 dweights: Array<F, IxDyn>,
366}
367
368impl<F: Float + Debug + 'static> RelativePositionalEncoding<F> {
369 pub fn new(max_len: usize, d_model: usize) -> Self {
380 let rel_size = 2 * max_len - 1;
383
384 let init_scale = F::from(0.02).unwrap();
386 let weights = Array::<F, _>::from_elem((rel_size, d_model), init_scale);
387 let dweights = Array::<F, _>::zeros((rel_size, d_model));
388
389 Self {
390 max_len,
391 d_model,
392 weights: weights.into_dyn(),
393 dweights: dweights.into_dyn(),
394 }
395 }
396
397 #[allow(dead_code)]
401 fn rel_pos_to_index(&self, rel_pos: isize) -> usize {
402 (rel_pos + self.max_len as isize - 1) as usize
407 }
408}
409
410impl<F: Float + Debug + 'static> PositionalEncoding<F> for RelativePositionalEncoding<F> {
411 fn forward(&self, embeddings: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
412 if embeddings.ndim() < 2 {
413 return Err(NeuralError::InferenceError(
414 "Embeddings must have at least 2 dimensions [batch, seq_len, ...]".to_string(),
415 ));
416 }
417
418 let embed_shape = embeddings.shape();
419 let seq_len = embed_shape[1];
420
421 if seq_len > self.max_len {
422 return Err(NeuralError::InferenceError(format!(
423 "Sequence length ({}) exceeds maximum length ({})",
424 seq_len, self.max_len
425 )));
426 }
427
428 Ok(embeddings.clone())
432 }
433
434 fn get_encoding(&self, seq_len: usize) -> Result<Array<F, IxDyn>> {
435 if seq_len > self.max_len {
436 return Err(NeuralError::InferenceError(format!(
437 "Requested sequence length ({}) exceeds maximum length ({})",
438 seq_len, self.max_len
439 )));
440 }
441
442 let encoding = Array::<F, _>::zeros((seq_len, self.d_model));
446 Ok(encoding.into_dyn())
447 }
448
449 fn update(&mut self, learning_rate: F) -> Result<()> {
450 let small_change = F::from(0.001).unwrap();
453 let lr = learning_rate * small_change;
454
455 let rel_size = 2 * self.max_len - 1;
456 for i in 0..rel_size {
457 for j in 0..self.d_model {
458 self.weights[[i, j]] = self.weights[[i, j]] - lr * self.dweights[[i, j]];
459 }
460 }
461
462 Ok(())
463 }
464
465 fn params(&self) -> Vec<Array<F, IxDyn>> {
466 vec![self.weights.clone()]
468 }
469}
470
471#[cfg(test)]
472mod tests {
473 use super::*;
474 use approx::assert_relative_eq;
475 use ndarray::Array3;
476
477 #[test]
478 fn test_sinusoidal_encoding_shape() {
479 let max_len = 100;
480 let d_model = 64;
481
482 let pos_enc = SinusoidalPositionalEncoding::<f64>::new(max_len, d_model).unwrap();
483
484 let encoding = pos_enc.get_encoding(50).unwrap();
486
487 assert_eq!(encoding.shape(), &[50, d_model]);
489 }
490
491 #[test]
492 fn test_sinusoidal_encoding_properties() {
493 let max_len = 100;
494 let d_model = 64;
495
496 let pos_enc = SinusoidalPositionalEncoding::<f64>::new(max_len, d_model).unwrap();
497
498 let encoding = pos_enc.get_encoding(max_len).unwrap();
500
501 let pos0 = encoding.slice(ndarray::s![0, ..]).to_owned();
503 let pos1 = encoding.slice(ndarray::s![1, ..]).to_owned();
504
505 let mut all_equal = true;
507 for i in 0..d_model {
508 if (pos0[i] - pos1[i]).abs() > 1e-10 {
509 all_equal = false;
510 break;
511 }
512 }
513
514 assert!(
515 !all_equal,
516 "Positions 0 and 1 should have different encodings"
517 );
518 }
519
520 #[test]
521 fn test_positional_encoding_factory() {
522 let max_len = 100;
523 let d_model = 64;
524
525 let sinusoidal = PositionalEncodingFactory::create::<f64>(
527 PositionalEncodingType::Sinusoidal,
528 max_len,
529 d_model,
530 )
531 .unwrap();
532
533 let learned = PositionalEncodingFactory::create::<f64>(
534 PositionalEncodingType::Learned,
535 max_len,
536 d_model,
537 )
538 .unwrap();
539
540 let relative = PositionalEncodingFactory::create::<f64>(
541 PositionalEncodingType::Relative,
542 max_len,
543 d_model,
544 )
545 .unwrap();
546
547 let batch_size = 2;
549 let seq_len = 10;
550 let embeddings = Array3::<f64>::zeros((batch_size, seq_len, d_model)).into_dyn();
551
552 let _ = sinusoidal.forward(&embeddings).unwrap();
554 let _ = learned.forward(&embeddings).unwrap();
555 let _ = relative.forward(&embeddings).unwrap();
556 }
557
558 #[test]
559 fn test_sinusoidal_encoding_addition() {
560 let max_len = 100;
561 let d_model = 64;
562
563 let pos_enc = SinusoidalPositionalEncoding::<f64>::new(max_len, d_model).unwrap();
564
565 let batch_size = 2;
567 let seq_len = 10;
568 let embeddings = Array3::<f64>::from_elem((batch_size, seq_len, d_model), 1.0).into_dyn();
569
570 let encoding = pos_enc.get_encoding(seq_len).unwrap();
572
573 let output = pos_enc.forward(&embeddings).unwrap();
575
576 for b in 0..batch_size {
578 for s in 0..seq_len {
579 for d in 0..d_model {
580 assert_relative_eq!(output[[b, s, d]], 1.0 + encoding[[s, d]], epsilon = 1e-10);
582 }
583 }
584 }
585 }
586}