Skip to main content

yscv_model/
blocks.rs

1use yscv_kernels::{LayerNormLastDimParams, layer_norm_last_dim, matmul_2d};
2use yscv_tensor::Tensor;
3
4use crate::{ModelError, SequentialModel, TransformerEncoderBlock};
5
6/// Adds a ResNet-style residual block to a SequentialModel (inference-mode).
7///
8/// Structure: Conv3x3 -> BN -> ReLU -> Conv3x3 -> BN, then adds skip connection.
9/// Input/output channels must match (`channels`). Operates in NHWC.
10pub fn add_residual_block(
11    model: &mut SequentialModel,
12    channels: usize,
13    epsilon: f32,
14) -> Result<(), ModelError> {
15    let kh = 3;
16    let kw = 3;
17
18    model.add_conv2d_zero(channels, channels, kh, kw, 1, 1, true)?;
19    model.add_batch_norm2d_identity(channels, epsilon)?;
20    model.add_relu();
21    model.add_conv2d_zero(channels, channels, kh, kw, 1, 1, true)?;
22    model.add_batch_norm2d_identity(channels, epsilon)?;
23
24    Ok(())
25}
26
27/// Adds a MobileNetV2-style inverted bottleneck block to a SequentialModel.
28///
29/// Structure: Conv1x1 (expand) -> BN -> ReLU -> DepthwiseConv3x3 -> BN -> ReLU -> Conv1x1 (project) -> BN.
30/// Since we don't have depthwise as a layer yet, we approximate with a grouped conv emulation.
31/// For now this builds a standard bottleneck: 1x1 expand -> 3x3 conv -> 1x1 project.
32pub fn add_bottleneck_block(
33    model: &mut SequentialModel,
34    in_channels: usize,
35    expand_channels: usize,
36    out_channels: usize,
37    stride: usize,
38    epsilon: f32,
39) -> Result<(), ModelError> {
40    // 1x1 pointwise expansion
41    model.add_conv2d_zero(in_channels, expand_channels, 1, 1, 1, 1, false)?;
42    model.add_batch_norm2d_identity(expand_channels, epsilon)?;
43    model.add_relu();
44
45    // 3x3 spatial convolution
46    model.add_conv2d_zero(
47        expand_channels,
48        expand_channels,
49        3,
50        3,
51        stride,
52        stride,
53        false,
54    )?;
55    model.add_batch_norm2d_identity(expand_channels, epsilon)?;
56    model.add_relu();
57
58    // 1x1 pointwise projection
59    model.add_conv2d_zero(expand_channels, out_channels, 1, 1, 1, 1, false)?;
60    model.add_batch_norm2d_identity(out_channels, epsilon)?;
61
62    Ok(())
63}
64
65/// Builds a simple CNN classifier architecture for NHWC input.
66///
67/// Architecture: [Conv->BN->ReLU->MaxPool] x stages -> GlobalAvgPool -> Flatten -> Linear.
68/// This is a convenient builder for common CV classification tasks.
69pub fn build_simple_cnn_classifier(
70    model: &mut SequentialModel,
71    graph: &mut yscv_autograd::Graph,
72    input_channels: usize,
73    num_classes: usize,
74    stage_channels: &[usize],
75    epsilon: f32,
76) -> Result<(), ModelError> {
77    let mut ch = input_channels;
78    for &out_ch in stage_channels {
79        model.add_conv2d_zero(ch, out_ch, 3, 3, 1, 1, true)?;
80        model.add_batch_norm2d_identity(out_ch, epsilon)?;
81        model.add_relu();
82        model.add_max_pool2d(2, 2, 2, 2)?;
83        ch = out_ch;
84    }
85    model.add_global_avg_pool2d();
86    model.add_flatten();
87
88    let weight = Tensor::from_vec(vec![ch, num_classes], vec![0.0; ch * num_classes])?;
89    let bias = Tensor::from_vec(vec![num_classes], vec![0.0; num_classes])?;
90    model.add_linear(graph, ch, num_classes, weight, bias)?;
91
92    Ok(())
93}
94
95/// Squeeze-and-Excitation block (inference-mode).
96///
97/// Channel attention: GlobalAvgPool -> FC(reduce) -> ReLU -> FC(expand) -> Sigmoid -> scale.
98/// `reduction_ratio` controls bottleneck (typically 4 or 16).
99pub struct SqueezeExciteBlock {
100    pub fc_reduce_w: Tensor, // [channels, channels / reduction]
101    pub fc_reduce_b: Tensor, // [channels / reduction]
102    pub fc_expand_w: Tensor, // [channels / reduction, channels]
103    pub fc_expand_b: Tensor, // [channels]
104    pub channels: usize,
105    pub reduced: usize,
106}
107
108impl SqueezeExciteBlock {
109    pub fn new(channels: usize, reduction_ratio: usize) -> Result<Self, ModelError> {
110        let reduced = (channels / reduction_ratio).max(1);
111        Ok(Self {
112            fc_reduce_w: Tensor::from_vec(vec![channels, reduced], vec![0.0; channels * reduced])?,
113            fc_reduce_b: Tensor::from_vec(vec![reduced], vec![0.0; reduced])?,
114            fc_expand_w: Tensor::from_vec(vec![reduced, channels], vec![0.0; reduced * channels])?,
115            fc_expand_b: Tensor::from_vec(vec![channels], vec![0.0; channels])?,
116            channels,
117            reduced,
118        })
119    }
120
121    /// Forward: input `[N,H,W,C]` -> scaled `[N,H,W,C]`.
122    pub fn forward(&self, input: &Tensor) -> Result<Tensor, ModelError> {
123        let shape = input.shape();
124        let (n, h, w, c) = (shape[0], shape[1], shape[2], shape[3]);
125        let data = input.data();
126
127        // Global average pool -> [N, C]
128        let hw = (h * w) as f32;
129        let mut pooled = vec![0.0f32; n * c];
130        for b in 0..n {
131            for ch in 0..c {
132                let mut sum = 0.0f32;
133                for y in 0..h {
134                    for x in 0..w {
135                        sum += data[((b * h + y) * w + x) * c + ch];
136                    }
137                }
138                pooled[b * c + ch] = sum / hw;
139            }
140        }
141        let pooled_t = Tensor::from_vec(vec![n, c], pooled)?;
142
143        // FC reduce + ReLU
144        let reduced = yscv_kernels::matmul_2d(&pooled_t, &self.fc_reduce_w)?;
145        let reduced = reduced.add(&self.fc_reduce_b.unsqueeze(0)?)?;
146        let reduced_data: Vec<f32> = reduced.data().iter().map(|&v| v.max(0.0)).collect();
147        let reduced = Tensor::from_vec(vec![n, self.reduced], reduced_data)?;
148
149        // FC expand + Sigmoid
150        let expanded = yscv_kernels::matmul_2d(&reduced, &self.fc_expand_w)?;
151        let expanded = expanded.add(&self.fc_expand_b.unsqueeze(0)?)?;
152        let scale_data: Vec<f32> = expanded
153            .data()
154            .iter()
155            .map(|&v| 1.0 / (1.0 + (-v).exp()))
156            .collect();
157
158        // Scale input channels
159        let mut out = Vec::with_capacity(n * h * w * c);
160        for b in 0..n {
161            for y in 0..h {
162                for x in 0..w {
163                    for ch in 0..c {
164                        out.push(data[((b * h + y) * w + x) * c + ch] * scale_data[b * c + ch]);
165                    }
166                }
167            }
168        }
169        Tensor::from_vec(shape.to_vec(), out).map_err(Into::into)
170    }
171}
172
173/// MBConv block (EfficientNet / MobileNetV2 inverted residual, inference-mode).
174///
175/// Structure: expand 1x1 -> BN -> activation -> depthwise 3x3 -> BN -> activation -> SE -> project 1x1 -> BN.
176/// Supports optional skip connection when stride=1 and in_channels=out_channels.
177pub struct MbConvBlock {
178    pub expand_conv: Option<crate::Conv2dLayer>,
179    pub expand_bn: Option<crate::BatchNorm2dLayer>,
180    pub depthwise_w: Tensor, // [kh, kw, expanded_ch, 1]
181    pub depthwise_bn: crate::BatchNorm2dLayer,
182    pub se: Option<SqueezeExciteBlock>,
183    pub project_conv: crate::Conv2dLayer,
184    pub project_bn: crate::BatchNorm2dLayer,
185    pub use_residual: bool,
186    pub expanded_ch: usize,
187}
188
189impl MbConvBlock {
190    #[allow(clippy::too_many_arguments)]
191    pub fn new(
192        in_channels: usize,
193        out_channels: usize,
194        expand_ratio: usize,
195        kernel_size: usize,
196        stride: usize,
197        se_ratio: Option<usize>,
198        epsilon: f32,
199    ) -> Result<Self, ModelError> {
200        let expanded_ch = in_channels * expand_ratio;
201        let use_residual = stride == 1 && in_channels == out_channels;
202
203        let (expand_conv, expand_bn) = if expand_ratio != 1 {
204            let w = Tensor::from_vec(
205                vec![1, 1, in_channels, expanded_ch],
206                vec![0.0; in_channels * expanded_ch],
207            )?;
208            let b = Tensor::from_vec(vec![expanded_ch], vec![0.0; expanded_ch])?;
209            (
210                Some(crate::Conv2dLayer::new(
211                    in_channels,
212                    expanded_ch,
213                    1,
214                    1,
215                    1,
216                    1,
217                    w,
218                    Some(b),
219                )?),
220                Some(crate::BatchNorm2dLayer::identity_init(
221                    expanded_ch,
222                    epsilon,
223                )?),
224            )
225        } else {
226            (None, None)
227        };
228
229        let depthwise_w = Tensor::from_vec(
230            vec![kernel_size, kernel_size, expanded_ch, 1],
231            vec![0.0; kernel_size * kernel_size * expanded_ch],
232        )?;
233        let depthwise_bn = crate::BatchNorm2dLayer::identity_init(expanded_ch, epsilon)?;
234
235        let se = se_ratio
236            .map(|r| SqueezeExciteBlock::new(expanded_ch, r))
237            .transpose()?;
238
239        let proj_w = Tensor::from_vec(
240            vec![1, 1, expanded_ch, out_channels],
241            vec![0.0; expanded_ch * out_channels],
242        )?;
243        let proj_b = Tensor::from_vec(vec![out_channels], vec![0.0; out_channels])?;
244        let project_conv =
245            crate::Conv2dLayer::new(expanded_ch, out_channels, 1, 1, 1, 1, proj_w, Some(proj_b))?;
246        let project_bn = crate::BatchNorm2dLayer::identity_init(out_channels, epsilon)?;
247
248        Ok(Self {
249            expand_conv,
250            expand_bn,
251            depthwise_w,
252            depthwise_bn,
253            se,
254            project_conv,
255            project_bn,
256            use_residual,
257            expanded_ch,
258        })
259    }
260
261    /// Forward inference: input `[N,H,W,C_in]` -> `[N,H',W',C_out]`.
262    pub fn forward(&self, input: &Tensor) -> Result<Tensor, ModelError> {
263        let mut x = input.clone();
264
265        // Expand phase
266        if let (Some(conv), Some(bn)) = (&self.expand_conv, &self.expand_bn) {
267            x = conv.forward_inference(&x)?;
268            x = bn.forward_inference(&x)?;
269            let data: Vec<f32> = x.data().iter().map(|&v| v.clamp(0.0, 6.0)).collect();
270            x = Tensor::from_vec(x.shape().to_vec(), data)?;
271        }
272
273        // Depthwise conv (using kernel directly)
274        x = yscv_kernels::depthwise_conv2d_nhwc(&x, &self.depthwise_w, None, 1, 1)?;
275        x = self.depthwise_bn.forward_inference(&x)?;
276        let data: Vec<f32> = x.data().iter().map(|&v| v.clamp(0.0, 6.0)).collect();
277        x = Tensor::from_vec(x.shape().to_vec(), data)?;
278
279        // SE
280        if let Some(se) = &self.se {
281            x = se.forward(&x)?;
282        }
283
284        // Project
285        x = self.project_conv.forward_inference(&x)?;
286        x = self.project_bn.forward_inference(&x)?;
287
288        // Residual skip
289        if self.use_residual {
290            x = x.add(input)?;
291        }
292
293        Ok(x)
294    }
295}
296
297/// Builds a ResNet-like feature extractor (no final classifier).
298///
299/// Architecture: initial Conv7x7->BN->ReLU->MaxPool, then residual stages.
300pub fn build_resnet_feature_extractor(
301    model: &mut SequentialModel,
302    input_channels: usize,
303    stage_channels: &[usize],
304    blocks_per_stage: usize,
305    epsilon: f32,
306) -> Result<(), ModelError> {
307    let initial_ch = stage_channels.first().copied().unwrap_or(64);
308
309    // Stem: Conv7x7 stride 2 -> BN -> ReLU -> MaxPool
310    model.add_conv2d_zero(input_channels, initial_ch, 7, 7, 2, 2, true)?;
311    model.add_batch_norm2d_identity(initial_ch, epsilon)?;
312    model.add_relu();
313    model.add_max_pool2d(3, 3, 2, 2)?;
314
315    let mut ch = initial_ch;
316    for &stage_ch in stage_channels {
317        if stage_ch != ch {
318            // Channel transition: 1x1 conv to match dimensions
319            model.add_conv2d_zero(ch, stage_ch, 1, 1, 1, 1, false)?;
320            model.add_batch_norm2d_identity(stage_ch, epsilon)?;
321            model.add_relu();
322        }
323        for _ in 0..blocks_per_stage {
324            add_residual_block(model, stage_ch, epsilon)?;
325        }
326        ch = stage_ch;
327    }
328
329    model.add_global_avg_pool2d();
330    model.add_flatten();
331
332    Ok(())
333}
334
335/// UNet encoder stage (inference-mode, NHWC).
336///
337/// Structure: Conv3x3 -> BN -> ReLU -> Conv3x3 -> BN -> ReLU.
338/// Returns the number of output channels for skip-connection bookkeeping.
339pub struct UNetEncoderStage {
340    conv1: crate::Conv2dLayer,
341    bn1: crate::BatchNorm2dLayer,
342    conv2: crate::Conv2dLayer,
343    bn2: crate::BatchNorm2dLayer,
344}
345
346impl UNetEncoderStage {
347    pub fn new(in_ch: usize, out_ch: usize, epsilon: f32) -> Result<Self, ModelError> {
348        let w1 = Tensor::from_vec(vec![3, 3, in_ch, out_ch], vec![0.0; 9 * in_ch * out_ch])?;
349        let b1 = Tensor::from_vec(vec![out_ch], vec![0.0; out_ch])?;
350        let w2 = Tensor::from_vec(vec![3, 3, out_ch, out_ch], vec![0.0; 9 * out_ch * out_ch])?;
351        let b2 = Tensor::from_vec(vec![out_ch], vec![0.0; out_ch])?;
352        Ok(Self {
353            conv1: crate::Conv2dLayer::new(in_ch, out_ch, 3, 3, 1, 1, w1, Some(b1))?,
354            bn1: crate::BatchNorm2dLayer::identity_init(out_ch, epsilon)?,
355            conv2: crate::Conv2dLayer::new(out_ch, out_ch, 3, 3, 1, 1, w2, Some(b2))?,
356            bn2: crate::BatchNorm2dLayer::identity_init(out_ch, epsilon)?,
357        })
358    }
359
360    pub fn forward(&self, input: &Tensor) -> Result<Tensor, ModelError> {
361        let x = self.conv1.forward_inference(input)?;
362        let x = self.bn1.forward_inference(&x)?;
363        let x = relu_nhwc(&x)?;
364        let x = self.conv2.forward_inference(&x)?;
365        let x = self.bn2.forward_inference(&x)?;
366        relu_nhwc(&x)
367    }
368}
369
370/// UNet decoder stage (inference-mode, NHWC).
371///
372/// Structure: nearest-neighbor 2x upsample -> cat(skip) -> Conv3x3 -> BN -> ReLU -> Conv3x3 -> BN -> ReLU.
373pub struct UNetDecoderStage {
374    conv1: crate::Conv2dLayer,
375    bn1: crate::BatchNorm2dLayer,
376    conv2: crate::Conv2dLayer,
377    bn2: crate::BatchNorm2dLayer,
378}
379
380impl UNetDecoderStage {
381    pub fn new(
382        in_ch: usize,
383        skip_ch: usize,
384        out_ch: usize,
385        epsilon: f32,
386    ) -> Result<Self, ModelError> {
387        let cat_ch = in_ch + skip_ch;
388        let w1 = Tensor::from_vec(vec![3, 3, cat_ch, out_ch], vec![0.0; 9 * cat_ch * out_ch])?;
389        let b1 = Tensor::from_vec(vec![out_ch], vec![0.0; out_ch])?;
390        let w2 = Tensor::from_vec(vec![3, 3, out_ch, out_ch], vec![0.0; 9 * out_ch * out_ch])?;
391        let b2 = Tensor::from_vec(vec![out_ch], vec![0.0; out_ch])?;
392        Ok(Self {
393            conv1: crate::Conv2dLayer::new(cat_ch, out_ch, 3, 3, 1, 1, w1, Some(b1))?,
394            bn1: crate::BatchNorm2dLayer::identity_init(out_ch, epsilon)?,
395            conv2: crate::Conv2dLayer::new(out_ch, out_ch, 3, 3, 1, 1, w2, Some(b2))?,
396            bn2: crate::BatchNorm2dLayer::identity_init(out_ch, epsilon)?,
397        })
398    }
399
400    /// Forward: `upsampled` is the feature from the lower level, `skip` from the encoder.
401    pub fn forward(&self, upsampled: &Tensor, skip: &Tensor) -> Result<Tensor, ModelError> {
402        let up = upsample_nearest_2x_nhwc(upsampled)?;
403        let cat = cat_nhwc_channel(&up, skip)?;
404        let x = self.conv1.forward_inference(&cat)?;
405        let x = self.bn1.forward_inference(&x)?;
406        let x = relu_nhwc(&x)?;
407        let x = self.conv2.forward_inference(&x)?;
408        let x = self.bn2.forward_inference(&x)?;
409        relu_nhwc(&x)
410    }
411}
412
413/// Feature Pyramid Network lateral + top-down pathway (inference-mode, NHWC).
414///
415/// Reduces each backbone level to `out_channels` via 1x1 conv, then top-down
416/// merges with 2x nearest-neighbor upsample + element-wise add + 3x3 smoothing.
417pub struct FpnNeck {
418    lateral_convs: Vec<crate::Conv2dLayer>,
419    smooth_convs: Vec<crate::Conv2dLayer>,
420    num_levels: usize,
421}
422
423impl FpnNeck {
424    pub fn new(in_channels: &[usize], out_channels: usize) -> Result<Self, ModelError> {
425        let mut lateral_convs = Vec::with_capacity(in_channels.len());
426        let mut smooth_convs = Vec::with_capacity(in_channels.len());
427        for &ch in in_channels {
428            let w = Tensor::from_vec(vec![1, 1, ch, out_channels], vec![0.0; ch * out_channels])?;
429            let b = Tensor::from_vec(vec![out_channels], vec![0.0; out_channels])?;
430            lateral_convs.push(crate::Conv2dLayer::new(
431                ch,
432                out_channels,
433                1,
434                1,
435                1,
436                1,
437                w,
438                Some(b),
439            )?);
440
441            let w3 = Tensor::from_vec(
442                vec![3, 3, out_channels, out_channels],
443                vec![0.0; 9 * out_channels * out_channels],
444            )?;
445            let b3 = Tensor::from_vec(vec![out_channels], vec![0.0; out_channels])?;
446            smooth_convs.push(crate::Conv2dLayer::new(
447                out_channels,
448                out_channels,
449                3,
450                3,
451                1,
452                1,
453                w3,
454                Some(b3),
455            )?);
456        }
457        Ok(Self {
458            lateral_convs,
459            smooth_convs,
460            num_levels: in_channels.len(),
461        })
462    }
463
464    /// Forward: `features` is a list of backbone feature maps from finest to coarsest.
465    /// Returns FPN outputs at the same spatial resolutions.
466    pub fn forward(&self, features: &[Tensor]) -> Result<Vec<Tensor>, ModelError> {
467        if features.len() != self.num_levels {
468            return Err(ModelError::InvalidInputShape {
469                expected_features: self.num_levels,
470                got: vec![features.len()],
471            });
472        }
473        let mut laterals: Vec<Tensor> = Vec::with_capacity(self.num_levels);
474        for (i, feat) in features.iter().enumerate() {
475            laterals.push(self.lateral_convs[i].forward_inference(feat)?);
476        }
477
478        for i in (0..self.num_levels - 1).rev() {
479            let up = upsample_nearest_2x_nhwc(&laterals[i + 1])?;
480            let shape_i = laterals[i].shape();
481            let shape_up = up.shape();
482            let min_h = shape_i[1].min(shape_up[1]);
483            let min_w = shape_i[2].min(shape_up[2]);
484            let cropped_lat = crop_nhwc(&laterals[i], min_h, min_w)?;
485            let cropped_up = crop_nhwc(&up, min_h, min_w)?;
486            laterals[i] = cropped_lat.add(&cropped_up)?;
487        }
488
489        let mut outputs = Vec::with_capacity(self.num_levels);
490        for (i, lat) in laterals.iter().enumerate() {
491            outputs.push(self.smooth_convs[i].forward_inference(lat)?);
492        }
493        Ok(outputs)
494    }
495}
496
497/// Anchor-free detection head (FCOS-style, inference-mode, NHWC).
498///
499/// Per-pixel classification + centerness + bbox regression.
500/// Operates on a single FPN level feature map.
501pub struct AnchorFreeHead {
502    cls_convs: Vec<(crate::Conv2dLayer, crate::BatchNorm2dLayer)>,
503    reg_convs: Vec<(crate::Conv2dLayer, crate::BatchNorm2dLayer)>,
504    cls_out: crate::Conv2dLayer,
505    reg_out: crate::Conv2dLayer,
506    centerness_out: crate::Conv2dLayer,
507}
508
509impl AnchorFreeHead {
510    pub fn new(
511        in_channels: usize,
512        num_classes: usize,
513        num_convs: usize,
514        epsilon: f32,
515    ) -> Result<Self, ModelError> {
516        let mut cls_convs = Vec::with_capacity(num_convs);
517        let mut reg_convs = Vec::with_capacity(num_convs);
518        let mut ch = in_channels;
519        for _ in 0..num_convs {
520            let wc =
521                Tensor::from_vec(vec![3, 3, ch, in_channels], vec![0.0; 9 * ch * in_channels])?;
522            let bc = Tensor::from_vec(vec![in_channels], vec![0.0; in_channels])?;
523            let bnc = crate::BatchNorm2dLayer::identity_init(in_channels, epsilon)?;
524            cls_convs.push((
525                crate::Conv2dLayer::new(ch, in_channels, 3, 3, 1, 1, wc, Some(bc))?,
526                bnc,
527            ));
528
529            let wr =
530                Tensor::from_vec(vec![3, 3, ch, in_channels], vec![0.0; 9 * ch * in_channels])?;
531            let br = Tensor::from_vec(vec![in_channels], vec![0.0; in_channels])?;
532            let bnr = crate::BatchNorm2dLayer::identity_init(in_channels, epsilon)?;
533            reg_convs.push((
534                crate::Conv2dLayer::new(ch, in_channels, 3, 3, 1, 1, wr, Some(br))?,
535                bnr,
536            ));
537            ch = in_channels;
538        }
539
540        let wco = Tensor::from_vec(
541            vec![3, 3, in_channels, num_classes],
542            vec![0.0; 9 * in_channels * num_classes],
543        )?;
544        let bco = Tensor::from_vec(vec![num_classes], vec![0.0; num_classes])?;
545        let cls_out =
546            crate::Conv2dLayer::new(in_channels, num_classes, 3, 3, 1, 1, wco, Some(bco))?;
547
548        let wro = Tensor::from_vec(vec![3, 3, in_channels, 4], vec![0.0; 9 * in_channels * 4])?;
549        let bro = Tensor::from_vec(vec![4], vec![0.0; 4])?;
550        let reg_out = crate::Conv2dLayer::new(in_channels, 4, 3, 3, 1, 1, wro, Some(bro))?;
551
552        let wcn = Tensor::from_vec(vec![3, 3, in_channels, 1], vec![0.0; 9 * in_channels])?;
553        let bcn = Tensor::from_vec(vec![1], vec![0.0; 1])?;
554        let centerness_out = crate::Conv2dLayer::new(in_channels, 1, 3, 3, 1, 1, wcn, Some(bcn))?;
555
556        Ok(Self {
557            cls_convs,
558            reg_convs,
559            cls_out,
560            reg_out,
561            centerness_out,
562        })
563    }
564
565    /// Forward on single feature map `[N, H, W, C]`.
566    /// Returns `(cls_logits [N,H,W,num_classes], bbox_pred [N,H,W,4], centerness [N,H,W,1])`.
567    pub fn forward(&self, input: &Tensor) -> Result<(Tensor, Tensor, Tensor), ModelError> {
568        let mut cls_feat = input.clone();
569        for (conv, bn) in &self.cls_convs {
570            cls_feat = conv.forward_inference(&cls_feat)?;
571            cls_feat = bn.forward_inference(&cls_feat)?;
572            cls_feat = relu_nhwc(&cls_feat)?;
573        }
574
575        let mut reg_feat = input.clone();
576        for (conv, bn) in &self.reg_convs {
577            reg_feat = conv.forward_inference(&reg_feat)?;
578            reg_feat = bn.forward_inference(&reg_feat)?;
579            reg_feat = relu_nhwc(&reg_feat)?;
580        }
581
582        let cls_logits = self.cls_out.forward_inference(&cls_feat)?;
583        let bbox_pred = self.reg_out.forward_inference(&reg_feat)?;
584        let centerness = self.centerness_out.forward_inference(&cls_feat)?;
585
586        Ok((cls_logits, bbox_pred, centerness))
587    }
588}
589
590fn relu_nhwc(t: &Tensor) -> Result<Tensor, ModelError> {
591    let data: Vec<f32> = t.data().iter().map(|&v| v.max(0.0)).collect();
592    Tensor::from_vec(t.shape().to_vec(), data).map_err(Into::into)
593}
594
595fn upsample_nearest_2x_nhwc(t: &Tensor) -> Result<Tensor, ModelError> {
596    let shape = t.shape();
597    let (n, h, w, c) = (shape[0], shape[1], shape[2], shape[3]);
598    let new_h = h * 2;
599    let new_w = w * 2;
600    let data = t.data();
601    let mut out = vec![0.0f32; n * new_h * new_w * c];
602    for b in 0..n {
603        for y in 0..new_h {
604            for x in 0..new_w {
605                let sy = y / 2;
606                let sx = x / 2;
607                let src_off = ((b * h + sy) * w + sx) * c;
608                let dst_off = ((b * new_h + y) * new_w + x) * c;
609                out[dst_off..dst_off + c].copy_from_slice(&data[src_off..src_off + c]);
610            }
611        }
612    }
613    Tensor::from_vec(vec![n, new_h, new_w, c], out).map_err(Into::into)
614}
615
616fn cat_nhwc_channel(a: &Tensor, b: &Tensor) -> Result<Tensor, ModelError> {
617    let sa = a.shape();
618    let sb = b.shape();
619    let (n, h, w) = (sa[0], sa[1], sa[2]);
620    let ca = sa[3];
621    let cb = sb[3];
622    let da = a.data();
623    let db = b.data();
624    let mut out = vec![0.0f32; n * h * w * (ca + cb)];
625    for b_idx in 0..n {
626        for y in 0..h {
627            for x in 0..w {
628                let src_a = ((b_idx * h + y) * w + x) * ca;
629                let src_b = ((b_idx * h + y) * w + x) * cb;
630                let dst = ((b_idx * h + y) * w + x) * (ca + cb);
631                out[dst..dst + ca].copy_from_slice(&da[src_a..src_a + ca]);
632                out[dst + ca..dst + ca + cb].copy_from_slice(&db[src_b..src_b + cb]);
633            }
634        }
635    }
636    Tensor::from_vec(vec![n, h, w, ca + cb], out).map_err(Into::into)
637}
638
639fn crop_nhwc(t: &Tensor, target_h: usize, target_w: usize) -> Result<Tensor, ModelError> {
640    let shape = t.shape();
641    let (n, _h, w, c) = (shape[0], shape[1], shape[2], shape[3]);
642    let data = t.data();
643    let mut out = vec![0.0f32; n * target_h * target_w * c];
644    for b in 0..n {
645        for y in 0..target_h {
646            let src_off = ((b * _h + y) * w) * c;
647            let dst_off = ((b * target_h + y) * target_w) * c;
648            for x in 0..target_w {
649                let so = src_off + x * c;
650                let do_ = dst_off + x * c;
651                out[do_..do_ + c].copy_from_slice(&data[so..so + c]);
652            }
653        }
654    }
655    Tensor::from_vec(vec![n, target_h, target_w, c], out).map_err(Into::into)
656}
657
658// ---------------------------------------------------------------------------
659// Vision Transformer (ViT) components
660// ---------------------------------------------------------------------------
661
662/// Patch embedding layer for Vision Transformer.
663///
664/// Splits an NHWC image into non-overlapping patches and projects each patch
665/// to an embedding vector via a linear projection (equivalent to Conv2d with
666/// kernel_size=patch_size, stride=patch_size).
667///
668/// Input:  `[batch, H, W, C]` (NHWC)
669/// Output: `[batch, num_patches, embed_dim]`
670pub struct PatchEmbedding {
671    /// Linear projection weight: `[patch_size * patch_size * in_channels, embed_dim]`
672    pub projection_w: Tensor,
673    /// Linear projection bias: `[embed_dim]`
674    pub projection_b: Tensor,
675    pub image_size: usize,
676    pub patch_size: usize,
677    pub in_channels: usize,
678    pub embed_dim: usize,
679    pub num_patches: usize,
680}
681
682impl PatchEmbedding {
683    /// Creates a zero-initialized patch embedding.
684    pub fn new(
685        image_size: usize,
686        patch_size: usize,
687        in_channels: usize,
688        embed_dim: usize,
689    ) -> Result<Self, ModelError> {
690        if !image_size.is_multiple_of(patch_size) {
691            return Err(ModelError::InvalidParameterShape {
692                parameter: "image_size must be divisible by patch_size",
693                expected: vec![image_size, patch_size],
694                got: vec![image_size % patch_size],
695            });
696        }
697        let num_patches = (image_size / patch_size) * (image_size / patch_size);
698        let patch_dim = patch_size * patch_size * in_channels;
699        Ok(Self {
700            projection_w: Tensor::from_vec(
701                vec![patch_dim, embed_dim],
702                vec![0.0; patch_dim * embed_dim],
703            )?,
704            projection_b: Tensor::from_vec(vec![embed_dim], vec![0.0; embed_dim])?,
705            image_size,
706            patch_size,
707            in_channels,
708            embed_dim,
709            num_patches,
710        })
711    }
712
713    /// Forward: `[batch, H, W, C]` -> `[batch, num_patches, embed_dim]`.
714    pub fn forward(&self, input: &Tensor) -> Result<Tensor, ModelError> {
715        let shape = input.shape();
716        let (batch, h, w, c) = (shape[0], shape[1], shape[2], shape[3]);
717        let ps = self.patch_size;
718        let grid_h = h / ps;
719        let grid_w = w / ps;
720        let num_patches = grid_h * grid_w;
721        let patch_dim = ps * ps * c;
722        let data = input.data();
723
724        // Extract patches -> [batch, num_patches, patch_dim]
725        let mut patches = vec![0.0f32; batch * num_patches * patch_dim];
726        for b in 0..batch {
727            for gh in 0..grid_h {
728                for gw in 0..grid_w {
729                    let patch_idx = gh * grid_w + gw;
730                    let dst_base = (b * num_patches + patch_idx) * patch_dim;
731                    let mut offset = 0;
732                    for ph in 0..ps {
733                        for pw in 0..ps {
734                            let iy = gh * ps + ph;
735                            let ix = gw * ps + pw;
736                            let src = ((b * h + iy) * w + ix) * c;
737                            patches[dst_base + offset..dst_base + offset + c]
738                                .copy_from_slice(&data[src..src + c]);
739                            offset += c;
740                        }
741                    }
742                }
743            }
744        }
745
746        // Project: [batch * num_patches, patch_dim] @ [patch_dim, embed_dim] -> [batch * num_patches, embed_dim]
747        let patches_t = Tensor::from_vec(vec![batch * num_patches, patch_dim], patches)?;
748        let projected = matmul_2d(&patches_t, &self.projection_w)?;
749        let projected = projected.add(&self.projection_b.unsqueeze(0)?)?;
750
751        // Reshape to [batch, num_patches, embed_dim]
752        projected
753            .reshape(vec![batch, num_patches, self.embed_dim])
754            .map_err(Into::into)
755    }
756}
757
758/// Vision Transformer (ViT) for image classification (inference-mode).
759///
760/// Architecture:
761/// 1. Patch embedding
762/// 2. Prepend learnable class token
763/// 3. Add learnable position embeddings
764/// 4. N transformer encoder blocks
765/// 5. Layer norm on output
766/// 6. Linear classification head on class token
767///
768/// Input:  `[batch, H, W, C]` (NHWC)
769/// Output: `[batch, num_classes]`
770pub struct VisionTransformer {
771    pub patch_embed: PatchEmbedding,
772    /// Class token: `[1, embed_dim]`
773    pub cls_token: Tensor,
774    /// Position embeddings: `[1, num_patches + 1, embed_dim]`
775    pub pos_embed: Tensor,
776    /// Transformer encoder blocks
777    pub encoder_blocks: Vec<TransformerEncoderBlock>,
778    /// Final layer norm gamma: `[embed_dim]`
779    pub ln_gamma: Tensor,
780    /// Final layer norm beta: `[embed_dim]`
781    pub ln_beta: Tensor,
782    /// Classification head weight: `[embed_dim, num_classes]`
783    pub head_w: Tensor,
784    /// Classification head bias: `[num_classes]`
785    pub head_b: Tensor,
786    pub embed_dim: usize,
787    pub num_classes: usize,
788}
789
790impl VisionTransformer {
791    /// Creates a zero-initialized Vision Transformer.
792    #[allow(clippy::too_many_arguments)]
793    pub fn new(
794        image_size: usize,
795        patch_size: usize,
796        in_channels: usize,
797        embed_dim: usize,
798        num_heads: usize,
799        num_layers: usize,
800        num_classes: usize,
801        mlp_ratio: f32,
802    ) -> Result<Self, ModelError> {
803        let patch_embed = PatchEmbedding::new(image_size, patch_size, in_channels, embed_dim)?;
804        let num_patches = patch_embed.num_patches;
805        let seq_len = num_patches + 1; // +1 for class token
806
807        let cls_token = Tensor::from_vec(vec![1, embed_dim], vec![0.0; embed_dim])?;
808        let pos_embed =
809            Tensor::from_vec(vec![1, seq_len, embed_dim], vec![0.0; seq_len * embed_dim])?;
810
811        let d_ff = (embed_dim as f32 * mlp_ratio) as usize;
812        let mut encoder_blocks = Vec::with_capacity(num_layers);
813        for _ in 0..num_layers {
814            encoder_blocks.push(TransformerEncoderBlock::new(embed_dim, num_heads, d_ff)?);
815        }
816
817        let ln_gamma = Tensor::from_vec(vec![embed_dim], vec![1.0; embed_dim])?;
818        let ln_beta = Tensor::from_vec(vec![embed_dim], vec![0.0; embed_dim])?;
819
820        let head_w = Tensor::from_vec(
821            vec![embed_dim, num_classes],
822            vec![0.0; embed_dim * num_classes],
823        )?;
824        let head_b = Tensor::from_vec(vec![num_classes], vec![0.0; num_classes])?;
825
826        Ok(Self {
827            patch_embed,
828            cls_token,
829            pos_embed,
830            encoder_blocks,
831            ln_gamma,
832            ln_beta,
833            head_w,
834            head_b,
835            embed_dim,
836            num_classes,
837        })
838    }
839
840    /// Forward inference: `[batch, H, W, C]` -> `[batch, num_classes]`.
841    pub fn forward(&self, input: &Tensor) -> Result<Tensor, ModelError> {
842        let batch = input.shape()[0];
843
844        // 1. Patch embedding -> [batch, num_patches, embed_dim]
845        let patch_tokens = self.patch_embed.forward(input)?;
846        let num_patches = patch_tokens.shape()[1];
847
848        // 2. Prepend class token -> [batch, num_patches+1, embed_dim]
849        // Expand cls_token [1, embed_dim] -> [batch, 1, embed_dim]
850        let cls_expanded = self.cls_token.repeat(&[batch, 1])?; // [batch, embed_dim]
851        let cls_expanded = cls_expanded.reshape(vec![batch, 1, self.embed_dim])?;
852
853        // Concatenate along sequence dimension
854        let seq_len = num_patches + 1;
855        let patch_data = patch_tokens.data();
856        let cls_data = cls_expanded.data();
857        let mut combined = vec![0.0f32; batch * seq_len * self.embed_dim];
858        for b in 0..batch {
859            // Copy class token
860            let cls_src = b * self.embed_dim;
861            let dst_base = b * seq_len * self.embed_dim;
862            combined[dst_base..dst_base + self.embed_dim]
863                .copy_from_slice(&cls_data[cls_src..cls_src + self.embed_dim]);
864            // Copy patch tokens
865            let patch_src = b * num_patches * self.embed_dim;
866            let patch_dst = dst_base + self.embed_dim;
867            let patch_len = num_patches * self.embed_dim;
868            combined[patch_dst..patch_dst + patch_len]
869                .copy_from_slice(&patch_data[patch_src..patch_src + patch_len]);
870        }
871        let mut x = Tensor::from_vec(vec![batch, seq_len, self.embed_dim], combined)?;
872
873        // 3. Add position embeddings (broadcast over batch)
874        let pos = self.pos_embed.repeat(&[batch, 1, 1])?;
875        x = x.add(&pos)?;
876
877        // 4. Run through transformer encoder blocks
878        // TransformerEncoderBlock expects [seq_len, d_model], so process each batch item
879        let mut out_data = vec![0.0f32; batch * seq_len * self.embed_dim];
880        for b in 0..batch {
881            // Extract [seq_len, embed_dim] for this batch
882            let start = b * seq_len * self.embed_dim;
883            let end = start + seq_len * self.embed_dim;
884            let slice = &x.data()[start..end];
885            let mut seq = Tensor::from_vec(vec![seq_len, self.embed_dim], slice.to_vec())?;
886
887            for block in &self.encoder_blocks {
888                seq = block.forward(&seq)?;
889            }
890
891            let seq_data = seq.data();
892            out_data[start..end].copy_from_slice(seq_data);
893        }
894        let x = Tensor::from_vec(vec![batch, seq_len, self.embed_dim], out_data)?;
895
896        // 5. Layer norm on full output
897        let x_2d = x.reshape(vec![batch * seq_len, self.embed_dim])?;
898        let params = LayerNormLastDimParams {
899            gamma: &self.ln_gamma,
900            beta: &self.ln_beta,
901            epsilon: 1e-5,
902        };
903        let normed = layer_norm_last_dim(&x_2d, params)?;
904        let normed = normed.reshape(vec![batch, seq_len, self.embed_dim])?;
905
906        // 6. Extract class token (index 0 along seq dimension) -> [batch, embed_dim]
907        let normed_data = normed.data();
908        let mut cls_out = vec![0.0f32; batch * self.embed_dim];
909        for b in 0..batch {
910            let src = b * seq_len * self.embed_dim;
911            cls_out[b * self.embed_dim..(b + 1) * self.embed_dim]
912                .copy_from_slice(&normed_data[src..src + self.embed_dim]);
913        }
914        let cls_features = Tensor::from_vec(vec![batch, self.embed_dim], cls_out)?;
915
916        // 7. Classification head: linear projection
917        let logits = matmul_2d(&cls_features, &self.head_w)?;
918        let logits = logits.add(&self.head_b.unsqueeze(0)?)?;
919
920        Ok(logits)
921    }
922}