1use ndarray::{Array, ArrayBase, Data, Dimension, Ix1, IxDyn, ScalarOperand};
7use num_traits::Float;
8use rand::prelude::*;
9use std::fmt::Debug;
10use std::sync::{Arc, RwLock};
11
12use crate::error::{Error, Result};
13use crate::layers::Layer;
14use crate::utils::initializers;
15
16pub struct EmbeddingConfig {
18 pub num_embeddings: usize,
20 pub embedding_dim: usize,
22 pub padding_idx: Option<usize>,
24 pub max_norm: Option<f64>,
26 pub norm_type: f64,
28 pub scale_grad_by_freq: bool,
30 pub sparse: bool,
32}
33
34impl Default for EmbeddingConfig {
35 fn default() -> Self {
36 Self {
37 num_embeddings: 1,
38 embedding_dim: 1,
39 padding_idx: None,
40 max_norm: None,
41 norm_type: 2.0,
42 scale_grad_by_freq: false,
43 sparse: false,
44 }
45 }
46}
47
48pub struct Embedding<F: Float + Debug + ScalarOperand> {
54 pub config: EmbeddingConfig,
56 pub weight: Array<F, IxDyn>,
58 weight_grad: Array<F, IxDyn>,
60 freq_counter: Option<Vec<usize>>,
62}
63
64impl<F: Float + Debug + ScalarOperand> Embedding<F> {
65 pub fn new(config: EmbeddingConfig) -> Result<Self> {
67 if config.num_embeddings == 0 {
68 return Err(Error::InvalidArchitecture(
69 "num_embeddings must be greater than 0".to_string(),
70 ));
71 }
72 if config.embedding_dim == 0 {
73 return Err(Error::InvalidArchitecture(
74 "embedding_dim must be greater than 0".to_string(),
75 ));
76 }
77
78 if let Some(idx) = config.padding_idx {
80 if idx >= config.num_embeddings {
81 return Err(Error::InvalidArchitecture(format!(
82 "padding_idx ({}) must be less than num_embeddings ({})",
83 idx, config.num_embeddings
84 )));
85 }
86 }
87
88 let weight_shape = IxDyn(&[config.num_embeddings, config.embedding_dim]);
90
91 let mut rng = rand::rng();
93 let mut weight = Array::from_shape_fn(weight_shape.clone(), |_| {
94 let value: f64 = rng.random::<f64>();
95 let scaled_value = (value * 2.0 - 1.0) * 0.5;
97 F::from(scaled_value).unwrap()
98 });
99
100 let weight_grad = Array::zeros(weight_shape.clone());
102
103 if let Some(idx) = config.padding_idx {
105 let mut slice = weight.slice_mut(ndarray::s![idx, ..]);
106 for item in slice.iter_mut() {
107 *item = F::zero();
108 }
109 }
110
111 let freq_counter = if config.scale_grad_by_freq {
113 Some(vec![0; config.num_embeddings])
114 } else {
115 None
116 };
117
118 Ok(Self {
119 config,
120 weight,
121 weight_grad,
122 freq_counter,
123 })
124 }
125
126 pub fn from_pretrained(
128 embeddings: Array<F, IxDyn>,
129 padding_idx: Option<usize>,
130 max_norm: Option<f64>,
131 norm_type: f64,
132 scale_grad_by_freq: bool,
133 sparse: bool,
134 ) -> Result<Self> {
135 if embeddings.ndim() != 2 {
136 return Err(Error::InvalidArchitecture(
137 "Embeddings parameter is expected to be 2-dimensional".to_string(),
138 ));
139 }
140
141 let shape = embeddings.shape();
142 let num_embeddings = shape[0];
143 let embedding_dim = shape[1];
144
145 if let Some(idx) = padding_idx {
147 if idx >= num_embeddings {
148 return Err(Error::InvalidArchitecture(format!(
149 "padding_idx ({}) must be less than num_embeddings ({})",
150 idx, num_embeddings
151 )));
152 }
153 }
154
155 let config = EmbeddingConfig {
156 num_embeddings,
157 embedding_dim,
158 padding_idx,
159 max_norm,
160 norm_type,
161 scale_grad_by_freq,
162 sparse,
163 };
164
165 let weight = embeddings.clone();
167 let weight_grad = Array::zeros(IxDyn(&[num_embeddings, embedding_dim]));
168
169 let freq_counter = if scale_grad_by_freq {
171 Some(vec![0; num_embeddings])
172 } else {
173 None
174 };
175
176 Ok(Self {
177 config,
178 weight,
179 weight_grad,
180 freq_counter,
181 })
182 }
183
184 pub fn reset_parameters(&mut self) -> Result<()> {
186 let mut rng = rand::rng();
188 for item in self.weight.iter_mut() {
189 *item = F::from(rng.random::<f64>()).unwrap();
190 }
191
192 if let Some(idx) = self.config.padding_idx {
194 let mut slice = self.weight.slice_mut(ndarray::s![idx, ..]);
195 for item in slice.iter_mut() {
196 *item = F::zero();
197 }
198 }
199
200 self.weight_grad.fill(F::zero());
202
203 if let Some(counter) = &mut self.freq_counter {
205 counter.iter_mut().for_each(|c| *c = 0);
206 }
207
208 Ok(())
209 }
210
211 fn apply_max_norm(&mut self) -> Result<()> {
213 if let Some(max_norm) = self.config.max_norm {
214 let norm_type = self.config.norm_type;
215 let p = F::from(norm_type).ok_or_else(|| {
216 Error::InvalidArchitecture(format!("Invalid norm_type: {}", norm_type))
217 })?;
218 let max_norm = F::from(max_norm).ok_or_else(|| {
219 Error::InvalidArchitecture(format!("Invalid max_norm: {}", max_norm))
220 })?;
221
222 for i in 0..self.config.num_embeddings {
224 let mut norm = F::zero();
225 for j in 0..self.config.embedding_dim {
227 let val = self.weight[[i, j]];
228 if p == F::from(2.0).unwrap() {
229 norm = norm + val * val;
230 } else {
231 norm = norm + val.abs().powf(p);
232 }
233 }
234
235 if p == F::from(2.0).unwrap() {
236 norm = norm.sqrt();
237 } else {
238 norm = norm.powf(F::one() / p);
239 }
240
241 if norm > max_norm {
243 let scale = max_norm / norm;
244 for j in 0..self.config.embedding_dim {
245 self.weight[[i, j]] = self.weight[[i, j]] * scale;
246 }
247 }
248 }
249 }
250
251 Ok(())
252 }
253
254 fn forward_impl<D: Dimension>(
256 &mut self,
257 indices: &ArrayBase<impl Data<Elem = usize>, D>,
258 ) -> Result<Array<F, IxDyn>> {
259 for &idx in indices.iter() {
261 if idx >= self.config.num_embeddings {
262 return Err(Error::InvalidArchitecture(format!(
263 "Index {} out of bounds for embedding with {} entries",
264 idx, self.config.num_embeddings
265 )));
266 }
267 }
268
269 self.apply_max_norm()?;
271
272 if let Some(counter) = &mut self.freq_counter {
274 for &idx in indices.iter() {
275 counter[idx] += 1;
276 }
277 }
278
279 let mut output_shape = Vec::with_capacity(indices.ndim() + 1);
281 output_shape.extend_from_slice(indices.shape());
282 output_shape.push(self.config.embedding_dim);
283
284 let mut output = Array::zeros(IxDyn(output_shape.as_slice()));
285
286 let indices_flat = indices
288 .view()
289 .into_shape_with_order(IxDyn(&[indices.len()]))
290 .unwrap()
291 .into_dimensionality::<Ix1>()
292 .unwrap();
293
294 for (flat_idx, &idx) in indices_flat.iter().enumerate() {
295 if let Some(padding_idx) = self.config.padding_idx {
297 if idx == padding_idx {
298 continue;
300 }
301 }
302
303 let mut output_idx = Vec::with_capacity(indices.ndim() + 1);
305 let mut remaining = flat_idx;
306 for &dim in indices.shape().iter().rev() {
307 output_idx.push(remaining % dim);
308 remaining /= dim;
309 }
310 output_idx.reverse();
311 output_idx.push(0); for j in 0..self.config.embedding_dim {
315 output_idx.last_mut().unwrap().clone_from(&j);
316 let emb_val = self.weight[[idx, j]];
317 output[IxDyn(output_idx.as_slice())] = emb_val;
318 }
319 }
320
321 Ok(output)
322 }
323}
324
325impl<F: Float + Debug + ScalarOperand> Layer<F> for Embedding<F> {
326 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
327 let indices = input
329 .mapv(|x| x.to_usize().unwrap_or(0))
330 .into_dimensionality::<IxDyn>()?;
331
332 let mut embedding_mut = Embedding {
334 config: EmbeddingConfig {
335 num_embeddings: self.config.num_embeddings,
336 embedding_dim: self.config.embedding_dim,
337 padding_idx: self.config.padding_idx,
338 max_norm: self.config.max_norm,
339 norm_type: self.config.norm_type,
340 scale_grad_by_freq: self.config.scale_grad_by_freq,
341 sparse: self.config.sparse,
342 },
343 weight: self.weight.clone(),
344 weight_grad: self.weight_grad.clone(),
345 freq_counter: self.freq_counter.clone(),
346 };
347
348 embedding_mut.forward_impl(&indices)
349 }
350
351 fn backward(
352 &self,
353 input: &Array<F, IxDyn>,
354 _grad_output: &Array<F, IxDyn>,
355 ) -> Result<Array<F, IxDyn>> {
356 let input_shape = &input.shape();
359 Ok(Array::zeros(IxDyn(input_shape)))
360 }
361
362 fn update(&mut self, learning_rate: F) -> Result<()> {
363 let lr = learning_rate;
365
366 if let Some(counter) = &self.freq_counter {
368 for (i, &count) in counter.iter().enumerate().take(self.config.num_embeddings) {
369 if let Some(padding_idx) = self.config.padding_idx {
371 if i == padding_idx {
372 continue;
373 }
374 }
375
376 let scale = if count > 0 {
377 F::from(1.0 / count as f64).unwrap()
378 } else {
379 F::one()
380 };
381
382 for j in 0..self.config.embedding_dim {
383 self.weight[[i, j]] =
384 self.weight[[i, j]] - lr * scale * self.weight_grad[[i, j]];
385 }
386 }
387 } else {
388 for i in 0..self.config.num_embeddings {
390 if let Some(padding_idx) = self.config.padding_idx {
392 if i == padding_idx {
393 continue;
394 }
395 }
396
397 for j in 0..self.config.embedding_dim {
398 self.weight[[i, j]] = self.weight[[i, j]] - lr * self.weight_grad[[i, j]];
399 }
400 }
401 }
402
403 self.weight_grad.fill(F::zero());
405
406 if let Some(counter) = &mut self.freq_counter {
408 counter.iter_mut().for_each(|c| *c = 0);
409 }
410
411 Ok(())
412 }
413
414 fn as_any(&self) -> &dyn std::any::Any {
415 self
416 }
417
418 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
419 self
420 }
421}
422
423pub struct PositionalEmbedding<F: Float + Debug + ScalarOperand> {
428 pub max_seq_length: usize,
430 pub embedding_dim: usize,
432 pub learned: bool,
434 pub weight: Option<Array<F, IxDyn>>,
436 weight_grad: Option<Array<F, IxDyn>>,
438}
439
440impl<F: Float + Debug + ScalarOperand> PositionalEmbedding<F> {
441 pub fn new(max_seq_length: usize, embedding_dim: usize, learned: bool) -> Result<Self> {
443 if max_seq_length == 0 {
444 return Err(Error::InvalidArchitecture(
445 "max_seq_length must be greater than 0".to_string(),
446 ));
447 }
448 if embedding_dim == 0 {
449 return Err(Error::InvalidArchitecture(
450 "embedding_dim must be greater than 0".to_string(),
451 ));
452 }
453
454 if learned {
455 let weight_shape = IxDyn(&[max_seq_length, embedding_dim]);
457 let weight = Some(initializers::xavier_uniform::<F>(weight_shape.clone())?);
458 let weight_grad = Some(Array::zeros(weight_shape));
459
460 Ok(Self {
461 max_seq_length,
462 embedding_dim,
463 learned,
464 weight,
465 weight_grad,
466 })
467 } else {
468 Ok(Self {
470 max_seq_length,
471 embedding_dim,
472 learned,
473 weight: None,
474 weight_grad: None,
475 })
476 }
477 }
478
479 fn generate_sinusoidal_embeddings(&self, seq_length: usize) -> Result<Array<F, IxDyn>> {
481 if seq_length > self.max_seq_length {
482 return Err(Error::InvalidArchitecture(format!(
483 "Sequence length {} exceeds maximum supported length {}",
484 seq_length, self.max_seq_length
485 )));
486 }
487
488 let mut pos_embeddings = Array::zeros(IxDyn(&[seq_length, self.embedding_dim]));
490
491 for pos in 0..seq_length {
493 for i in 0..self.embedding_dim {
494 let div_term =
495 F::from((10000.0f64).powf(2.0 * (i / 2) as f64 / self.embedding_dim as f64))
496 .unwrap();
497
498 if i % 2 == 0 {
499 pos_embeddings[[pos, i]] = F::from(pos as f64 / div_term.to_f64().unwrap())
501 .unwrap()
502 .sin();
503 } else {
504 pos_embeddings[[pos, i]] = F::from(pos as f64 / div_term.to_f64().unwrap())
506 .unwrap()
507 .cos();
508 }
509 }
510 }
511
512 Ok(pos_embeddings)
513 }
514}
515
516impl<F: Float + Debug + ScalarOperand> Layer<F> for PositionalEmbedding<F> {
517 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
518 if input.ndim() < 2 {
520 return Err(Error::InvalidArchitecture(
521 "Input to PositionalEmbedding must be at least 2D".to_string(),
522 ));
523 }
524
525 let last_dim = input.shape().last().unwrap();
526 if *last_dim != self.embedding_dim {
527 return Err(Error::InvalidArchitecture(format!(
528 "Input embedding dimension {} doesn't match layer embedding dimension {}",
529 last_dim, self.embedding_dim
530 )));
531 }
532
533 let seq_dim = input.ndim() - 2;
535 let seq_length = input.shape()[seq_dim];
536
537 if seq_length > self.max_seq_length {
538 return Err(Error::InvalidArchitecture(format!(
539 "Input sequence length {} exceeds maximum supported length {}",
540 seq_length, self.max_seq_length
541 )));
542 }
543
544 if self.learned {
545 let pos_embeddings = self
547 .weight
548 .as_ref()
549 .unwrap()
550 .slice(ndarray::s![0..seq_length, ..]);
551
552 let mut output = input.clone();
555
556 let batch_shape = &input.shape()[..seq_dim];
558 let batch_size: usize = batch_shape.iter().product();
559
560 for batch_idx in 0..batch_size {
561 let mut multi_idx = Vec::with_capacity(seq_dim);
563 let mut remaining = batch_idx;
564 for &dim in batch_shape.iter().rev() {
565 multi_idx.push(remaining % dim);
566 remaining /= dim;
567 }
568 multi_idx.reverse();
569
570 for pos in 0..seq_length {
572 let mut full_idx = multi_idx.clone();
574 full_idx.push(pos);
575
576 for dim in 0..self.embedding_dim {
578 full_idx.push(dim);
579 let pos_val = pos_embeddings[[pos, dim]];
580 output[IxDyn(full_idx.as_slice())] =
581 output[IxDyn(full_idx.as_slice())] + pos_val;
582 full_idx.pop();
583 }
584 }
585 }
586
587 Ok(output)
588 } else {
589 let pos_embeddings = self.generate_sinusoidal_embeddings(seq_length)?;
591
592 let mut output = input.clone();
594
595 let batch_shape = &input.shape()[..seq_dim];
597 let batch_size: usize = batch_shape.iter().product();
598
599 for batch_idx in 0..batch_size {
600 let mut multi_idx = Vec::with_capacity(seq_dim);
602 let mut remaining = batch_idx;
603 for &dim in batch_shape.iter().rev() {
604 multi_idx.push(remaining % dim);
605 remaining /= dim;
606 }
607 multi_idx.reverse();
608
609 for pos in 0..seq_length {
611 let mut full_idx = multi_idx.clone();
613 full_idx.push(pos);
614
615 for dim in 0..self.embedding_dim {
617 full_idx.push(dim);
618 let pos_val = pos_embeddings[[pos, dim]];
619 output[IxDyn(full_idx.as_slice())] =
620 output[IxDyn(full_idx.as_slice())] + pos_val;
621 full_idx.pop();
622 }
623 }
624 }
625
626 Ok(output)
627 }
628 }
629
630 fn backward(
631 &self,
632 _input: &Array<F, IxDyn>,
633 grad_output: &Array<F, IxDyn>,
634 ) -> Result<Array<F, IxDyn>> {
635 Ok(grad_output.clone())
637 }
638
639 fn update(&mut self, learning_rate: F) -> Result<()> {
640 if self.learned {
642 if let (Some(weight), Some(weight_grad)) = (&mut self.weight, &self.weight_grad) {
643 let lr = learning_rate;
645 for i in 0..self.max_seq_length {
646 for j in 0..self.embedding_dim {
647 weight[[i, j]] = weight[[i, j]] - lr * weight_grad[[i, j]];
648 }
649 }
650
651 self.weight_grad = Some(Array::zeros(IxDyn(&[
653 self.max_seq_length,
654 self.embedding_dim,
655 ])));
656 }
657 }
658
659 Ok(())
660 }
661
662 fn as_any(&self) -> &dyn std::any::Any {
663 self
664 }
665
666 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
667 self
668 }
669}
670
671#[derive(Debug, Clone)]
676pub struct PatchEmbedding<F: Float + Debug + ScalarOperand + Send + Sync> {
677 pub image_size: (usize, usize),
679 pub patch_size: (usize, usize),
681 pub in_channels: usize,
683 pub embedding_dim: usize,
685 pub weight: Array<F, IxDyn>,
687 pub bias: Option<Array<F, IxDyn>>,
689 weight_grad: Arc<RwLock<Array<F, IxDyn>>>,
691 bias_grad: Option<Arc<RwLock<Array<F, IxDyn>>>>,
693 input_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
695}
696
697impl<F: Float + Debug + ScalarOperand + Send + Sync> PatchEmbedding<F> {
698 pub fn new(
700 image_size: (usize, usize),
701 patch_size: (usize, usize),
702 in_channels: usize,
703 embedding_dim: usize,
704 use_bias: bool,
705 ) -> Result<Self> {
706 if image_size.0 == 0 || image_size.1 == 0 {
708 return Err(Error::InvalidArchitecture(
709 "Image height and width must be greater than 0".to_string(),
710 ));
711 }
712 if patch_size.0 == 0 || patch_size.1 == 0 {
713 return Err(Error::InvalidArchitecture(
714 "Patch height and width must be greater than 0".to_string(),
715 ));
716 }
717 if in_channels == 0 {
718 return Err(Error::InvalidArchitecture(
719 "Number of input channels must be greater than 0".to_string(),
720 ));
721 }
722 if embedding_dim == 0 {
723 return Err(Error::InvalidArchitecture(
724 "Embedding dimension must be greater than 0".to_string(),
725 ));
726 }
727
728 if image_size.0 % patch_size.0 != 0 || image_size.1 % patch_size.1 != 0 {
730 return Err(Error::InvalidArchitecture(
731 "Image dimensions must be divisible by patch dimensions".to_string(),
732 ));
733 }
734
735 let n_h = image_size.0 / patch_size.0;
737 let n_w = image_size.1 / patch_size.1;
738 let _num_patches = n_h * n_w;
739
740 let weight_shape = IxDyn(&[embedding_dim, in_channels * patch_size.0 * patch_size.1]);
743 let weight = initializers::xavier_uniform::<F>(weight_shape.clone())?;
744 let weight_grad = Arc::new(RwLock::new(Array::zeros(weight_shape)));
745
746 let (bias, bias_grad) = if use_bias {
748 let bias = Some(Array::zeros(IxDyn(&[embedding_dim])));
749 let bias_grad = Some(Arc::new(RwLock::new(Array::zeros(IxDyn(&[embedding_dim])))));
750 (bias, bias_grad)
751 } else {
752 (None, None)
753 };
754
755 Ok(Self {
756 image_size,
757 patch_size,
758 in_channels,
759 embedding_dim,
760 weight,
761 bias,
762 weight_grad,
763 bias_grad,
764 input_cache: Arc::new(RwLock::new(None)),
765 })
766 }
767
768 pub fn num_patches(&self) -> usize {
770 let n_h = self.image_size.0 / self.patch_size.0;
771 let n_w = self.image_size.1 / self.patch_size.1;
772 n_h * n_w
773 }
774
775 fn extract_patches(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
777 if input.ndim() != 4 {
779 return Err(Error::InvalidArchitecture(
780 "Input to PatchEmbedding must be 4D [batch_size, channels, height, width]"
781 .to_string(),
782 ));
783 }
784
785 let shape = input.shape();
786 let batch_size = shape[0];
787 let channels = shape[1];
788 let height = shape[2];
789 let width = shape[3];
790
791 if channels != self.in_channels {
792 return Err(Error::InvalidArchitecture(format!(
793 "Input has {} channels, but expected {}",
794 channels, self.in_channels
795 )));
796 }
797
798 if height != self.image_size.0 || width != self.image_size.1 {
799 return Err(Error::InvalidArchitecture(format!(
800 "Input has shape [{}x{}], but expected [{}x{}]",
801 height, width, self.image_size.0, self.image_size.1
802 )));
803 }
804
805 let n_h = height / self.patch_size.0;
807 let n_w = width / self.patch_size.1;
808 let num_patches = n_h * n_w;
809
810 let patch_dim = channels * self.patch_size.0 * self.patch_size.1;
812 let mut patches = Array::zeros(IxDyn(&[batch_size, num_patches, patch_dim]));
813
814 for b in 0..batch_size {
815 for i in 0..n_h {
816 for j in 0..n_w {
817 let patch_idx = i * n_w + j;
818 let h_start = i * self.patch_size.0;
819 let w_start = j * self.patch_size.1;
820
821 let mut flat_idx = 0;
823 for c in 0..channels {
824 for ph in 0..self.patch_size.0 {
825 for pw in 0..self.patch_size.1 {
826 let h_idx = h_start + ph;
827 let w_idx = w_start + pw;
828 patches[[b, patch_idx, flat_idx]] = input[[b, c, h_idx, w_idx]];
829 flat_idx += 1;
830 }
831 }
832 }
833 }
834 }
835 }
836
837 Ok(patches)
838 }
839}
840
841impl<F: Float + Debug + ScalarOperand + Send + Sync> Layer<F> for PatchEmbedding<F> {
842 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
843 let patches = self.extract_patches(input)?;
845
846 if let Ok(mut cache) = self.input_cache.write() {
848 *cache = Some(patches.clone());
849 } else {
850 return Err(Error::InferenceError(
851 "Failed to acquire write lock on input cache".to_string(),
852 ));
853 }
854
855 let batch_size = patches.shape()[0];
856 let num_patches = patches.shape()[1];
857
858 let mut embeddings = Array::zeros(IxDyn(&[batch_size, num_patches, self.embedding_dim]));
860
861 for b in 0..batch_size {
862 for p in 0..num_patches {
863 for e in 0..self.embedding_dim {
865 let mut val = F::zero();
866 for i in 0..patches.shape()[2] {
867 val = val + self.weight[[e, i]] * patches[[b, p, i]];
868 }
869
870 if let Some(ref bias) = self.bias {
872 val = val + bias[[e]];
873 }
874
875 embeddings[[b, p, e]] = val;
876 }
877 }
878 }
879
880 Ok(embeddings)
881 }
882
883 fn backward(
884 &self,
885 _input: &Array<F, IxDyn>,
886 grad_output: &Array<F, IxDyn>,
887 ) -> Result<Array<F, IxDyn>> {
888 let input_cache_guard = match self.input_cache.read() {
890 Ok(guard) => guard,
891 Err(_) => {
892 return Err(Error::InferenceError(
893 "Failed to acquire read lock on input cache".to_string(),
894 ))
895 }
896 };
897
898 if input_cache_guard.is_none() {
899 return Err(Error::InferenceError(
900 "Cannot perform backward pass before forward pass".to_string(),
901 ));
902 }
903
904 let patches = input_cache_guard.as_ref().unwrap();
905 let batch_size = patches.shape()[0];
906 let num_patches = patches.shape()[1];
907 let patch_dim = patches.shape()[2];
908
909 if grad_output.shape() != [batch_size, num_patches, self.embedding_dim] {
911 return Err(Error::InvalidArchitecture(format!(
912 "Expected grad_output shape [{}, {}, {}], but got {:?}",
913 batch_size,
914 num_patches,
915 self.embedding_dim,
916 grad_output.shape()
917 )));
918 }
919
920 let mut weight_grad = Array::zeros(self.weight.dim());
922 let mut bias_grad = if self.bias.is_some() {
923 Some(Array::zeros(IxDyn(&[self.embedding_dim])))
924 } else {
925 None
926 };
927
928 for b in 0..batch_size {
929 for p in 0..num_patches {
930 for e in 0..self.embedding_dim {
931 let grad = grad_output[[b, p, e]];
932
933 if let Some(ref mut bg) = bias_grad {
935 bg[[e]] = bg[[e]] + grad;
936 }
937
938 for i in 0..patch_dim {
940 weight_grad[[e, i]] = weight_grad[[e, i]] + grad * patches[[b, p, i]];
941 }
942 }
943 }
944 }
945
946 if let Ok(mut weight_grad_guard) = self.weight_grad.write() {
948 for e in 0..self.embedding_dim {
949 for i in 0..patch_dim {
950 weight_grad_guard[[e, i]] = weight_grad_guard[[e, i]] + weight_grad[[e, i]];
951 }
952 }
953 } else {
954 return Err(Error::InferenceError(
955 "Failed to acquire write lock on weight gradients".to_string(),
956 ));
957 }
958
959 if let (Some(ref bg_acc_lock), Some(ref bg)) = (&self.bias_grad, &bias_grad) {
960 if let Ok(mut bg_acc) = bg_acc_lock.write() {
961 for e in 0..self.embedding_dim {
962 bg_acc[[e]] = bg_acc[[e]] + bg[[e]];
963 }
964 } else {
965 return Err(Error::InferenceError(
966 "Failed to acquire write lock on bias gradients".to_string(),
967 ));
968 }
969 }
970
971 let mut input_grad = Array::zeros(IxDyn(&[
973 batch_size,
974 self.in_channels,
975 self.image_size.0,
976 self.image_size.1,
977 ]));
978
979 let mut patches_grad = Array::zeros(patches.dim());
981
982 for b in 0..batch_size {
983 for p in 0..num_patches {
984 for i in 0..patch_dim {
985 let mut grad = F::zero();
986 for e in 0..self.embedding_dim {
987 grad = grad + grad_output[[b, p, e]] * self.weight[[e, i]];
988 }
989 patches_grad[[b, p, i]] = grad;
990 }
991 }
992 }
993
994 let n_h = self.image_size.0 / self.patch_size.0;
996 let n_w = self.image_size.1 / self.patch_size.1;
997
998 for b in 0..batch_size {
999 for i in 0..n_h {
1000 for j in 0..n_w {
1001 let patch_idx = i * n_w + j;
1002 let h_start = i * self.patch_size.0;
1003 let w_start = j * self.patch_size.1;
1004
1005 let mut flat_idx = 0;
1007 for c in 0..self.in_channels {
1008 for ph in 0..self.patch_size.0 {
1009 for pw in 0..self.patch_size.1 {
1010 let h_idx = h_start + ph;
1011 let w_idx = w_start + pw;
1012 input_grad[[b, c, h_idx, w_idx]] =
1013 patches_grad[[b, patch_idx, flat_idx]];
1014 flat_idx += 1;
1015 }
1016 }
1017 }
1018 }
1019 }
1020 }
1021
1022 Ok(input_grad)
1023 }
1024
1025 fn update(&mut self, learning_rate: F) -> Result<()> {
1026 let lr = learning_rate;
1028 let patch_dim = self.weight.shape()[1];
1029
1030 if let Ok(weight_grad_guard) = self.weight_grad.read() {
1032 for e in 0..self.embedding_dim {
1033 for i in 0..patch_dim {
1034 self.weight[[e, i]] = self.weight[[e, i]] - lr * weight_grad_guard[[e, i]];
1035 }
1036 }
1037 } else {
1038 return Err(Error::InferenceError(
1039 "Failed to acquire read lock on weight gradients".to_string(),
1040 ));
1041 }
1042
1043 if let Some(ref mut bias) = &mut self.bias {
1045 if let Some(ref bias_grad_lock) = &self.bias_grad {
1046 if let Ok(bias_grad_guard) = bias_grad_lock.read() {
1047 for e in 0..self.embedding_dim {
1048 bias[[e]] = bias[[e]] - lr * bias_grad_guard[[e]];
1049 }
1050 } else {
1051 return Err(Error::InferenceError(
1052 "Failed to acquire read lock on bias gradients".to_string(),
1053 ));
1054 }
1055 }
1056 }
1057
1058 if let Ok(mut weight_grad_guard) = self.weight_grad.write() {
1060 weight_grad_guard.fill(F::zero());
1061 } else {
1062 return Err(Error::InferenceError(
1063 "Failed to acquire write lock on weight gradients".to_string(),
1064 ));
1065 }
1066
1067 if let Some(ref bias_grad_lock) = &self.bias_grad {
1068 if let Ok(mut bias_grad_guard) = bias_grad_lock.write() {
1069 bias_grad_guard.fill(F::zero());
1070 } else {
1071 return Err(Error::InferenceError(
1072 "Failed to acquire write lock on bias gradients".to_string(),
1073 ));
1074 }
1075 }
1076
1077 if let Ok(mut cache) = self.input_cache.write() {
1079 *cache = None;
1080 } else {
1081 return Err(Error::InferenceError(
1082 "Failed to acquire write lock on input cache".to_string(),
1083 ));
1084 }
1085
1086 Ok(())
1087 }
1088
1089 fn as_any(&self) -> &dyn std::any::Any {
1090 self
1091 }
1092
1093 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
1094 self
1095 }
1096}
1097
1098#[cfg(test)]
1099mod tests {
1100 use super::*;
1101 use ndarray::Array2;
1102 use rand::Rng;
1103
1104 #[test]
1105 fn test_embedding_creation() {
1106 let config = EmbeddingConfig {
1108 num_embeddings: 10,
1109 embedding_dim: 5,
1110 padding_idx: Some(0),
1111 max_norm: None,
1112 norm_type: 2.0,
1113 scale_grad_by_freq: false,
1114 sparse: false,
1115 };
1116
1117 let embedding = Embedding::<f32>::new(config).unwrap();
1118
1119 assert_eq!(embedding.weight.shape(), &[10, 5]);
1121
1122 for i in 0..5 {
1124 assert_eq!(embedding.weight[[0, i]], 0.0);
1125 }
1126 }
1127
1128 #[test]
1129 fn test_embedding_forward() {
1130 let config = EmbeddingConfig {
1132 num_embeddings: 10,
1133 embedding_dim: 5,
1134 padding_idx: Some(0),
1135 max_norm: None,
1136 norm_type: 2.0,
1137 scale_grad_by_freq: false,
1138 sparse: false,
1139 };
1140
1141 let mut embedding = Embedding::<f32>::new(config).unwrap();
1142
1143 for i in 0..10 {
1145 for j in 0..5 {
1146 embedding.weight[[i, j]] = (i * 10 + j) as f32 / 10.0;
1147 }
1148 }
1149
1150 for j in 0..5 {
1152 embedding.weight[[0, j]] = 0.0;
1153 }
1154
1155 let indices = Array2::from_shape_vec((2, 3), vec![1, 2, 0, 3, 0, 4]).unwrap();
1157 let indices_dyn = indices.into_dimensionality::<IxDyn>().unwrap();
1158
1159 let output = embedding.forward_impl(&indices_dyn).unwrap();
1161
1162 assert_eq!(output.shape(), &[2, 3, 5]);
1164
1165 for j in 0..5 {
1168 assert_eq!(output[[0, 0, j]], (10 + j) as f32 / 10.0);
1169 }
1170
1171 for j in 0..5 {
1173 assert_eq!(output[[0, 1, j]], (20 + j) as f32 / 10.0);
1174 }
1175
1176 for j in 0..5 {
1178 assert_eq!(output[[0, 2, j]], 0.0);
1179 }
1180 }
1181
1182 #[test]
1183 fn test_positional_embedding() {
1184 let pos_emb_learned = PositionalEmbedding::<f32>::new(10, 8, true).unwrap();
1186
1187 assert!(pos_emb_learned.weight.is_some());
1189 assert_eq!(pos_emb_learned.weight.as_ref().unwrap().shape(), &[10, 8]);
1190
1191 let input = Array::from_shape_fn(IxDyn(&[2, 5, 8]), |_| 1.0f32);
1193
1194 let output = pos_emb_learned.forward(&input).unwrap();
1196
1197 assert_eq!(output.shape(), &[2, 5, 8]);
1199
1200 let pos_emb_fixed = PositionalEmbedding::<f32>::new(10, 8, false).unwrap();
1202
1203 assert!(pos_emb_fixed.weight.is_none());
1205
1206 let output = pos_emb_fixed.forward(&input).unwrap();
1208
1209 assert_eq!(output.shape(), &[2, 5, 8]);
1211 }
1212
1213 #[test]
1214 fn test_patch_embedding() {
1215 let patch_emb = PatchEmbedding::<f32>::new((32, 32), (8, 8), 3, 96, true).unwrap();
1217
1218 assert_eq!(patch_emb.weight.shape(), &[96, 3 * 8 * 8]);
1220 assert!(patch_emb.bias.is_some());
1221 assert_eq!(patch_emb.bias.as_ref().unwrap().shape(), &[96]);
1222
1223 assert_eq!(patch_emb.num_patches(), 16); let mut rand_gen = rand::rng();
1228 let input = Array::from_shape_fn(IxDyn(&[2, 3, 32, 32]), |_| rand_gen.random::<f32>());
1229
1230 let output = patch_emb.forward(&input).unwrap();
1232
1233 assert_eq!(output.shape(), &[2, 16, 96]);
1235 }
1236}