Skip to main content

scirs2_neural/layers/
patch_embed.rs

1//! Patch Embedding layer for Vision Transformers
2//!
3//! Implements the patch embedding operation used in ViT and related architectures.
4//! An input image `[batch, channels, height, width]` is divided into non-overlapping
5//! `patch_size × patch_size` patches, which are then linearly projected to an
6//! embedding space of dimension `embed_dim`, yielding `[batch, num_patches, embed_dim]`.
7//!
8//! Reference: "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale",
9//! Dosovitskiy et al. (2020). <https://arxiv.org/abs/2010.11929>
10
11use crate::error::{NeuralError, Result};
12use crate::layers::{Layer, ParamLayer};
13use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
14use scirs2_core::numeric::{Float, NumAssign};
15use scirs2_core::random::{Distribution, Rng, RngExt, Uniform};
16use std::fmt::Debug;
17use std::sync::RwLock;
18
19/// Patch Embedding layer.
20///
21/// Divides an image into fixed-size patches and projects each patch to an embedding
22/// vector via a learned linear projection (optionally including a bias term).
23///
24/// # Shape
25/// - Input: `[batch_size, in_channels, image_height, image_width]`
26/// - Output: `[batch_size, num_patches, embed_dim]`
27///   where `num_patches = (image_height / patch_size) * (image_width / patch_size)`
28///
29/// # Parameters
30/// - `weight`: shape `[embed_dim, in_channels * patch_h * patch_w]` — the projection matrix
31/// - `bias` (optional): shape `[embed_dim]` — per-embedding bias
32pub struct PatchEmbedding<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
33    /// Image size `(height, width)` expected as input
34    image_size: (usize, usize),
35    /// Patch size `(height, width)`
36    patch_size: (usize, usize),
37    /// Number of input channels
38    in_channels: usize,
39    /// Output embedding dimension
40    embed_dim: usize,
41    /// Number of patches along height
42    num_patches_h: usize,
43    /// Number of patches along width
44    num_patches_w: usize,
45    /// Flat patch dimension: `in_channels * patch_h * patch_w`
46    patch_dim: usize,
47
48    /// Projection weight matrix, shape `[embed_dim, patch_dim]`
49    weight: Array<F, IxDyn>,
50    /// Projection bias, shape `[embed_dim]` (optional but always allocated as zeros)
51    bias: Array<F, IxDyn>,
52    /// Whether a bias term is used
53    use_bias: bool,
54
55    /// Gradient for weight
56    d_weight: RwLock<Array<F, IxDyn>>,
57    /// Gradient for bias
58    d_bias: RwLock<Array<F, IxDyn>>,
59    /// Cached input from most recent forward pass (for backward)
60    cached_patches: RwLock<Option<Array<F, IxDyn>>>,
61}
62
63impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Debug
64    for PatchEmbedding<F>
65{
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        f.debug_struct("PatchEmbedding")
68            .field("image_size", &self.image_size)
69            .field("patch_size", &self.patch_size)
70            .field("in_channels", &self.in_channels)
71            .field("embed_dim", &self.embed_dim)
72            .field("num_patches", &(self.num_patches_h * self.num_patches_w))
73            .field("use_bias", &self.use_bias)
74            .finish()
75    }
76}
77
78impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Clone
79    for PatchEmbedding<F>
80{
81    fn clone(&self) -> Self {
82        Self {
83            image_size: self.image_size,
84            patch_size: self.patch_size,
85            in_channels: self.in_channels,
86            embed_dim: self.embed_dim,
87            num_patches_h: self.num_patches_h,
88            num_patches_w: self.num_patches_w,
89            patch_dim: self.patch_dim,
90            weight: self.weight.clone(),
91            bias: self.bias.clone(),
92            use_bias: self.use_bias,
93            d_weight: RwLock::new(
94                self.d_weight
95                    .read()
96                    .expect("RwLock poisoned on d_weight read")
97                    .clone(),
98            ),
99            d_bias: RwLock::new(
100                self.d_bias
101                    .read()
102                    .expect("RwLock poisoned on d_bias read")
103                    .clone(),
104            ),
105            cached_patches: RwLock::new(
106                self.cached_patches
107                    .read()
108                    .expect("RwLock poisoned on cached_patches read")
109                    .clone(),
110            ),
111        }
112    }
113}
114
115impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> PatchEmbedding<F> {
116    /// Create a new PatchEmbedding layer.
117    ///
118    /// # Arguments
119    /// * `image_size` — expected input image dimensions `(height, width)`
120    /// * `patch_size` — patch dimensions `(height, width)` — must divide `image_size` evenly
121    /// * `in_channels` — number of input image channels
122    /// * `embed_dim` — output embedding dimension
123    /// * `use_bias` — whether to include a learnable bias term
124    /// * `rng` — random number generator for weight initialisation
125    pub fn new<R: Rng>(
126        image_size: (usize, usize),
127        patch_size: (usize, usize),
128        in_channels: usize,
129        embed_dim: usize,
130        use_bias: bool,
131        rng: &mut R,
132    ) -> Result<Self> {
133        if image_size.0 == 0 || image_size.1 == 0 {
134            return Err(NeuralError::InvalidArchitecture(
135                "image_size dimensions must be non-zero".to_string(),
136            ));
137        }
138        if patch_size.0 == 0 || patch_size.1 == 0 {
139            return Err(NeuralError::InvalidArchitecture(
140                "patch_size dimensions must be non-zero".to_string(),
141            ));
142        }
143        if !image_size.0.is_multiple_of(patch_size.0) || !image_size.1.is_multiple_of(patch_size.1)
144        {
145            return Err(NeuralError::InvalidArchitecture(format!(
146                "image_size {:?} must be divisible by patch_size {:?}",
147                image_size, patch_size
148            )));
149        }
150        if in_channels == 0 {
151            return Err(NeuralError::InvalidArchitecture(
152                "in_channels must be non-zero".to_string(),
153            ));
154        }
155        if embed_dim == 0 {
156            return Err(NeuralError::InvalidArchitecture(
157                "embed_dim must be non-zero".to_string(),
158            ));
159        }
160
161        let num_patches_h = image_size.0 / patch_size.0;
162        let num_patches_w = image_size.1 / patch_size.1;
163        let patch_dim = in_channels * patch_size.0 * patch_size.1;
164
165        // Kaiming / Xavier uniform initialisation: uniform(-bound, bound)
166        // bound = sqrt(6 / (fan_in + fan_out)) for Xavier uniform
167        let fan_in = patch_dim as f64;
168        let fan_out = embed_dim as f64;
169        let bound = f64::sqrt(6.0 / (fan_in + fan_out));
170
171        let uniform = Uniform::new(-bound, bound).map_err(|e| {
172            NeuralError::InvalidArchitecture(format!("Failed to create uniform distribution: {e}"))
173        })?;
174
175        // Weight: [embed_dim, patch_dim]
176        let weight_vec: Vec<F> = (0..(embed_dim * patch_dim))
177            .map(|_| {
178                F::from(uniform.sample(rng))
179                    .ok_or_else(|| {
180                        NeuralError::InvalidArchitecture(
181                            "Failed to convert random value to float type".to_string(),
182                        )
183                    })
184                    .unwrap_or(F::zero())
185            })
186            .collect();
187
188        let weight =
189            Array::from_shape_vec(IxDyn(&[embed_dim, patch_dim]), weight_vec).map_err(|e| {
190                NeuralError::InvalidArchitecture(format!("Failed to construct weight array: {e}"))
191            })?;
192
193        // Bias: [embed_dim], initialised to zero
194        let bias = Array::zeros(IxDyn(&[embed_dim]));
195
196        let d_weight = RwLock::new(Array::zeros(IxDyn(&[embed_dim, patch_dim])));
197        let d_bias = RwLock::new(Array::zeros(IxDyn(&[embed_dim])));
198
199        Ok(Self {
200            image_size,
201            patch_size,
202            in_channels,
203            embed_dim,
204            num_patches_h,
205            num_patches_w,
206            patch_dim,
207            weight,
208            bias,
209            use_bias,
210            d_weight,
211            d_bias,
212            cached_patches: RwLock::new(None),
213        })
214    }
215
216    /// Total number of patches produced from one image
217    pub fn num_patches(&self) -> usize {
218        self.num_patches_h * self.num_patches_w
219    }
220
221    /// Flat patch vector dimension: `in_channels * patch_h * patch_w`
222    pub fn patch_dim(&self) -> usize {
223        self.patch_dim
224    }
225
226    /// Extract and flatten patches from the input image batch.
227    ///
228    /// Returns an array of shape `[batch, num_patches, patch_dim]`.
229    fn extract_patches(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
230        let shape = input.shape();
231        let batch = shape[0];
232        let num_patches = self.num_patches_h * self.num_patches_w;
233
234        let mut patches = Array::zeros(IxDyn(&[batch, num_patches, self.patch_dim]));
235
236        for b in 0..batch {
237            for ph in 0..self.num_patches_h {
238                for pw in 0..self.num_patches_w {
239                    let patch_idx = ph * self.num_patches_w + pw;
240                    // Top-left pixel of this patch in the image
241                    let h_start = ph * self.patch_size.0;
242                    let w_start = pw * self.patch_size.1;
243                    let mut flat_idx = 0usize;
244                    for c in 0..self.in_channels {
245                        for dy in 0..self.patch_size.0 {
246                            for dx in 0..self.patch_size.1 {
247                                patches[[b, patch_idx, flat_idx]] =
248                                    input[[b, c, h_start + dy, w_start + dx]];
249                                flat_idx += 1;
250                            }
251                        }
252                    }
253                }
254            }
255        }
256
257        Ok(patches)
258    }
259
260    /// Apply the linear projection `patches @ weight.T + bias` over all patches.
261    ///
262    /// Input shape: `[batch, num_patches, patch_dim]`
263    /// Output shape: `[batch, num_patches, embed_dim]`
264    fn linear_project(&self, patches: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
265        let batch = patches.shape()[0];
266        let num_patches = patches.shape()[1];
267        let mut output = Array::zeros(IxDyn(&[batch, num_patches, self.embed_dim]));
268
269        for b in 0..batch {
270            for p in 0..num_patches {
271                for e in 0..self.embed_dim {
272                    let mut acc = F::zero();
273                    for k in 0..self.patch_dim {
274                        acc += patches[[b, p, k]] * self.weight[[e, k]];
275                    }
276                    if self.use_bias {
277                        acc += self.bias[e];
278                    }
279                    output[[b, p, e]] = acc;
280                }
281            }
282        }
283
284        Ok(output)
285    }
286
287    /// Validate input tensor shape for this layer.
288    fn validate_input_shape(&self, input: &Array<F, IxDyn>) -> Result<()> {
289        let shape = input.shape();
290        if shape.len() != 4 {
291            return Err(NeuralError::InferenceError(format!(
292                "PatchEmbedding expects 4-D input [batch, channels, height, width], got {:?}",
293                shape
294            )));
295        }
296        if shape[1] != self.in_channels {
297            return Err(NeuralError::InferenceError(format!(
298                "PatchEmbedding: expected {} input channels, got {}",
299                self.in_channels, shape[1]
300            )));
301        }
302        if shape[2] != self.image_size.0 || shape[3] != self.image_size.1 {
303            return Err(NeuralError::InferenceError(format!(
304                "PatchEmbedding: expected image size {:?}, got ({}, {})",
305                self.image_size, shape[2], shape[3]
306            )));
307        }
308        Ok(())
309    }
310}
311
312impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Layer<F>
313    for PatchEmbedding<F>
314{
315    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
316        self.validate_input_shape(input)?;
317
318        // Extract patches: [batch, num_patches, patch_dim]
319        let patches = self.extract_patches(input)?;
320
321        // Cache flattened patches for backward pass
322        {
323            let mut cache = self
324                .cached_patches
325                .write()
326                .expect("RwLock poisoned on cached_patches write");
327            *cache = Some(patches.clone());
328        }
329
330        // Linear projection: [batch, num_patches, embed_dim]
331        self.linear_project(&patches)
332    }
333
334    fn backward(
335        &self,
336        _input: &Array<F, IxDyn>,
337        grad_output: &Array<F, IxDyn>,
338    ) -> Result<Array<F, IxDyn>> {
339        // grad_output shape: [batch, num_patches, embed_dim]
340        let go_shape = grad_output.shape();
341        if go_shape.len() != 3 {
342            return Err(NeuralError::InferenceError(format!(
343                "PatchEmbedding backward: grad_output must be 3-D [batch, num_patches, embed_dim], got {:?}",
344                go_shape
345            )));
346        }
347        let batch = go_shape[0];
348        let num_patches = go_shape[1];
349
350        let patches = {
351            let cache = self
352                .cached_patches
353                .read()
354                .expect("RwLock poisoned on cached_patches read");
355            cache.clone().ok_or_else(|| {
356                NeuralError::InferenceError(
357                    "PatchEmbedding backward called before forward — no cached patches".to_string(),
358                )
359            })?
360        };
361
362        // Gradient w.r.t. weight: d_weight[e, k] = sum_{b,p} grad_output[b,p,e] * patches[b,p,k]
363        let mut d_weight = Array::zeros(IxDyn(&[self.embed_dim, self.patch_dim]));
364        for b in 0..batch {
365            for p in 0..num_patches {
366                for e in 0..self.embed_dim {
367                    let go = grad_output[[b, p, e]];
368                    for k in 0..self.patch_dim {
369                        d_weight[[e, k]] += go * patches[[b, p, k]];
370                    }
371                }
372            }
373        }
374
375        // Gradient w.r.t. bias: d_bias[e] = sum_{b,p} grad_output[b,p,e]
376        let mut d_bias = Array::zeros(IxDyn(&[self.embed_dim]));
377        if self.use_bias {
378            for b in 0..batch {
379                for p in 0..num_patches {
380                    for e in 0..self.embed_dim {
381                        d_bias[e] += grad_output[[b, p, e]];
382                    }
383                }
384            }
385        }
386
387        // Store gradients
388        {
389            let mut dw = self
390                .d_weight
391                .write()
392                .expect("RwLock poisoned on d_weight write");
393            *dw = d_weight;
394        }
395        {
396            let mut db = self
397                .d_bias
398                .write()
399                .expect("RwLock poisoned on d_bias write");
400            *db = d_bias;
401        }
402
403        // Gradient w.r.t. input patches: d_patches[b,p,k] = sum_e grad_output[b,p,e] * weight[e,k]
404        let mut d_patches = Array::zeros(IxDyn(&[batch, num_patches, self.patch_dim]));
405        for b in 0..batch {
406            for p in 0..num_patches {
407                for k in 0..self.patch_dim {
408                    let mut acc = F::zero();
409                    for e in 0..self.embed_dim {
410                        acc += grad_output[[b, p, e]] * self.weight[[e, k]];
411                    }
412                    d_patches[[b, p, k]] = acc;
413                }
414            }
415        }
416
417        // Scatter gradient back into image-shaped tensor [batch, channels, H, W]
418        let mut d_input = Array::zeros(IxDyn(&[
419            batch,
420            self.in_channels,
421            self.image_size.0,
422            self.image_size.1,
423        ]));
424        for b in 0..batch {
425            for ph in 0..self.num_patches_h {
426                for pw in 0..self.num_patches_w {
427                    let patch_idx = ph * self.num_patches_w + pw;
428                    let h_start = ph * self.patch_size.0;
429                    let w_start = pw * self.patch_size.1;
430                    let mut flat_idx = 0usize;
431                    for c in 0..self.in_channels {
432                        for dy in 0..self.patch_size.0 {
433                            for dx in 0..self.patch_size.1 {
434                                d_input[[b, c, h_start + dy, w_start + dx]] +=
435                                    d_patches[[b, patch_idx, flat_idx]];
436                                flat_idx += 1;
437                            }
438                        }
439                    }
440                }
441            }
442        }
443
444        Ok(d_input)
445    }
446
447    fn update(&mut self, learning_rate: F) -> Result<()> {
448        let d_weight = {
449            self.d_weight
450                .read()
451                .expect("RwLock poisoned on d_weight read")
452                .clone()
453        };
454        let d_bias = {
455            self.d_bias
456                .read()
457                .expect("RwLock poisoned on d_bias read")
458                .clone()
459        };
460
461        // SGD update for weight
462        for e in 0..self.embed_dim {
463            for k in 0..self.patch_dim {
464                self.weight[[e, k]] -= learning_rate * d_weight[[e, k]];
465            }
466        }
467
468        // SGD update for bias
469        if self.use_bias {
470            for e in 0..self.embed_dim {
471                self.bias[e] -= learning_rate * d_bias[e];
472            }
473        }
474
475        Ok(())
476    }
477
478    fn as_any(&self) -> &dyn std::any::Any {
479        self
480    }
481
482    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
483        self
484    }
485
486    fn layer_type(&self) -> &str {
487        "PatchEmbedding"
488    }
489
490    fn parameter_count(&self) -> usize {
491        let weight_params = self.embed_dim * self.patch_dim;
492        let bias_params = if self.use_bias { self.embed_dim } else { 0 };
493        weight_params + bias_params
494    }
495
496    fn layer_description(&self) -> String {
497        format!(
498            "type:PatchEmbedding, image_size:{:?}, patch_size:{:?}, in_channels:{}, embed_dim:{}, num_patches:{}, params:{}",
499            self.image_size,
500            self.patch_size,
501            self.in_channels,
502            self.embed_dim,
503            self.num_patches(),
504            self.parameter_count()
505        )
506    }
507
508    fn params(&self) -> Vec<Array<F, IxDyn>> {
509        if self.use_bias {
510            vec![self.weight.clone(), self.bias.clone()]
511        } else {
512            vec![self.weight.clone()]
513        }
514    }
515
516    fn set_params(&mut self, params: &[Array<F, IxDyn>]) -> Result<()> {
517        if params.is_empty() {
518            return Err(NeuralError::InvalidArchitecture(
519                "PatchEmbedding set_params: expected at least 1 parameter (weight)".to_string(),
520            ));
521        }
522        self.weight = params[0].clone();
523        if self.use_bias && params.len() >= 2 {
524            self.bias = params[1].clone();
525        }
526        Ok(())
527    }
528
529    fn inputshape(&self) -> Option<Vec<usize>> {
530        Some(vec![self.in_channels, self.image_size.0, self.image_size.1])
531    }
532
533    fn outputshape(&self) -> Option<Vec<usize>> {
534        Some(vec![self.num_patches(), self.embed_dim])
535    }
536}
537
538impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> ParamLayer<F>
539    for PatchEmbedding<F>
540{
541    fn get_parameters(&self) -> Vec<Array<F, IxDyn>> {
542        self.params()
543    }
544
545    fn get_gradients(&self) -> Vec<Array<F, IxDyn>> {
546        let dw = self
547            .d_weight
548            .read()
549            .expect("RwLock poisoned on d_weight read")
550            .clone();
551        if self.use_bias {
552            let db = self
553                .d_bias
554                .read()
555                .expect("RwLock poisoned on d_bias read")
556                .clone();
557            vec![dw, db]
558        } else {
559            vec![dw]
560        }
561    }
562
563    fn set_parameters(&mut self, params: Vec<Array<F, IxDyn>>) -> Result<()> {
564        self.set_params(&params)
565    }
566}
567
568// Safety: PatchEmbedding is safe to send across threads; interior mutability is via RwLock
569unsafe impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Send for PatchEmbedding<F> {}
570unsafe impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Sync for PatchEmbedding<F> {}
571
572#[cfg(test)]
573mod tests {
574    use super::*;
575    use scirs2_core::random::rngs::SmallRng;
576    use scirs2_core::random::SeedableRng;
577
578    fn make_embed(
579        image_size: (usize, usize),
580        patch_size: (usize, usize),
581        in_channels: usize,
582        embed_dim: usize,
583    ) -> PatchEmbedding<f64> {
584        let mut rng = SmallRng::from_seed([0u8; 32]);
585        PatchEmbedding::new(
586            image_size,
587            patch_size,
588            in_channels,
589            embed_dim,
590            true,
591            &mut rng,
592        )
593        .expect("Failed to construct PatchEmbedding")
594    }
595
596    #[test]
597    fn test_patch_embedding_output_shape() {
598        // 8×8 image, 2×2 patches → 4×4 = 16 patches
599        let layer = make_embed((8, 8), (2, 2), 3, 32);
600        assert_eq!(layer.num_patches(), 16);
601        assert_eq!(layer.patch_dim(), 3 * 2 * 2);
602
603        let batch = 2usize;
604        let input = Array::zeros(IxDyn(&[batch, 3, 8, 8]));
605        let output = layer.forward(&input).expect("Forward pass failed");
606        assert_eq!(output.shape(), &[batch, 16, 32]);
607    }
608
609    #[test]
610    fn test_patch_embedding_parameter_count() {
611        let layer = make_embed((16, 16), (4, 4), 3, 64);
612        // patch_dim = 3*4*4 = 48; weight = 64*48 = 3072; bias = 64
613        assert_eq!(layer.parameter_count(), 64 * 48 + 64);
614    }
615
616    #[test]
617    fn test_patch_embedding_backward_shape() {
618        let layer = make_embed((8, 8), (2, 2), 3, 32);
619        let batch = 2usize;
620        let input = Array::zeros(IxDyn(&[batch, 3, 8, 8]));
621        let output = layer.forward(&input).expect("Forward failed");
622        let grad_out = Array::ones(output.raw_dim());
623        let grad_in = layer
624            .backward(&input, &grad_out)
625            .expect("Backward pass failed");
626        // Gradient must match input shape
627        assert_eq!(grad_in.shape(), input.shape());
628    }
629
630    #[test]
631    fn test_patch_embedding_invalid_size() {
632        let mut rng = SmallRng::from_seed([0u8; 32]);
633        // 7 is not divisible by 4
634        let result = PatchEmbedding::<f64>::new((7, 8), (4, 4), 3, 32, true, &mut rng);
635        assert!(result.is_err());
636    }
637
638    #[test]
639    fn test_patch_embedding_update() {
640        let mut layer = make_embed((8, 8), (2, 2), 1, 16);
641        let input = Array::zeros(IxDyn(&[1, 1, 8, 8]));
642        let output = layer.forward(&input).expect("Forward failed");
643        let grad_out = Array::ones(output.raw_dim());
644        layer.backward(&input, &grad_out).expect("Backward failed");
645        layer.update(0.01f64).expect("Update failed");
646    }
647
648    #[test]
649    fn test_patch_embedding_round_trip_params() {
650        // get_parameters → set_parameters must preserve forward output identically
651        let mut layer = make_embed((8, 8), (2, 2), 3, 16);
652        let input = Array::ones(IxDyn(&[1, 3, 8, 8]));
653
654        let out_before = layer
655            .forward(&input)
656            .expect("Forward before round-trip failed");
657
658        // Extract parameters, set them back
659        let params = layer.get_parameters();
660        layer
661            .set_parameters(params.clone())
662            .expect("set_parameters failed");
663
664        let out_after = layer
665            .forward(&input)
666            .expect("Forward after round-trip failed");
667
668        // Outputs must be bit-identical
669        assert_eq!(out_before.shape(), out_after.shape());
670        for (a, b) in out_before.iter().zip(out_after.iter()) {
671            assert!((a - b).abs() < 1e-12, "round-trip mismatch: {a} vs {b}");
672        }
673
674        // Parameter count consistency
675        // make_embed((8,8),(2,2),3,16): embed_dim=16, patch_dim=3*2*2=12
676        // weight: 16*12=192 elements; bias: 16 elements
677        assert_eq!(params.len(), 2); // weight + bias
678        assert_eq!(params[0].len(), 16 * 3 * 2 * 2); // embed_dim * patch_dim
679        assert_eq!(params[1].len(), 16); // embed_dim
680    }
681}