scirs2_neural/layers/
embedding.rs

1//! Embedding layer implementations
2//!
3//! This module provides implementations of various embedding layers
4//! such as word embeddings, positional embeddings, and patch embeddings for vision.
5
6use ndarray::{Array, ArrayBase, Data, Dimension, Ix1, IxDyn, ScalarOperand};
7use num_traits::Float;
8use rand::prelude::*;
9use std::fmt::Debug;
10use std::sync::{Arc, RwLock};
11
12use crate::error::{Error, Result};
13use crate::layers::Layer;
14use crate::utils::initializers;
15
16/// Configuration for the Embedding layer
17pub struct EmbeddingConfig {
18    /// Number of embeddings in the embedding table
19    pub num_embeddings: usize,
20    /// Dimension of each embedding vector
21    pub embedding_dim: usize,
22    /// Optional padding index that will have its embedding vector filled with zeros
23    pub padding_idx: Option<usize>,
24    /// Maximum norm for embedding vectors
25    pub max_norm: Option<f64>,
26    /// Type of norm to use with max_norm
27    pub norm_type: f64,
28    /// Whether to scale gradients by the inverse of frequency of the indices
29    pub scale_grad_by_freq: bool,
30    /// Whether to use sparse gradients for the embedding matrix
31    pub sparse: bool,
32}
33
34impl Default for EmbeddingConfig {
35    fn default() -> Self {
36        Self {
37            num_embeddings: 1,
38            embedding_dim: 1,
39            padding_idx: None,
40            max_norm: None,
41            norm_type: 2.0,
42            scale_grad_by_freq: false,
43            sparse: false,
44        }
45    }
46}
47
48/// Embedding layer that stores embeddings for discrete inputs
49///
50/// This layer is often used to store word embeddings and retrieve them using indices.
51/// The input to the module is a list of indices, and the output is the corresponding
52/// embedding vectors.
53pub struct Embedding<F: Float + Debug + ScalarOperand> {
54    /// Configuration for the embedding layer
55    pub config: EmbeddingConfig,
56    /// Weight matrix containing the embeddings
57    pub weight: Array<F, IxDyn>,
58    /// Gradient of the weight matrix
59    weight_grad: Array<F, IxDyn>,
60    /// Frequency counter for indices
61    freq_counter: Option<Vec<usize>>,
62}
63
64impl<F: Float + Debug + ScalarOperand> Embedding<F> {
65    /// Create a new Embedding layer with the given configuration
66    pub fn new(config: EmbeddingConfig) -> Result<Self> {
67        if config.num_embeddings == 0 {
68            return Err(Error::InvalidArchitecture(
69                "num_embeddings must be greater than 0".to_string(),
70            ));
71        }
72        if config.embedding_dim == 0 {
73            return Err(Error::InvalidArchitecture(
74                "embedding_dim must be greater than 0".to_string(),
75            ));
76        }
77
78        // Validate padding_idx
79        if let Some(idx) = config.padding_idx {
80            if idx >= config.num_embeddings {
81                return Err(Error::InvalidArchitecture(format!(
82                    "padding_idx ({}) must be less than num_embeddings ({})",
83                    idx, config.num_embeddings
84                )));
85            }
86        }
87
88        // Initialize weights with standard distribution
89        let weight_shape = IxDyn(&[config.num_embeddings, config.embedding_dim]);
90
91        // Use standard distribution and scale it
92        let mut rng = rand::rng();
93        let mut weight = Array::from_shape_fn(weight_shape.clone(), |_| {
94            let value: f64 = rng.random::<f64>();
95            // Scale to approximate normal distribution N(0, 1)
96            let scaled_value = (value * 2.0 - 1.0) * 0.5;
97            F::from(scaled_value).unwrap()
98        });
99
100        // Initialize gradients with zeros
101        let weight_grad = Array::zeros(weight_shape.clone());
102
103        // Set padding_idx embeddings to zero if specified
104        if let Some(idx) = config.padding_idx {
105            let mut slice = weight.slice_mut(ndarray::s![idx, ..]);
106            for item in slice.iter_mut() {
107                *item = F::zero();
108            }
109        }
110
111        // Initialize frequency counter if needed
112        let freq_counter = if config.scale_grad_by_freq {
113            Some(vec![0; config.num_embeddings])
114        } else {
115            None
116        };
117
118        Ok(Self {
119            config,
120            weight,
121            weight_grad,
122            freq_counter,
123        })
124    }
125
126    /// Create an Embedding layer from pretrained embeddings
127    pub fn from_pretrained(
128        embeddings: Array<F, IxDyn>,
129        padding_idx: Option<usize>,
130        max_norm: Option<f64>,
131        norm_type: f64,
132        scale_grad_by_freq: bool,
133        sparse: bool,
134    ) -> Result<Self> {
135        if embeddings.ndim() != 2 {
136            return Err(Error::InvalidArchitecture(
137                "Embeddings parameter is expected to be 2-dimensional".to_string(),
138            ));
139        }
140
141        let shape = embeddings.shape();
142        let num_embeddings = shape[0];
143        let embedding_dim = shape[1];
144
145        // Validate padding_idx
146        if let Some(idx) = padding_idx {
147            if idx >= num_embeddings {
148                return Err(Error::InvalidArchitecture(format!(
149                    "padding_idx ({}) must be less than num_embeddings ({})",
150                    idx, num_embeddings
151                )));
152            }
153        }
154
155        let config = EmbeddingConfig {
156            num_embeddings,
157            embedding_dim,
158            padding_idx,
159            max_norm,
160            norm_type,
161            scale_grad_by_freq,
162            sparse,
163        };
164
165        // Clone the weights
166        let weight = embeddings.clone();
167        let weight_grad = Array::zeros(IxDyn(&[num_embeddings, embedding_dim]));
168
169        // Initialize frequency counter if needed
170        let freq_counter = if scale_grad_by_freq {
171            Some(vec![0; num_embeddings])
172        } else {
173            None
174        };
175
176        Ok(Self {
177            config,
178            weight,
179            weight_grad,
180            freq_counter,
181        })
182    }
183
184    /// Reset parameters of the embedding layer
185    pub fn reset_parameters(&mut self) -> Result<()> {
186        // Re-initialize weights with standard normal distribution
187        let mut rng = rand::rng();
188        for item in self.weight.iter_mut() {
189            *item = F::from(rng.random::<f64>()).unwrap();
190        }
191
192        // Set padding_idx embeddings to zero if specified
193        if let Some(idx) = self.config.padding_idx {
194            let mut slice = self.weight.slice_mut(ndarray::s![idx, ..]);
195            for item in slice.iter_mut() {
196                *item = F::zero();
197            }
198        }
199
200        // Reset gradients
201        self.weight_grad.fill(F::zero());
202
203        // Reset frequency counter if needed
204        if let Some(counter) = &mut self.freq_counter {
205            counter.iter_mut().for_each(|c| *c = 0);
206        }
207
208        Ok(())
209    }
210
211    /// Apply max_norm to the embeddings if specified
212    fn apply_max_norm(&mut self) -> Result<()> {
213        if let Some(max_norm) = self.config.max_norm {
214            let norm_type = self.config.norm_type;
215            let p = F::from(norm_type).ok_or_else(|| {
216                Error::InvalidArchitecture(format!("Invalid norm_type: {}", norm_type))
217            })?;
218            let max_norm = F::from(max_norm).ok_or_else(|| {
219                Error::InvalidArchitecture(format!("Invalid max_norm: {}", max_norm))
220            })?;
221
222            // Calculate norms for each embedding vector
223            for i in 0..self.config.num_embeddings {
224                let mut norm = F::zero();
225                // Calculate p-norm
226                for j in 0..self.config.embedding_dim {
227                    let val = self.weight[[i, j]];
228                    if p == F::from(2.0).unwrap() {
229                        norm = norm + val * val;
230                    } else {
231                        norm = norm + val.abs().powf(p);
232                    }
233                }
234
235                if p == F::from(2.0).unwrap() {
236                    norm = norm.sqrt();
237                } else {
238                    norm = norm.powf(F::one() / p);
239                }
240
241                // Apply max_norm if needed
242                if norm > max_norm {
243                    let scale = max_norm / norm;
244                    for j in 0..self.config.embedding_dim {
245                        self.weight[[i, j]] = self.weight[[i, j]] * scale;
246                    }
247                }
248            }
249        }
250
251        Ok(())
252    }
253
254    /// Internal forward pass implementation
255    fn forward_impl<D: Dimension>(
256        &mut self,
257        indices: &ArrayBase<impl Data<Elem = usize>, D>,
258    ) -> Result<Array<F, IxDyn>> {
259        // Validate indices
260        for &idx in indices.iter() {
261            if idx >= self.config.num_embeddings {
262                return Err(Error::InvalidArchitecture(format!(
263                    "Index {} out of bounds for embedding with {} entries",
264                    idx, self.config.num_embeddings
265                )));
266            }
267        }
268
269        // Apply max_norm if specified
270        self.apply_max_norm()?;
271
272        // Update frequency counter if needed
273        if let Some(counter) = &mut self.freq_counter {
274            for &idx in indices.iter() {
275                counter[idx] += 1;
276            }
277        }
278
279        // Create output array
280        let mut output_shape = Vec::with_capacity(indices.ndim() + 1);
281        output_shape.extend_from_slice(indices.shape());
282        output_shape.push(self.config.embedding_dim);
283
284        let mut output = Array::zeros(IxDyn(output_shape.as_slice()));
285
286        // Lookup embeddings
287        let indices_flat = indices
288            .view()
289            .into_shape_with_order(IxDyn(&[indices.len()]))
290            .unwrap()
291            .into_dimensionality::<Ix1>()
292            .unwrap();
293
294        for (flat_idx, &idx) in indices_flat.iter().enumerate() {
295            // Skip padding indices
296            if let Some(padding_idx) = self.config.padding_idx {
297                if idx == padding_idx {
298                    // Already filled with zeros
299                    continue;
300                }
301            }
302
303            // Compute output index
304            let mut output_idx = Vec::with_capacity(indices.ndim() + 1);
305            let mut remaining = flat_idx;
306            for &dim in indices.shape().iter().rev() {
307                output_idx.push(remaining % dim);
308                remaining /= dim;
309            }
310            output_idx.reverse();
311            output_idx.push(0); // Starting embedding dimension index
312
313            // Copy embedding to output
314            for j in 0..self.config.embedding_dim {
315                output_idx.last_mut().unwrap().clone_from(&j);
316                let emb_val = self.weight[[idx, j]];
317                output[IxDyn(output_idx.as_slice())] = emb_val;
318            }
319        }
320
321        Ok(output)
322    }
323}
324
325impl<F: Float + Debug + ScalarOperand> Layer<F> for Embedding<F> {
326    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
327        // Convert input to indices
328        let indices = input
329            .mapv(|x| x.to_usize().unwrap_or(0))
330            .into_dimensionality::<IxDyn>()?;
331
332        // Forward using the immutable version
333        let mut embedding_mut = Embedding {
334            config: EmbeddingConfig {
335                num_embeddings: self.config.num_embeddings,
336                embedding_dim: self.config.embedding_dim,
337                padding_idx: self.config.padding_idx,
338                max_norm: self.config.max_norm,
339                norm_type: self.config.norm_type,
340                scale_grad_by_freq: self.config.scale_grad_by_freq,
341                sparse: self.config.sparse,
342            },
343            weight: self.weight.clone(),
344            weight_grad: self.weight_grad.clone(),
345            freq_counter: self.freq_counter.clone(),
346        };
347
348        embedding_mut.forward_impl(&indices)
349    }
350
351    fn backward(
352        &self,
353        input: &Array<F, IxDyn>,
354        _grad_output: &Array<F, IxDyn>,
355    ) -> Result<Array<F, IxDyn>> {
356        // Embedding has no meaningful upstream gradient since indices are discrete
357        // Return zeros of the same shape as the input (indices)
358        let input_shape = &input.shape();
359        Ok(Array::zeros(IxDyn(input_shape)))
360    }
361
362    fn update(&mut self, learning_rate: F) -> Result<()> {
363        // Update weights using accumulated gradients
364        let lr = learning_rate;
365
366        // Handle frequency-based scaling
367        if let Some(counter) = &self.freq_counter {
368            for (i, &count) in counter.iter().enumerate().take(self.config.num_embeddings) {
369                // Skip padding indices
370                if let Some(padding_idx) = self.config.padding_idx {
371                    if i == padding_idx {
372                        continue;
373                    }
374                }
375
376                let scale = if count > 0 {
377                    F::from(1.0 / count as f64).unwrap()
378                } else {
379                    F::one()
380                };
381
382                for j in 0..self.config.embedding_dim {
383                    self.weight[[i, j]] =
384                        self.weight[[i, j]] - lr * scale * self.weight_grad[[i, j]];
385                }
386            }
387        } else {
388            // Standard gradient update
389            for i in 0..self.config.num_embeddings {
390                // Skip padding indices
391                if let Some(padding_idx) = self.config.padding_idx {
392                    if i == padding_idx {
393                        continue;
394                    }
395                }
396
397                for j in 0..self.config.embedding_dim {
398                    self.weight[[i, j]] = self.weight[[i, j]] - lr * self.weight_grad[[i, j]];
399                }
400            }
401        }
402
403        // Reset gradients
404        self.weight_grad.fill(F::zero());
405
406        // Reset frequency counter if needed
407        if let Some(counter) = &mut self.freq_counter {
408            counter.iter_mut().for_each(|c| *c = 0);
409        }
410
411        Ok(())
412    }
413
414    fn as_any(&self) -> &dyn std::any::Any {
415        self
416    }
417
418    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
419        self
420    }
421}
422
423/// Positional Embedding layer for transformers and sequence models
424///
425/// This layer adds positional information to embeddings to help models
426/// understand the position of elements in a sequence.
427pub struct PositionalEmbedding<F: Float + Debug + ScalarOperand> {
428    /// Maximum sequence length supported
429    pub max_seq_length: usize,
430    /// Dimension of each embedding vector
431    pub embedding_dim: usize,
432    /// Whether to use learned positional embeddings (true) or fixed sinusoidal (false)
433    pub learned: bool,
434    /// Weight matrix for learned positional embeddings
435    pub weight: Option<Array<F, IxDyn>>,
436    /// Gradient of the weight matrix
437    weight_grad: Option<Array<F, IxDyn>>,
438}
439
440impl<F: Float + Debug + ScalarOperand> PositionalEmbedding<F> {
441    /// Create a new PositionalEmbedding layer
442    pub fn new(max_seq_length: usize, embedding_dim: usize, learned: bool) -> Result<Self> {
443        if max_seq_length == 0 {
444            return Err(Error::InvalidArchitecture(
445                "max_seq_length must be greater than 0".to_string(),
446            ));
447        }
448        if embedding_dim == 0 {
449            return Err(Error::InvalidArchitecture(
450                "embedding_dim must be greater than 0".to_string(),
451            ));
452        }
453
454        if learned {
455            // Initialize learned positional embeddings
456            let weight_shape = IxDyn(&[max_seq_length, embedding_dim]);
457            let weight = Some(initializers::xavier_uniform::<F>(weight_shape.clone())?);
458            let weight_grad = Some(Array::zeros(weight_shape));
459
460            Ok(Self {
461                max_seq_length,
462                embedding_dim,
463                learned,
464                weight,
465                weight_grad,
466            })
467        } else {
468            // Sinusoidal positional embeddings are computed on the fly
469            Ok(Self {
470                max_seq_length,
471                embedding_dim,
472                learned,
473                weight: None,
474                weight_grad: None,
475            })
476        }
477    }
478
479    /// Generate sinusoidal positional embeddings
480    fn generate_sinusoidal_embeddings(&self, seq_length: usize) -> Result<Array<F, IxDyn>> {
481        if seq_length > self.max_seq_length {
482            return Err(Error::InvalidArchitecture(format!(
483                "Sequence length {} exceeds maximum supported length {}",
484                seq_length, self.max_seq_length
485            )));
486        }
487
488        // Initialize output array
489        let mut pos_embeddings = Array::zeros(IxDyn(&[seq_length, self.embedding_dim]));
490
491        // Generate sinusoidal positional embeddings
492        for pos in 0..seq_length {
493            for i in 0..self.embedding_dim {
494                let div_term =
495                    F::from((10000.0f64).powf(2.0 * (i / 2) as f64 / self.embedding_dim as f64))
496                        .unwrap();
497
498                if i % 2 == 0 {
499                    // Sine for even dimensions
500                    pos_embeddings[[pos, i]] = F::from(pos as f64 / div_term.to_f64().unwrap())
501                        .unwrap()
502                        .sin();
503                } else {
504                    // Cosine for odd dimensions
505                    pos_embeddings[[pos, i]] = F::from(pos as f64 / div_term.to_f64().unwrap())
506                        .unwrap()
507                        .cos();
508                }
509            }
510        }
511
512        Ok(pos_embeddings)
513    }
514}
515
516impl<F: Float + Debug + ScalarOperand> Layer<F> for PositionalEmbedding<F> {
517    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
518        // Validate input shape - at least 2D with last dimension being embedding_dim
519        if input.ndim() < 2 {
520            return Err(Error::InvalidArchitecture(
521                "Input to PositionalEmbedding must be at least 2D".to_string(),
522            ));
523        }
524
525        let last_dim = input.shape().last().unwrap();
526        if *last_dim != self.embedding_dim {
527            return Err(Error::InvalidArchitecture(format!(
528                "Input embedding dimension {} doesn't match layer embedding dimension {}",
529                last_dim, self.embedding_dim
530            )));
531        }
532
533        // Get sequence length from the input shape
534        let seq_dim = input.ndim() - 2;
535        let seq_length = input.shape()[seq_dim];
536
537        if seq_length > self.max_seq_length {
538            return Err(Error::InvalidArchitecture(format!(
539                "Input sequence length {} exceeds maximum supported length {}",
540                seq_length, self.max_seq_length
541            )));
542        }
543
544        if self.learned {
545            // Use learned positional embeddings
546            let pos_embeddings = self
547                .weight
548                .as_ref()
549                .unwrap()
550                .slice(ndarray::s![0..seq_length, ..]);
551
552            // Add positional embeddings to input
553            // Need to broadcast positional embeddings to match input shape
554            let mut output = input.clone();
555
556            // Iterate over all batch elements and add positional embeddings
557            let batch_shape = &input.shape()[..seq_dim];
558            let batch_size: usize = batch_shape.iter().product();
559
560            for batch_idx in 0..batch_size {
561                // Calculate the multi-dimensional batch index
562                let mut multi_idx = Vec::with_capacity(seq_dim);
563                let mut remaining = batch_idx;
564                for &dim in batch_shape.iter().rev() {
565                    multi_idx.push(remaining % dim);
566                    remaining /= dim;
567                }
568                multi_idx.reverse();
569
570                // For each position in the sequence
571                for pos in 0..seq_length {
572                    // Full index includes batch indices, sequence position, and embedding dimension
573                    let mut full_idx = multi_idx.clone();
574                    full_idx.push(pos);
575
576                    // Add positional embeddings to each embedding dimension
577                    for dim in 0..self.embedding_dim {
578                        full_idx.push(dim);
579                        let pos_val = pos_embeddings[[pos, dim]];
580                        output[IxDyn(full_idx.as_slice())] =
581                            output[IxDyn(full_idx.as_slice())] + pos_val;
582                        full_idx.pop();
583                    }
584                }
585            }
586
587            Ok(output)
588        } else {
589            // Generate sinusoidal positional embeddings
590            let pos_embeddings = self.generate_sinusoidal_embeddings(seq_length)?;
591
592            // Add positional embeddings to input
593            let mut output = input.clone();
594
595            // Iterate over all batch elements and add positional embeddings
596            let batch_shape = &input.shape()[..seq_dim];
597            let batch_size: usize = batch_shape.iter().product();
598
599            for batch_idx in 0..batch_size {
600                // Calculate the multi-dimensional batch index
601                let mut multi_idx = Vec::with_capacity(seq_dim);
602                let mut remaining = batch_idx;
603                for &dim in batch_shape.iter().rev() {
604                    multi_idx.push(remaining % dim);
605                    remaining /= dim;
606                }
607                multi_idx.reverse();
608
609                // For each position in the sequence
610                for pos in 0..seq_length {
611                    // Full index includes batch indices, sequence position, and embedding dimension
612                    let mut full_idx = multi_idx.clone();
613                    full_idx.push(pos);
614
615                    // Add positional embeddings to each embedding dimension
616                    for dim in 0..self.embedding_dim {
617                        full_idx.push(dim);
618                        let pos_val = pos_embeddings[[pos, dim]];
619                        output[IxDyn(full_idx.as_slice())] =
620                            output[IxDyn(full_idx.as_slice())] + pos_val;
621                        full_idx.pop();
622                    }
623                }
624            }
625
626            Ok(output)
627        }
628    }
629
630    fn backward(
631        &self,
632        _input: &Array<F, IxDyn>,
633        grad_output: &Array<F, IxDyn>,
634    ) -> Result<Array<F, IxDyn>> {
635        // For PositionalEmbedding, gradients flow through directly
636        Ok(grad_output.clone())
637    }
638
639    fn update(&mut self, learning_rate: F) -> Result<()> {
640        // Only update weights if using learned positional embeddings
641        if self.learned {
642            if let (Some(weight), Some(weight_grad)) = (&mut self.weight, &self.weight_grad) {
643                // Update weights using accumulated gradients
644                let lr = learning_rate;
645                for i in 0..self.max_seq_length {
646                    for j in 0..self.embedding_dim {
647                        weight[[i, j]] = weight[[i, j]] - lr * weight_grad[[i, j]];
648                    }
649                }
650
651                // Reset gradients
652                self.weight_grad = Some(Array::zeros(IxDyn(&[
653                    self.max_seq_length,
654                    self.embedding_dim,
655                ])));
656            }
657        }
658
659        Ok(())
660    }
661
662    fn as_any(&self) -> &dyn std::any::Any {
663        self
664    }
665
666    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
667        self
668    }
669}
670
671/// Patch Embedding layer for vision transformers
672///
673/// This layer converts image patches into embeddings for vision transformers.
674/// It applies a convolution to extract patches and flatten them into embedding vectors.
675#[derive(Debug, Clone)]
676pub struct PatchEmbedding<F: Float + Debug + ScalarOperand + Send + Sync> {
677    /// Size of input images (height, width)
678    pub image_size: (usize, usize),
679    /// Size of patches (height, width)
680    pub patch_size: (usize, usize),
681    /// Number of input channels (e.g., 3 for RGB)
682    pub in_channels: usize,
683    /// Dimension of each embedding vector
684    pub embedding_dim: usize,
685    /// Weight matrix for patch extraction
686    pub weight: Array<F, IxDyn>,
687    /// Bias vector
688    pub bias: Option<Array<F, IxDyn>>,
689    /// Gradient of the weight matrix
690    weight_grad: Arc<RwLock<Array<F, IxDyn>>>,
691    /// Gradient of the bias vector
692    bias_grad: Option<Arc<RwLock<Array<F, IxDyn>>>>,
693    /// Input cache for backpropagation
694    input_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
695}
696
697impl<F: Float + Debug + ScalarOperand + Send + Sync> PatchEmbedding<F> {
698    /// Create a new PatchEmbedding layer
699    pub fn new(
700        image_size: (usize, usize),
701        patch_size: (usize, usize),
702        in_channels: usize,
703        embedding_dim: usize,
704        use_bias: bool,
705    ) -> Result<Self> {
706        // Validate parameters
707        if image_size.0 == 0 || image_size.1 == 0 {
708            return Err(Error::InvalidArchitecture(
709                "Image height and width must be greater than 0".to_string(),
710            ));
711        }
712        if patch_size.0 == 0 || patch_size.1 == 0 {
713            return Err(Error::InvalidArchitecture(
714                "Patch height and width must be greater than 0".to_string(),
715            ));
716        }
717        if in_channels == 0 {
718            return Err(Error::InvalidArchitecture(
719                "Number of input channels must be greater than 0".to_string(),
720            ));
721        }
722        if embedding_dim == 0 {
723            return Err(Error::InvalidArchitecture(
724                "Embedding dimension must be greater than 0".to_string(),
725            ));
726        }
727
728        // Check if image is divisible by patch size
729        if image_size.0 % patch_size.0 != 0 || image_size.1 % patch_size.1 != 0 {
730            return Err(Error::InvalidArchitecture(
731                "Image dimensions must be divisible by patch dimensions".to_string(),
732            ));
733        }
734
735        // Calculate number of patches
736        let n_h = image_size.0 / patch_size.0;
737        let n_w = image_size.1 / patch_size.1;
738        let _num_patches = n_h * n_w;
739
740        // Initialize weights and bias
741        // Weight shape: [embedding_dim, in_channels * patch_size.0 * patch_size.1]
742        let weight_shape = IxDyn(&[embedding_dim, in_channels * patch_size.0 * patch_size.1]);
743        let weight = initializers::xavier_uniform::<F>(weight_shape.clone())?;
744        let weight_grad = Arc::new(RwLock::new(Array::zeros(weight_shape)));
745
746        // Bias and its gradient
747        let (bias, bias_grad) = if use_bias {
748            let bias = Some(Array::zeros(IxDyn(&[embedding_dim])));
749            let bias_grad = Some(Arc::new(RwLock::new(Array::zeros(IxDyn(&[embedding_dim])))));
750            (bias, bias_grad)
751        } else {
752            (None, None)
753        };
754
755        Ok(Self {
756            image_size,
757            patch_size,
758            in_channels,
759            embedding_dim,
760            weight,
761            bias,
762            weight_grad,
763            bias_grad,
764            input_cache: Arc::new(RwLock::new(None)),
765        })
766    }
767
768    /// Calculate the number of patches
769    pub fn num_patches(&self) -> usize {
770        let n_h = self.image_size.0 / self.patch_size.0;
771        let n_w = self.image_size.1 / self.patch_size.1;
772        n_h * n_w
773    }
774
775    /// Extract patches from input images
776    fn extract_patches(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
777        // Validate input shape
778        if input.ndim() != 4 {
779            return Err(Error::InvalidArchitecture(
780                "Input to PatchEmbedding must be 4D [batch_size, channels, height, width]"
781                    .to_string(),
782            ));
783        }
784
785        let shape = input.shape();
786        let batch_size = shape[0];
787        let channels = shape[1];
788        let height = shape[2];
789        let width = shape[3];
790
791        if channels != self.in_channels {
792            return Err(Error::InvalidArchitecture(format!(
793                "Input has {} channels, but expected {}",
794                channels, self.in_channels
795            )));
796        }
797
798        if height != self.image_size.0 || width != self.image_size.1 {
799            return Err(Error::InvalidArchitecture(format!(
800                "Input has shape [{}x{}], but expected [{}x{}]",
801                height, width, self.image_size.0, self.image_size.1
802            )));
803        }
804
805        // Calculate patch grid dimensions
806        let n_h = height / self.patch_size.0;
807        let n_w = width / self.patch_size.1;
808        let num_patches = n_h * n_w;
809
810        // Extract patches and flatten them
811        let patch_dim = channels * self.patch_size.0 * self.patch_size.1;
812        let mut patches = Array::zeros(IxDyn(&[batch_size, num_patches, patch_dim]));
813
814        for b in 0..batch_size {
815            for i in 0..n_h {
816                for j in 0..n_w {
817                    let patch_idx = i * n_w + j;
818                    let h_start = i * self.patch_size.0;
819                    let w_start = j * self.patch_size.1;
820
821                    // Flatten the patch
822                    let mut flat_idx = 0;
823                    for c in 0..channels {
824                        for ph in 0..self.patch_size.0 {
825                            for pw in 0..self.patch_size.1 {
826                                let h_idx = h_start + ph;
827                                let w_idx = w_start + pw;
828                                patches[[b, patch_idx, flat_idx]] = input[[b, c, h_idx, w_idx]];
829                                flat_idx += 1;
830                            }
831                        }
832                    }
833                }
834            }
835        }
836
837        Ok(patches)
838    }
839}
840
841impl<F: Float + Debug + ScalarOperand + Send + Sync> Layer<F> for PatchEmbedding<F> {
842    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
843        // Extract patches
844        let patches = self.extract_patches(input)?;
845
846        // Cache input for backpropagation
847        if let Ok(mut cache) = self.input_cache.write() {
848            *cache = Some(patches.clone());
849        } else {
850            return Err(Error::InferenceError(
851                "Failed to acquire write lock on input cache".to_string(),
852            ));
853        }
854
855        let batch_size = patches.shape()[0];
856        let num_patches = patches.shape()[1];
857
858        // Linear projection of patches to embedding dimension
859        let mut embeddings = Array::zeros(IxDyn(&[batch_size, num_patches, self.embedding_dim]));
860
861        for b in 0..batch_size {
862            for p in 0..num_patches {
863                // Matrix multiplication for each patch
864                for e in 0..self.embedding_dim {
865                    let mut val = F::zero();
866                    for i in 0..patches.shape()[2] {
867                        val = val + self.weight[[e, i]] * patches[[b, p, i]];
868                    }
869
870                    // Add bias if present
871                    if let Some(ref bias) = self.bias {
872                        val = val + bias[[e]];
873                    }
874
875                    embeddings[[b, p, e]] = val;
876                }
877            }
878        }
879
880        Ok(embeddings)
881    }
882
883    fn backward(
884        &self,
885        _input: &Array<F, IxDyn>,
886        grad_output: &Array<F, IxDyn>,
887    ) -> Result<Array<F, IxDyn>> {
888        // Get cached input from RwLock
889        let input_cache_guard = match self.input_cache.read() {
890            Ok(guard) => guard,
891            Err(_) => {
892                return Err(Error::InferenceError(
893                    "Failed to acquire read lock on input cache".to_string(),
894                ))
895            }
896        };
897
898        if input_cache_guard.is_none() {
899            return Err(Error::InferenceError(
900                "Cannot perform backward pass before forward pass".to_string(),
901            ));
902        }
903
904        let patches = input_cache_guard.as_ref().unwrap();
905        let batch_size = patches.shape()[0];
906        let num_patches = patches.shape()[1];
907        let patch_dim = patches.shape()[2];
908
909        // Validate grad_output shape
910        if grad_output.shape() != [batch_size, num_patches, self.embedding_dim] {
911            return Err(Error::InvalidArchitecture(format!(
912                "Expected grad_output shape [{}, {}, {}], but got {:?}",
913                batch_size,
914                num_patches,
915                self.embedding_dim,
916                grad_output.shape()
917            )));
918        }
919
920        // Compute gradients with respect to weights and bias
921        let mut weight_grad = Array::zeros(self.weight.dim());
922        let mut bias_grad = if self.bias.is_some() {
923            Some(Array::zeros(IxDyn(&[self.embedding_dim])))
924        } else {
925            None
926        };
927
928        for b in 0..batch_size {
929            for p in 0..num_patches {
930                for e in 0..self.embedding_dim {
931                    let grad = grad_output[[b, p, e]];
932
933                    // Gradient for bias
934                    if let Some(ref mut bg) = bias_grad {
935                        bg[[e]] = bg[[e]] + grad;
936                    }
937
938                    // Gradient for weights
939                    for i in 0..patch_dim {
940                        weight_grad[[e, i]] = weight_grad[[e, i]] + grad * patches[[b, p, i]];
941                    }
942                }
943            }
944        }
945
946        // Update accumulated gradients
947        if let Ok(mut weight_grad_guard) = self.weight_grad.write() {
948            for e in 0..self.embedding_dim {
949                for i in 0..patch_dim {
950                    weight_grad_guard[[e, i]] = weight_grad_guard[[e, i]] + weight_grad[[e, i]];
951                }
952            }
953        } else {
954            return Err(Error::InferenceError(
955                "Failed to acquire write lock on weight gradients".to_string(),
956            ));
957        }
958
959        if let (Some(ref bg_acc_lock), Some(ref bg)) = (&self.bias_grad, &bias_grad) {
960            if let Ok(mut bg_acc) = bg_acc_lock.write() {
961                for e in 0..self.embedding_dim {
962                    bg_acc[[e]] = bg_acc[[e]] + bg[[e]];
963                }
964            } else {
965                return Err(Error::InferenceError(
966                    "Failed to acquire write lock on bias gradients".to_string(),
967                ));
968            }
969        }
970
971        // Compute gradient with respect to input
972        let mut input_grad = Array::zeros(IxDyn(&[
973            batch_size,
974            self.in_channels,
975            self.image_size.0,
976            self.image_size.1,
977        ]));
978
979        // Calculate gradient for each patch
980        let mut patches_grad = Array::zeros(patches.dim());
981
982        for b in 0..batch_size {
983            for p in 0..num_patches {
984                for i in 0..patch_dim {
985                    let mut grad = F::zero();
986                    for e in 0..self.embedding_dim {
987                        grad = grad + grad_output[[b, p, e]] * self.weight[[e, i]];
988                    }
989                    patches_grad[[b, p, i]] = grad;
990                }
991            }
992        }
993
994        // Reshape patches gradient back to image space
995        let n_h = self.image_size.0 / self.patch_size.0;
996        let n_w = self.image_size.1 / self.patch_size.1;
997
998        for b in 0..batch_size {
999            for i in 0..n_h {
1000                for j in 0..n_w {
1001                    let patch_idx = i * n_w + j;
1002                    let h_start = i * self.patch_size.0;
1003                    let w_start = j * self.patch_size.1;
1004
1005                    // Unflatten the patch gradient
1006                    let mut flat_idx = 0;
1007                    for c in 0..self.in_channels {
1008                        for ph in 0..self.patch_size.0 {
1009                            for pw in 0..self.patch_size.1 {
1010                                let h_idx = h_start + ph;
1011                                let w_idx = w_start + pw;
1012                                input_grad[[b, c, h_idx, w_idx]] =
1013                                    patches_grad[[b, patch_idx, flat_idx]];
1014                                flat_idx += 1;
1015                            }
1016                        }
1017                    }
1018                }
1019            }
1020        }
1021
1022        Ok(input_grad)
1023    }
1024
1025    fn update(&mut self, learning_rate: F) -> Result<()> {
1026        // Update weights
1027        let lr = learning_rate;
1028        let patch_dim = self.weight.shape()[1];
1029
1030        // Get the weight gradients inside RwLock
1031        if let Ok(weight_grad_guard) = self.weight_grad.read() {
1032            for e in 0..self.embedding_dim {
1033                for i in 0..patch_dim {
1034                    self.weight[[e, i]] = self.weight[[e, i]] - lr * weight_grad_guard[[e, i]];
1035                }
1036            }
1037        } else {
1038            return Err(Error::InferenceError(
1039                "Failed to acquire read lock on weight gradients".to_string(),
1040            ));
1041        }
1042
1043        // Update bias if present
1044        if let Some(ref mut bias) = &mut self.bias {
1045            if let Some(ref bias_grad_lock) = &self.bias_grad {
1046                if let Ok(bias_grad_guard) = bias_grad_lock.read() {
1047                    for e in 0..self.embedding_dim {
1048                        bias[[e]] = bias[[e]] - lr * bias_grad_guard[[e]];
1049                    }
1050                } else {
1051                    return Err(Error::InferenceError(
1052                        "Failed to acquire read lock on bias gradients".to_string(),
1053                    ));
1054                }
1055            }
1056        }
1057
1058        // Reset gradients
1059        if let Ok(mut weight_grad_guard) = self.weight_grad.write() {
1060            weight_grad_guard.fill(F::zero());
1061        } else {
1062            return Err(Error::InferenceError(
1063                "Failed to acquire write lock on weight gradients".to_string(),
1064            ));
1065        }
1066
1067        if let Some(ref bias_grad_lock) = &self.bias_grad {
1068            if let Ok(mut bias_grad_guard) = bias_grad_lock.write() {
1069                bias_grad_guard.fill(F::zero());
1070            } else {
1071                return Err(Error::InferenceError(
1072                    "Failed to acquire write lock on bias gradients".to_string(),
1073                ));
1074            }
1075        }
1076
1077        // Clean input cache to free memory
1078        if let Ok(mut cache) = self.input_cache.write() {
1079            *cache = None;
1080        } else {
1081            return Err(Error::InferenceError(
1082                "Failed to acquire write lock on input cache".to_string(),
1083            ));
1084        }
1085
1086        Ok(())
1087    }
1088
1089    fn as_any(&self) -> &dyn std::any::Any {
1090        self
1091    }
1092
1093    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
1094        self
1095    }
1096}
1097
1098#[cfg(test)]
1099mod tests {
1100    use super::*;
1101    use ndarray::Array2;
1102    use rand::Rng;
1103
1104    #[test]
1105    fn test_embedding_creation() {
1106        // Create embedding layer
1107        let config = EmbeddingConfig {
1108            num_embeddings: 10,
1109            embedding_dim: 5,
1110            padding_idx: Some(0),
1111            max_norm: None,
1112            norm_type: 2.0,
1113            scale_grad_by_freq: false,
1114            sparse: false,
1115        };
1116
1117        let embedding = Embedding::<f32>::new(config).unwrap();
1118
1119        // Check dimensions
1120        assert_eq!(embedding.weight.shape(), &[10, 5]);
1121
1122        // Check that padding index is zero
1123        for i in 0..5 {
1124            assert_eq!(embedding.weight[[0, i]], 0.0);
1125        }
1126    }
1127
1128    #[test]
1129    fn test_embedding_forward() {
1130        // Create embedding layer
1131        let config = EmbeddingConfig {
1132            num_embeddings: 10,
1133            embedding_dim: 5,
1134            padding_idx: Some(0),
1135            max_norm: None,
1136            norm_type: 2.0,
1137            scale_grad_by_freq: false,
1138            sparse: false,
1139        };
1140
1141        let mut embedding = Embedding::<f32>::new(config).unwrap();
1142
1143        // Set weight values for testing
1144        for i in 0..10 {
1145            for j in 0..5 {
1146                embedding.weight[[i, j]] = (i * 10 + j) as f32 / 10.0;
1147            }
1148        }
1149
1150        // Zero out padding index
1151        for j in 0..5 {
1152            embedding.weight[[0, j]] = 0.0;
1153        }
1154
1155        // Create input indices
1156        let indices = Array2::from_shape_vec((2, 3), vec![1, 2, 0, 3, 0, 4]).unwrap();
1157        let indices_dyn = indices.into_dimensionality::<IxDyn>().unwrap();
1158
1159        // Forward pass
1160        let output = embedding.forward_impl(&indices_dyn).unwrap();
1161
1162        // Check output shape
1163        assert_eq!(output.shape(), &[2, 3, 5]);
1164
1165        // Check values
1166        // First batch, first token (index 1)
1167        for j in 0..5 {
1168            assert_eq!(output[[0, 0, j]], (10 + j) as f32 / 10.0);
1169        }
1170
1171        // First batch, second token (index 2)
1172        for j in 0..5 {
1173            assert_eq!(output[[0, 1, j]], (20 + j) as f32 / 10.0);
1174        }
1175
1176        // First batch, third token (index 0, padding)
1177        for j in 0..5 {
1178            assert_eq!(output[[0, 2, j]], 0.0);
1179        }
1180    }
1181
1182    #[test]
1183    fn test_positional_embedding() {
1184        // Test learned positional embeddings
1185        let pos_emb_learned = PositionalEmbedding::<f32>::new(10, 8, true).unwrap();
1186
1187        // Check dimensions
1188        assert!(pos_emb_learned.weight.is_some());
1189        assert_eq!(pos_emb_learned.weight.as_ref().unwrap().shape(), &[10, 8]);
1190
1191        // Create dummy input
1192        let input = Array::from_shape_fn(IxDyn(&[2, 5, 8]), |_| 1.0f32);
1193
1194        // Forward pass
1195        let output = pos_emb_learned.forward(&input).unwrap();
1196
1197        // Check output shape
1198        assert_eq!(output.shape(), &[2, 5, 8]);
1199
1200        // Test fixed sinusoidal positional embeddings
1201        let pos_emb_fixed = PositionalEmbedding::<f32>::new(10, 8, false).unwrap();
1202
1203        // Check that weight is None for fixed embeddings
1204        assert!(pos_emb_fixed.weight.is_none());
1205
1206        // Forward pass
1207        let output = pos_emb_fixed.forward(&input).unwrap();
1208
1209        // Check output shape
1210        assert_eq!(output.shape(), &[2, 5, 8]);
1211    }
1212
1213    #[test]
1214    fn test_patch_embedding() {
1215        // Create patch embedding layer
1216        let patch_emb = PatchEmbedding::<f32>::new((32, 32), (8, 8), 3, 96, true).unwrap();
1217
1218        // Check dimensions
1219        assert_eq!(patch_emb.weight.shape(), &[96, 3 * 8 * 8]);
1220        assert!(patch_emb.bias.is_some());
1221        assert_eq!(patch_emb.bias.as_ref().unwrap().shape(), &[96]);
1222
1223        // Check number of patches
1224        assert_eq!(patch_emb.num_patches(), 16); // 4x4 patches of 8x8 in a 32x32 image
1225
1226        // Create random input
1227        let mut rand_gen = rand::rng();
1228        let input = Array::from_shape_fn(IxDyn(&[2, 3, 32, 32]), |_| rand_gen.random::<f32>());
1229
1230        // Forward pass
1231        let output = patch_emb.forward(&input).unwrap();
1232
1233        // Check output shape [batch_size, num_patches, embedding_dim]
1234        assert_eq!(output.shape(), &[2, 16, 96]);
1235    }
1236}