1use crate::error::{NeuralError, Result};
12use crate::layers::{Layer, ParamLayer};
13use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
14use scirs2_core::numeric::{Float, NumAssign};
15use scirs2_core::random::{Distribution, Rng, RngExt, Uniform};
16use std::fmt::Debug;
17use std::sync::RwLock;
18
19pub struct PatchEmbedding<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
33 image_size: (usize, usize),
35 patch_size: (usize, usize),
37 in_channels: usize,
39 embed_dim: usize,
41 num_patches_h: usize,
43 num_patches_w: usize,
45 patch_dim: usize,
47
48 weight: Array<F, IxDyn>,
50 bias: Array<F, IxDyn>,
52 use_bias: bool,
54
55 d_weight: RwLock<Array<F, IxDyn>>,
57 d_bias: RwLock<Array<F, IxDyn>>,
59 cached_patches: RwLock<Option<Array<F, IxDyn>>>,
61}
62
63impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Debug
64 for PatchEmbedding<F>
65{
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 f.debug_struct("PatchEmbedding")
68 .field("image_size", &self.image_size)
69 .field("patch_size", &self.patch_size)
70 .field("in_channels", &self.in_channels)
71 .field("embed_dim", &self.embed_dim)
72 .field("num_patches", &(self.num_patches_h * self.num_patches_w))
73 .field("use_bias", &self.use_bias)
74 .finish()
75 }
76}
77
78impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Clone
79 for PatchEmbedding<F>
80{
81 fn clone(&self) -> Self {
82 Self {
83 image_size: self.image_size,
84 patch_size: self.patch_size,
85 in_channels: self.in_channels,
86 embed_dim: self.embed_dim,
87 num_patches_h: self.num_patches_h,
88 num_patches_w: self.num_patches_w,
89 patch_dim: self.patch_dim,
90 weight: self.weight.clone(),
91 bias: self.bias.clone(),
92 use_bias: self.use_bias,
93 d_weight: RwLock::new(
94 self.d_weight
95 .read()
96 .expect("RwLock poisoned on d_weight read")
97 .clone(),
98 ),
99 d_bias: RwLock::new(
100 self.d_bias
101 .read()
102 .expect("RwLock poisoned on d_bias read")
103 .clone(),
104 ),
105 cached_patches: RwLock::new(
106 self.cached_patches
107 .read()
108 .expect("RwLock poisoned on cached_patches read")
109 .clone(),
110 ),
111 }
112 }
113}
114
115impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> PatchEmbedding<F> {
116 pub fn new<R: Rng>(
126 image_size: (usize, usize),
127 patch_size: (usize, usize),
128 in_channels: usize,
129 embed_dim: usize,
130 use_bias: bool,
131 rng: &mut R,
132 ) -> Result<Self> {
133 if image_size.0 == 0 || image_size.1 == 0 {
134 return Err(NeuralError::InvalidArchitecture(
135 "image_size dimensions must be non-zero".to_string(),
136 ));
137 }
138 if patch_size.0 == 0 || patch_size.1 == 0 {
139 return Err(NeuralError::InvalidArchitecture(
140 "patch_size dimensions must be non-zero".to_string(),
141 ));
142 }
143 if !image_size.0.is_multiple_of(patch_size.0) || !image_size.1.is_multiple_of(patch_size.1)
144 {
145 return Err(NeuralError::InvalidArchitecture(format!(
146 "image_size {:?} must be divisible by patch_size {:?}",
147 image_size, patch_size
148 )));
149 }
150 if in_channels == 0 {
151 return Err(NeuralError::InvalidArchitecture(
152 "in_channels must be non-zero".to_string(),
153 ));
154 }
155 if embed_dim == 0 {
156 return Err(NeuralError::InvalidArchitecture(
157 "embed_dim must be non-zero".to_string(),
158 ));
159 }
160
161 let num_patches_h = image_size.0 / patch_size.0;
162 let num_patches_w = image_size.1 / patch_size.1;
163 let patch_dim = in_channels * patch_size.0 * patch_size.1;
164
165 let fan_in = patch_dim as f64;
168 let fan_out = embed_dim as f64;
169 let bound = f64::sqrt(6.0 / (fan_in + fan_out));
170
171 let uniform = Uniform::new(-bound, bound).map_err(|e| {
172 NeuralError::InvalidArchitecture(format!("Failed to create uniform distribution: {e}"))
173 })?;
174
175 let weight_vec: Vec<F> = (0..(embed_dim * patch_dim))
177 .map(|_| {
178 F::from(uniform.sample(rng))
179 .ok_or_else(|| {
180 NeuralError::InvalidArchitecture(
181 "Failed to convert random value to float type".to_string(),
182 )
183 })
184 .unwrap_or(F::zero())
185 })
186 .collect();
187
188 let weight =
189 Array::from_shape_vec(IxDyn(&[embed_dim, patch_dim]), weight_vec).map_err(|e| {
190 NeuralError::InvalidArchitecture(format!("Failed to construct weight array: {e}"))
191 })?;
192
193 let bias = Array::zeros(IxDyn(&[embed_dim]));
195
196 let d_weight = RwLock::new(Array::zeros(IxDyn(&[embed_dim, patch_dim])));
197 let d_bias = RwLock::new(Array::zeros(IxDyn(&[embed_dim])));
198
199 Ok(Self {
200 image_size,
201 patch_size,
202 in_channels,
203 embed_dim,
204 num_patches_h,
205 num_patches_w,
206 patch_dim,
207 weight,
208 bias,
209 use_bias,
210 d_weight,
211 d_bias,
212 cached_patches: RwLock::new(None),
213 })
214 }
215
216 pub fn num_patches(&self) -> usize {
218 self.num_patches_h * self.num_patches_w
219 }
220
221 pub fn patch_dim(&self) -> usize {
223 self.patch_dim
224 }
225
226 fn extract_patches(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
230 let shape = input.shape();
231 let batch = shape[0];
232 let num_patches = self.num_patches_h * self.num_patches_w;
233
234 let mut patches = Array::zeros(IxDyn(&[batch, num_patches, self.patch_dim]));
235
236 for b in 0..batch {
237 for ph in 0..self.num_patches_h {
238 for pw in 0..self.num_patches_w {
239 let patch_idx = ph * self.num_patches_w + pw;
240 let h_start = ph * self.patch_size.0;
242 let w_start = pw * self.patch_size.1;
243 let mut flat_idx = 0usize;
244 for c in 0..self.in_channels {
245 for dy in 0..self.patch_size.0 {
246 for dx in 0..self.patch_size.1 {
247 patches[[b, patch_idx, flat_idx]] =
248 input[[b, c, h_start + dy, w_start + dx]];
249 flat_idx += 1;
250 }
251 }
252 }
253 }
254 }
255 }
256
257 Ok(patches)
258 }
259
260 fn linear_project(&self, patches: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
265 let batch = patches.shape()[0];
266 let num_patches = patches.shape()[1];
267 let mut output = Array::zeros(IxDyn(&[batch, num_patches, self.embed_dim]));
268
269 for b in 0..batch {
270 for p in 0..num_patches {
271 for e in 0..self.embed_dim {
272 let mut acc = F::zero();
273 for k in 0..self.patch_dim {
274 acc += patches[[b, p, k]] * self.weight[[e, k]];
275 }
276 if self.use_bias {
277 acc += self.bias[e];
278 }
279 output[[b, p, e]] = acc;
280 }
281 }
282 }
283
284 Ok(output)
285 }
286
287 fn validate_input_shape(&self, input: &Array<F, IxDyn>) -> Result<()> {
289 let shape = input.shape();
290 if shape.len() != 4 {
291 return Err(NeuralError::InferenceError(format!(
292 "PatchEmbedding expects 4-D input [batch, channels, height, width], got {:?}",
293 shape
294 )));
295 }
296 if shape[1] != self.in_channels {
297 return Err(NeuralError::InferenceError(format!(
298 "PatchEmbedding: expected {} input channels, got {}",
299 self.in_channels, shape[1]
300 )));
301 }
302 if shape[2] != self.image_size.0 || shape[3] != self.image_size.1 {
303 return Err(NeuralError::InferenceError(format!(
304 "PatchEmbedding: expected image size {:?}, got ({}, {})",
305 self.image_size, shape[2], shape[3]
306 )));
307 }
308 Ok(())
309 }
310}
311
312impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Layer<F>
313 for PatchEmbedding<F>
314{
315 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
316 self.validate_input_shape(input)?;
317
318 let patches = self.extract_patches(input)?;
320
321 {
323 let mut cache = self
324 .cached_patches
325 .write()
326 .expect("RwLock poisoned on cached_patches write");
327 *cache = Some(patches.clone());
328 }
329
330 self.linear_project(&patches)
332 }
333
334 fn backward(
335 &self,
336 _input: &Array<F, IxDyn>,
337 grad_output: &Array<F, IxDyn>,
338 ) -> Result<Array<F, IxDyn>> {
339 let go_shape = grad_output.shape();
341 if go_shape.len() != 3 {
342 return Err(NeuralError::InferenceError(format!(
343 "PatchEmbedding backward: grad_output must be 3-D [batch, num_patches, embed_dim], got {:?}",
344 go_shape
345 )));
346 }
347 let batch = go_shape[0];
348 let num_patches = go_shape[1];
349
350 let patches = {
351 let cache = self
352 .cached_patches
353 .read()
354 .expect("RwLock poisoned on cached_patches read");
355 cache.clone().ok_or_else(|| {
356 NeuralError::InferenceError(
357 "PatchEmbedding backward called before forward — no cached patches".to_string(),
358 )
359 })?
360 };
361
362 let mut d_weight = Array::zeros(IxDyn(&[self.embed_dim, self.patch_dim]));
364 for b in 0..batch {
365 for p in 0..num_patches {
366 for e in 0..self.embed_dim {
367 let go = grad_output[[b, p, e]];
368 for k in 0..self.patch_dim {
369 d_weight[[e, k]] += go * patches[[b, p, k]];
370 }
371 }
372 }
373 }
374
375 let mut d_bias = Array::zeros(IxDyn(&[self.embed_dim]));
377 if self.use_bias {
378 for b in 0..batch {
379 for p in 0..num_patches {
380 for e in 0..self.embed_dim {
381 d_bias[e] += grad_output[[b, p, e]];
382 }
383 }
384 }
385 }
386
387 {
389 let mut dw = self
390 .d_weight
391 .write()
392 .expect("RwLock poisoned on d_weight write");
393 *dw = d_weight;
394 }
395 {
396 let mut db = self
397 .d_bias
398 .write()
399 .expect("RwLock poisoned on d_bias write");
400 *db = d_bias;
401 }
402
403 let mut d_patches = Array::zeros(IxDyn(&[batch, num_patches, self.patch_dim]));
405 for b in 0..batch {
406 for p in 0..num_patches {
407 for k in 0..self.patch_dim {
408 let mut acc = F::zero();
409 for e in 0..self.embed_dim {
410 acc += grad_output[[b, p, e]] * self.weight[[e, k]];
411 }
412 d_patches[[b, p, k]] = acc;
413 }
414 }
415 }
416
417 let mut d_input = Array::zeros(IxDyn(&[
419 batch,
420 self.in_channels,
421 self.image_size.0,
422 self.image_size.1,
423 ]));
424 for b in 0..batch {
425 for ph in 0..self.num_patches_h {
426 for pw in 0..self.num_patches_w {
427 let patch_idx = ph * self.num_patches_w + pw;
428 let h_start = ph * self.patch_size.0;
429 let w_start = pw * self.patch_size.1;
430 let mut flat_idx = 0usize;
431 for c in 0..self.in_channels {
432 for dy in 0..self.patch_size.0 {
433 for dx in 0..self.patch_size.1 {
434 d_input[[b, c, h_start + dy, w_start + dx]] +=
435 d_patches[[b, patch_idx, flat_idx]];
436 flat_idx += 1;
437 }
438 }
439 }
440 }
441 }
442 }
443
444 Ok(d_input)
445 }
446
447 fn update(&mut self, learning_rate: F) -> Result<()> {
448 let d_weight = {
449 self.d_weight
450 .read()
451 .expect("RwLock poisoned on d_weight read")
452 .clone()
453 };
454 let d_bias = {
455 self.d_bias
456 .read()
457 .expect("RwLock poisoned on d_bias read")
458 .clone()
459 };
460
461 for e in 0..self.embed_dim {
463 for k in 0..self.patch_dim {
464 self.weight[[e, k]] -= learning_rate * d_weight[[e, k]];
465 }
466 }
467
468 if self.use_bias {
470 for e in 0..self.embed_dim {
471 self.bias[e] -= learning_rate * d_bias[e];
472 }
473 }
474
475 Ok(())
476 }
477
478 fn as_any(&self) -> &dyn std::any::Any {
479 self
480 }
481
482 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
483 self
484 }
485
486 fn layer_type(&self) -> &str {
487 "PatchEmbedding"
488 }
489
490 fn parameter_count(&self) -> usize {
491 let weight_params = self.embed_dim * self.patch_dim;
492 let bias_params = if self.use_bias { self.embed_dim } else { 0 };
493 weight_params + bias_params
494 }
495
496 fn layer_description(&self) -> String {
497 format!(
498 "type:PatchEmbedding, image_size:{:?}, patch_size:{:?}, in_channels:{}, embed_dim:{}, num_patches:{}, params:{}",
499 self.image_size,
500 self.patch_size,
501 self.in_channels,
502 self.embed_dim,
503 self.num_patches(),
504 self.parameter_count()
505 )
506 }
507
508 fn params(&self) -> Vec<Array<F, IxDyn>> {
509 if self.use_bias {
510 vec![self.weight.clone(), self.bias.clone()]
511 } else {
512 vec![self.weight.clone()]
513 }
514 }
515
516 fn set_params(&mut self, params: &[Array<F, IxDyn>]) -> Result<()> {
517 if params.is_empty() {
518 return Err(NeuralError::InvalidArchitecture(
519 "PatchEmbedding set_params: expected at least 1 parameter (weight)".to_string(),
520 ));
521 }
522 self.weight = params[0].clone();
523 if self.use_bias && params.len() >= 2 {
524 self.bias = params[1].clone();
525 }
526 Ok(())
527 }
528
529 fn inputshape(&self) -> Option<Vec<usize>> {
530 Some(vec![self.in_channels, self.image_size.0, self.image_size.1])
531 }
532
533 fn outputshape(&self) -> Option<Vec<usize>> {
534 Some(vec![self.num_patches(), self.embed_dim])
535 }
536}
537
538impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> ParamLayer<F>
539 for PatchEmbedding<F>
540{
541 fn get_parameters(&self) -> Vec<Array<F, IxDyn>> {
542 self.params()
543 }
544
545 fn get_gradients(&self) -> Vec<Array<F, IxDyn>> {
546 let dw = self
547 .d_weight
548 .read()
549 .expect("RwLock poisoned on d_weight read")
550 .clone();
551 if self.use_bias {
552 let db = self
553 .d_bias
554 .read()
555 .expect("RwLock poisoned on d_bias read")
556 .clone();
557 vec![dw, db]
558 } else {
559 vec![dw]
560 }
561 }
562
563 fn set_parameters(&mut self, params: Vec<Array<F, IxDyn>>) -> Result<()> {
564 self.set_params(¶ms)
565 }
566}
567
568unsafe impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Send for PatchEmbedding<F> {}
570unsafe impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Sync for PatchEmbedding<F> {}
571
572#[cfg(test)]
573mod tests {
574 use super::*;
575 use scirs2_core::random::rngs::SmallRng;
576 use scirs2_core::random::SeedableRng;
577
578 fn make_embed(
579 image_size: (usize, usize),
580 patch_size: (usize, usize),
581 in_channels: usize,
582 embed_dim: usize,
583 ) -> PatchEmbedding<f64> {
584 let mut rng = SmallRng::from_seed([0u8; 32]);
585 PatchEmbedding::new(
586 image_size,
587 patch_size,
588 in_channels,
589 embed_dim,
590 true,
591 &mut rng,
592 )
593 .expect("Failed to construct PatchEmbedding")
594 }
595
596 #[test]
597 fn test_patch_embedding_output_shape() {
598 let layer = make_embed((8, 8), (2, 2), 3, 32);
600 assert_eq!(layer.num_patches(), 16);
601 assert_eq!(layer.patch_dim(), 3 * 2 * 2);
602
603 let batch = 2usize;
604 let input = Array::zeros(IxDyn(&[batch, 3, 8, 8]));
605 let output = layer.forward(&input).expect("Forward pass failed");
606 assert_eq!(output.shape(), &[batch, 16, 32]);
607 }
608
609 #[test]
610 fn test_patch_embedding_parameter_count() {
611 let layer = make_embed((16, 16), (4, 4), 3, 64);
612 assert_eq!(layer.parameter_count(), 64 * 48 + 64);
614 }
615
616 #[test]
617 fn test_patch_embedding_backward_shape() {
618 let layer = make_embed((8, 8), (2, 2), 3, 32);
619 let batch = 2usize;
620 let input = Array::zeros(IxDyn(&[batch, 3, 8, 8]));
621 let output = layer.forward(&input).expect("Forward failed");
622 let grad_out = Array::ones(output.raw_dim());
623 let grad_in = layer
624 .backward(&input, &grad_out)
625 .expect("Backward pass failed");
626 assert_eq!(grad_in.shape(), input.shape());
628 }
629
630 #[test]
631 fn test_patch_embedding_invalid_size() {
632 let mut rng = SmallRng::from_seed([0u8; 32]);
633 let result = PatchEmbedding::<f64>::new((7, 8), (4, 4), 3, 32, true, &mut rng);
635 assert!(result.is_err());
636 }
637
638 #[test]
639 fn test_patch_embedding_update() {
640 let mut layer = make_embed((8, 8), (2, 2), 1, 16);
641 let input = Array::zeros(IxDyn(&[1, 1, 8, 8]));
642 let output = layer.forward(&input).expect("Forward failed");
643 let grad_out = Array::ones(output.raw_dim());
644 layer.backward(&input, &grad_out).expect("Backward failed");
645 layer.update(0.01f64).expect("Update failed");
646 }
647
648 #[test]
649 fn test_patch_embedding_round_trip_params() {
650 let mut layer = make_embed((8, 8), (2, 2), 3, 16);
652 let input = Array::ones(IxDyn(&[1, 3, 8, 8]));
653
654 let out_before = layer
655 .forward(&input)
656 .expect("Forward before round-trip failed");
657
658 let params = layer.get_parameters();
660 layer
661 .set_parameters(params.clone())
662 .expect("set_parameters failed");
663
664 let out_after = layer
665 .forward(&input)
666 .expect("Forward after round-trip failed");
667
668 assert_eq!(out_before.shape(), out_after.shape());
670 for (a, b) in out_before.iter().zip(out_after.iter()) {
671 assert!((a - b).abs() < 1e-12, "round-trip mismatch: {a} vs {b}");
672 }
673
674 assert_eq!(params.len(), 2); assert_eq!(params[0].len(), 16 * 3 * 2 * 2); assert_eq!(params[1].len(), 16); }
681}