1use yscv_kernels::{LayerNormLastDimParams, layer_norm_last_dim, matmul_2d};
2use yscv_tensor::Tensor;
3
4use crate::{ModelError, SequentialModel, TransformerEncoderBlock};
5
6pub fn add_residual_block(
11 model: &mut SequentialModel,
12 channels: usize,
13 epsilon: f32,
14) -> Result<(), ModelError> {
15 let kh = 3;
16 let kw = 3;
17
18 model.add_conv2d_zero(channels, channels, kh, kw, 1, 1, true)?;
19 model.add_batch_norm2d_identity(channels, epsilon)?;
20 model.add_relu();
21 model.add_conv2d_zero(channels, channels, kh, kw, 1, 1, true)?;
22 model.add_batch_norm2d_identity(channels, epsilon)?;
23
24 Ok(())
25}
26
27pub fn add_bottleneck_block(
33 model: &mut SequentialModel,
34 in_channels: usize,
35 expand_channels: usize,
36 out_channels: usize,
37 stride: usize,
38 epsilon: f32,
39) -> Result<(), ModelError> {
40 model.add_conv2d_zero(in_channels, expand_channels, 1, 1, 1, 1, false)?;
42 model.add_batch_norm2d_identity(expand_channels, epsilon)?;
43 model.add_relu();
44
45 model.add_conv2d_zero(
47 expand_channels,
48 expand_channels,
49 3,
50 3,
51 stride,
52 stride,
53 false,
54 )?;
55 model.add_batch_norm2d_identity(expand_channels, epsilon)?;
56 model.add_relu();
57
58 model.add_conv2d_zero(expand_channels, out_channels, 1, 1, 1, 1, false)?;
60 model.add_batch_norm2d_identity(out_channels, epsilon)?;
61
62 Ok(())
63}
64
65pub fn build_simple_cnn_classifier(
70 model: &mut SequentialModel,
71 graph: &mut yscv_autograd::Graph,
72 input_channels: usize,
73 num_classes: usize,
74 stage_channels: &[usize],
75 epsilon: f32,
76) -> Result<(), ModelError> {
77 let mut ch = input_channels;
78 for &out_ch in stage_channels {
79 model.add_conv2d_zero(ch, out_ch, 3, 3, 1, 1, true)?;
80 model.add_batch_norm2d_identity(out_ch, epsilon)?;
81 model.add_relu();
82 model.add_max_pool2d(2, 2, 2, 2)?;
83 ch = out_ch;
84 }
85 model.add_global_avg_pool2d();
86 model.add_flatten();
87
88 let weight = Tensor::from_vec(vec![ch, num_classes], vec![0.0; ch * num_classes])?;
89 let bias = Tensor::from_vec(vec![num_classes], vec![0.0; num_classes])?;
90 model.add_linear(graph, ch, num_classes, weight, bias)?;
91
92 Ok(())
93}
94
95pub struct SqueezeExciteBlock {
100 pub fc_reduce_w: Tensor, pub fc_reduce_b: Tensor, pub fc_expand_w: Tensor, pub fc_expand_b: Tensor, pub channels: usize,
105 pub reduced: usize,
106}
107
108impl SqueezeExciteBlock {
109 pub fn new(channels: usize, reduction_ratio: usize) -> Result<Self, ModelError> {
110 let reduced = (channels / reduction_ratio).max(1);
111 Ok(Self {
112 fc_reduce_w: Tensor::from_vec(vec![channels, reduced], vec![0.0; channels * reduced])?,
113 fc_reduce_b: Tensor::from_vec(vec![reduced], vec![0.0; reduced])?,
114 fc_expand_w: Tensor::from_vec(vec![reduced, channels], vec![0.0; reduced * channels])?,
115 fc_expand_b: Tensor::from_vec(vec![channels], vec![0.0; channels])?,
116 channels,
117 reduced,
118 })
119 }
120
121 pub fn forward(&self, input: &Tensor) -> Result<Tensor, ModelError> {
123 let shape = input.shape();
124 let (n, h, w, c) = (shape[0], shape[1], shape[2], shape[3]);
125 let data = input.data();
126
127 let hw = (h * w) as f32;
129 let mut pooled = vec![0.0f32; n * c];
130 for b in 0..n {
131 for ch in 0..c {
132 let mut sum = 0.0f32;
133 for y in 0..h {
134 for x in 0..w {
135 sum += data[((b * h + y) * w + x) * c + ch];
136 }
137 }
138 pooled[b * c + ch] = sum / hw;
139 }
140 }
141 let pooled_t = Tensor::from_vec(vec![n, c], pooled)?;
142
143 let reduced = yscv_kernels::matmul_2d(&pooled_t, &self.fc_reduce_w)?;
145 let reduced = reduced.add(&self.fc_reduce_b.unsqueeze(0)?)?;
146 let reduced_data: Vec<f32> = reduced.data().iter().map(|&v| v.max(0.0)).collect();
147 let reduced = Tensor::from_vec(vec![n, self.reduced], reduced_data)?;
148
149 let expanded = yscv_kernels::matmul_2d(&reduced, &self.fc_expand_w)?;
151 let expanded = expanded.add(&self.fc_expand_b.unsqueeze(0)?)?;
152 let scale_data: Vec<f32> = expanded
153 .data()
154 .iter()
155 .map(|&v| 1.0 / (1.0 + (-v).exp()))
156 .collect();
157
158 let mut out = Vec::with_capacity(n * h * w * c);
160 for b in 0..n {
161 for y in 0..h {
162 for x in 0..w {
163 for ch in 0..c {
164 out.push(data[((b * h + y) * w + x) * c + ch] * scale_data[b * c + ch]);
165 }
166 }
167 }
168 }
169 Tensor::from_vec(shape.to_vec(), out).map_err(Into::into)
170 }
171}
172
173pub struct MbConvBlock {
178 pub expand_conv: Option<crate::Conv2dLayer>,
179 pub expand_bn: Option<crate::BatchNorm2dLayer>,
180 pub depthwise_w: Tensor, pub depthwise_bn: crate::BatchNorm2dLayer,
182 pub se: Option<SqueezeExciteBlock>,
183 pub project_conv: crate::Conv2dLayer,
184 pub project_bn: crate::BatchNorm2dLayer,
185 pub use_residual: bool,
186 pub expanded_ch: usize,
187}
188
189impl MbConvBlock {
190 #[allow(clippy::too_many_arguments)]
191 pub fn new(
192 in_channels: usize,
193 out_channels: usize,
194 expand_ratio: usize,
195 kernel_size: usize,
196 stride: usize,
197 se_ratio: Option<usize>,
198 epsilon: f32,
199 ) -> Result<Self, ModelError> {
200 let expanded_ch = in_channels * expand_ratio;
201 let use_residual = stride == 1 && in_channels == out_channels;
202
203 let (expand_conv, expand_bn) = if expand_ratio != 1 {
204 let w = Tensor::from_vec(
205 vec![1, 1, in_channels, expanded_ch],
206 vec![0.0; in_channels * expanded_ch],
207 )?;
208 let b = Tensor::from_vec(vec![expanded_ch], vec![0.0; expanded_ch])?;
209 (
210 Some(crate::Conv2dLayer::new(
211 in_channels,
212 expanded_ch,
213 1,
214 1,
215 1,
216 1,
217 w,
218 Some(b),
219 )?),
220 Some(crate::BatchNorm2dLayer::identity_init(
221 expanded_ch,
222 epsilon,
223 )?),
224 )
225 } else {
226 (None, None)
227 };
228
229 let depthwise_w = Tensor::from_vec(
230 vec![kernel_size, kernel_size, expanded_ch, 1],
231 vec![0.0; kernel_size * kernel_size * expanded_ch],
232 )?;
233 let depthwise_bn = crate::BatchNorm2dLayer::identity_init(expanded_ch, epsilon)?;
234
235 let se = se_ratio
236 .map(|r| SqueezeExciteBlock::new(expanded_ch, r))
237 .transpose()?;
238
239 let proj_w = Tensor::from_vec(
240 vec![1, 1, expanded_ch, out_channels],
241 vec![0.0; expanded_ch * out_channels],
242 )?;
243 let proj_b = Tensor::from_vec(vec![out_channels], vec![0.0; out_channels])?;
244 let project_conv =
245 crate::Conv2dLayer::new(expanded_ch, out_channels, 1, 1, 1, 1, proj_w, Some(proj_b))?;
246 let project_bn = crate::BatchNorm2dLayer::identity_init(out_channels, epsilon)?;
247
248 Ok(Self {
249 expand_conv,
250 expand_bn,
251 depthwise_w,
252 depthwise_bn,
253 se,
254 project_conv,
255 project_bn,
256 use_residual,
257 expanded_ch,
258 })
259 }
260
261 pub fn forward(&self, input: &Tensor) -> Result<Tensor, ModelError> {
263 let mut x = input.clone();
264
265 if let (Some(conv), Some(bn)) = (&self.expand_conv, &self.expand_bn) {
267 x = conv.forward_inference(&x)?;
268 x = bn.forward_inference(&x)?;
269 let data: Vec<f32> = x.data().iter().map(|&v| v.clamp(0.0, 6.0)).collect();
270 x = Tensor::from_vec(x.shape().to_vec(), data)?;
271 }
272
273 x = yscv_kernels::depthwise_conv2d_nhwc(&x, &self.depthwise_w, None, 1, 1)?;
275 x = self.depthwise_bn.forward_inference(&x)?;
276 let data: Vec<f32> = x.data().iter().map(|&v| v.clamp(0.0, 6.0)).collect();
277 x = Tensor::from_vec(x.shape().to_vec(), data)?;
278
279 if let Some(se) = &self.se {
281 x = se.forward(&x)?;
282 }
283
284 x = self.project_conv.forward_inference(&x)?;
286 x = self.project_bn.forward_inference(&x)?;
287
288 if self.use_residual {
290 x = x.add(input)?;
291 }
292
293 Ok(x)
294 }
295}
296
297pub fn build_resnet_feature_extractor(
301 model: &mut SequentialModel,
302 input_channels: usize,
303 stage_channels: &[usize],
304 blocks_per_stage: usize,
305 epsilon: f32,
306) -> Result<(), ModelError> {
307 let initial_ch = stage_channels.first().copied().unwrap_or(64);
308
309 model.add_conv2d_zero(input_channels, initial_ch, 7, 7, 2, 2, true)?;
311 model.add_batch_norm2d_identity(initial_ch, epsilon)?;
312 model.add_relu();
313 model.add_max_pool2d(3, 3, 2, 2)?;
314
315 let mut ch = initial_ch;
316 for &stage_ch in stage_channels {
317 if stage_ch != ch {
318 model.add_conv2d_zero(ch, stage_ch, 1, 1, 1, 1, false)?;
320 model.add_batch_norm2d_identity(stage_ch, epsilon)?;
321 model.add_relu();
322 }
323 for _ in 0..blocks_per_stage {
324 add_residual_block(model, stage_ch, epsilon)?;
325 }
326 ch = stage_ch;
327 }
328
329 model.add_global_avg_pool2d();
330 model.add_flatten();
331
332 Ok(())
333}
334
335pub struct UNetEncoderStage {
340 conv1: crate::Conv2dLayer,
341 bn1: crate::BatchNorm2dLayer,
342 conv2: crate::Conv2dLayer,
343 bn2: crate::BatchNorm2dLayer,
344}
345
346impl UNetEncoderStage {
347 pub fn new(in_ch: usize, out_ch: usize, epsilon: f32) -> Result<Self, ModelError> {
348 let w1 = Tensor::from_vec(vec![3, 3, in_ch, out_ch], vec![0.0; 9 * in_ch * out_ch])?;
349 let b1 = Tensor::from_vec(vec![out_ch], vec![0.0; out_ch])?;
350 let w2 = Tensor::from_vec(vec![3, 3, out_ch, out_ch], vec![0.0; 9 * out_ch * out_ch])?;
351 let b2 = Tensor::from_vec(vec![out_ch], vec![0.0; out_ch])?;
352 Ok(Self {
353 conv1: crate::Conv2dLayer::new(in_ch, out_ch, 3, 3, 1, 1, w1, Some(b1))?,
354 bn1: crate::BatchNorm2dLayer::identity_init(out_ch, epsilon)?,
355 conv2: crate::Conv2dLayer::new(out_ch, out_ch, 3, 3, 1, 1, w2, Some(b2))?,
356 bn2: crate::BatchNorm2dLayer::identity_init(out_ch, epsilon)?,
357 })
358 }
359
360 pub fn forward(&self, input: &Tensor) -> Result<Tensor, ModelError> {
361 let x = self.conv1.forward_inference(input)?;
362 let x = self.bn1.forward_inference(&x)?;
363 let x = relu_nhwc(&x)?;
364 let x = self.conv2.forward_inference(&x)?;
365 let x = self.bn2.forward_inference(&x)?;
366 relu_nhwc(&x)
367 }
368}
369
370pub struct UNetDecoderStage {
374 conv1: crate::Conv2dLayer,
375 bn1: crate::BatchNorm2dLayer,
376 conv2: crate::Conv2dLayer,
377 bn2: crate::BatchNorm2dLayer,
378}
379
380impl UNetDecoderStage {
381 pub fn new(
382 in_ch: usize,
383 skip_ch: usize,
384 out_ch: usize,
385 epsilon: f32,
386 ) -> Result<Self, ModelError> {
387 let cat_ch = in_ch + skip_ch;
388 let w1 = Tensor::from_vec(vec![3, 3, cat_ch, out_ch], vec![0.0; 9 * cat_ch * out_ch])?;
389 let b1 = Tensor::from_vec(vec![out_ch], vec![0.0; out_ch])?;
390 let w2 = Tensor::from_vec(vec![3, 3, out_ch, out_ch], vec![0.0; 9 * out_ch * out_ch])?;
391 let b2 = Tensor::from_vec(vec![out_ch], vec![0.0; out_ch])?;
392 Ok(Self {
393 conv1: crate::Conv2dLayer::new(cat_ch, out_ch, 3, 3, 1, 1, w1, Some(b1))?,
394 bn1: crate::BatchNorm2dLayer::identity_init(out_ch, epsilon)?,
395 conv2: crate::Conv2dLayer::new(out_ch, out_ch, 3, 3, 1, 1, w2, Some(b2))?,
396 bn2: crate::BatchNorm2dLayer::identity_init(out_ch, epsilon)?,
397 })
398 }
399
400 pub fn forward(&self, upsampled: &Tensor, skip: &Tensor) -> Result<Tensor, ModelError> {
402 let up = upsample_nearest_2x_nhwc(upsampled)?;
403 let cat = cat_nhwc_channel(&up, skip)?;
404 let x = self.conv1.forward_inference(&cat)?;
405 let x = self.bn1.forward_inference(&x)?;
406 let x = relu_nhwc(&x)?;
407 let x = self.conv2.forward_inference(&x)?;
408 let x = self.bn2.forward_inference(&x)?;
409 relu_nhwc(&x)
410 }
411}
412
413pub struct FpnNeck {
418 lateral_convs: Vec<crate::Conv2dLayer>,
419 smooth_convs: Vec<crate::Conv2dLayer>,
420 num_levels: usize,
421}
422
423impl FpnNeck {
424 pub fn new(in_channels: &[usize], out_channels: usize) -> Result<Self, ModelError> {
425 let mut lateral_convs = Vec::with_capacity(in_channels.len());
426 let mut smooth_convs = Vec::with_capacity(in_channels.len());
427 for &ch in in_channels {
428 let w = Tensor::from_vec(vec![1, 1, ch, out_channels], vec![0.0; ch * out_channels])?;
429 let b = Tensor::from_vec(vec![out_channels], vec![0.0; out_channels])?;
430 lateral_convs.push(crate::Conv2dLayer::new(
431 ch,
432 out_channels,
433 1,
434 1,
435 1,
436 1,
437 w,
438 Some(b),
439 )?);
440
441 let w3 = Tensor::from_vec(
442 vec![3, 3, out_channels, out_channels],
443 vec![0.0; 9 * out_channels * out_channels],
444 )?;
445 let b3 = Tensor::from_vec(vec![out_channels], vec![0.0; out_channels])?;
446 smooth_convs.push(crate::Conv2dLayer::new(
447 out_channels,
448 out_channels,
449 3,
450 3,
451 1,
452 1,
453 w3,
454 Some(b3),
455 )?);
456 }
457 Ok(Self {
458 lateral_convs,
459 smooth_convs,
460 num_levels: in_channels.len(),
461 })
462 }
463
464 pub fn forward(&self, features: &[Tensor]) -> Result<Vec<Tensor>, ModelError> {
467 if features.len() != self.num_levels {
468 return Err(ModelError::InvalidInputShape {
469 expected_features: self.num_levels,
470 got: vec![features.len()],
471 });
472 }
473 let mut laterals: Vec<Tensor> = Vec::with_capacity(self.num_levels);
474 for (i, feat) in features.iter().enumerate() {
475 laterals.push(self.lateral_convs[i].forward_inference(feat)?);
476 }
477
478 for i in (0..self.num_levels - 1).rev() {
479 let up = upsample_nearest_2x_nhwc(&laterals[i + 1])?;
480 let shape_i = laterals[i].shape();
481 let shape_up = up.shape();
482 let min_h = shape_i[1].min(shape_up[1]);
483 let min_w = shape_i[2].min(shape_up[2]);
484 let cropped_lat = crop_nhwc(&laterals[i], min_h, min_w)?;
485 let cropped_up = crop_nhwc(&up, min_h, min_w)?;
486 laterals[i] = cropped_lat.add(&cropped_up)?;
487 }
488
489 let mut outputs = Vec::with_capacity(self.num_levels);
490 for (i, lat) in laterals.iter().enumerate() {
491 outputs.push(self.smooth_convs[i].forward_inference(lat)?);
492 }
493 Ok(outputs)
494 }
495}
496
497pub struct AnchorFreeHead {
502 cls_convs: Vec<(crate::Conv2dLayer, crate::BatchNorm2dLayer)>,
503 reg_convs: Vec<(crate::Conv2dLayer, crate::BatchNorm2dLayer)>,
504 cls_out: crate::Conv2dLayer,
505 reg_out: crate::Conv2dLayer,
506 centerness_out: crate::Conv2dLayer,
507}
508
509impl AnchorFreeHead {
510 pub fn new(
511 in_channels: usize,
512 num_classes: usize,
513 num_convs: usize,
514 epsilon: f32,
515 ) -> Result<Self, ModelError> {
516 let mut cls_convs = Vec::with_capacity(num_convs);
517 let mut reg_convs = Vec::with_capacity(num_convs);
518 let mut ch = in_channels;
519 for _ in 0..num_convs {
520 let wc =
521 Tensor::from_vec(vec![3, 3, ch, in_channels], vec![0.0; 9 * ch * in_channels])?;
522 let bc = Tensor::from_vec(vec![in_channels], vec![0.0; in_channels])?;
523 let bnc = crate::BatchNorm2dLayer::identity_init(in_channels, epsilon)?;
524 cls_convs.push((
525 crate::Conv2dLayer::new(ch, in_channels, 3, 3, 1, 1, wc, Some(bc))?,
526 bnc,
527 ));
528
529 let wr =
530 Tensor::from_vec(vec![3, 3, ch, in_channels], vec![0.0; 9 * ch * in_channels])?;
531 let br = Tensor::from_vec(vec![in_channels], vec![0.0; in_channels])?;
532 let bnr = crate::BatchNorm2dLayer::identity_init(in_channels, epsilon)?;
533 reg_convs.push((
534 crate::Conv2dLayer::new(ch, in_channels, 3, 3, 1, 1, wr, Some(br))?,
535 bnr,
536 ));
537 ch = in_channels;
538 }
539
540 let wco = Tensor::from_vec(
541 vec![3, 3, in_channels, num_classes],
542 vec![0.0; 9 * in_channels * num_classes],
543 )?;
544 let bco = Tensor::from_vec(vec![num_classes], vec![0.0; num_classes])?;
545 let cls_out =
546 crate::Conv2dLayer::new(in_channels, num_classes, 3, 3, 1, 1, wco, Some(bco))?;
547
548 let wro = Tensor::from_vec(vec![3, 3, in_channels, 4], vec![0.0; 9 * in_channels * 4])?;
549 let bro = Tensor::from_vec(vec![4], vec![0.0; 4])?;
550 let reg_out = crate::Conv2dLayer::new(in_channels, 4, 3, 3, 1, 1, wro, Some(bro))?;
551
552 let wcn = Tensor::from_vec(vec![3, 3, in_channels, 1], vec![0.0; 9 * in_channels])?;
553 let bcn = Tensor::from_vec(vec![1], vec![0.0; 1])?;
554 let centerness_out = crate::Conv2dLayer::new(in_channels, 1, 3, 3, 1, 1, wcn, Some(bcn))?;
555
556 Ok(Self {
557 cls_convs,
558 reg_convs,
559 cls_out,
560 reg_out,
561 centerness_out,
562 })
563 }
564
565 pub fn forward(&self, input: &Tensor) -> Result<(Tensor, Tensor, Tensor), ModelError> {
568 let mut cls_feat = input.clone();
569 for (conv, bn) in &self.cls_convs {
570 cls_feat = conv.forward_inference(&cls_feat)?;
571 cls_feat = bn.forward_inference(&cls_feat)?;
572 cls_feat = relu_nhwc(&cls_feat)?;
573 }
574
575 let mut reg_feat = input.clone();
576 for (conv, bn) in &self.reg_convs {
577 reg_feat = conv.forward_inference(®_feat)?;
578 reg_feat = bn.forward_inference(®_feat)?;
579 reg_feat = relu_nhwc(®_feat)?;
580 }
581
582 let cls_logits = self.cls_out.forward_inference(&cls_feat)?;
583 let bbox_pred = self.reg_out.forward_inference(®_feat)?;
584 let centerness = self.centerness_out.forward_inference(&cls_feat)?;
585
586 Ok((cls_logits, bbox_pred, centerness))
587 }
588}
589
590fn relu_nhwc(t: &Tensor) -> Result<Tensor, ModelError> {
591 let data: Vec<f32> = t.data().iter().map(|&v| v.max(0.0)).collect();
592 Tensor::from_vec(t.shape().to_vec(), data).map_err(Into::into)
593}
594
595fn upsample_nearest_2x_nhwc(t: &Tensor) -> Result<Tensor, ModelError> {
596 let shape = t.shape();
597 let (n, h, w, c) = (shape[0], shape[1], shape[2], shape[3]);
598 let new_h = h * 2;
599 let new_w = w * 2;
600 let data = t.data();
601 let mut out = vec![0.0f32; n * new_h * new_w * c];
602 for b in 0..n {
603 for y in 0..new_h {
604 for x in 0..new_w {
605 let sy = y / 2;
606 let sx = x / 2;
607 let src_off = ((b * h + sy) * w + sx) * c;
608 let dst_off = ((b * new_h + y) * new_w + x) * c;
609 out[dst_off..dst_off + c].copy_from_slice(&data[src_off..src_off + c]);
610 }
611 }
612 }
613 Tensor::from_vec(vec![n, new_h, new_w, c], out).map_err(Into::into)
614}
615
616fn cat_nhwc_channel(a: &Tensor, b: &Tensor) -> Result<Tensor, ModelError> {
617 let sa = a.shape();
618 let sb = b.shape();
619 let (n, h, w) = (sa[0], sa[1], sa[2]);
620 let ca = sa[3];
621 let cb = sb[3];
622 let da = a.data();
623 let db = b.data();
624 let mut out = vec![0.0f32; n * h * w * (ca + cb)];
625 for b_idx in 0..n {
626 for y in 0..h {
627 for x in 0..w {
628 let src_a = ((b_idx * h + y) * w + x) * ca;
629 let src_b = ((b_idx * h + y) * w + x) * cb;
630 let dst = ((b_idx * h + y) * w + x) * (ca + cb);
631 out[dst..dst + ca].copy_from_slice(&da[src_a..src_a + ca]);
632 out[dst + ca..dst + ca + cb].copy_from_slice(&db[src_b..src_b + cb]);
633 }
634 }
635 }
636 Tensor::from_vec(vec![n, h, w, ca + cb], out).map_err(Into::into)
637}
638
639fn crop_nhwc(t: &Tensor, target_h: usize, target_w: usize) -> Result<Tensor, ModelError> {
640 let shape = t.shape();
641 let (n, _h, w, c) = (shape[0], shape[1], shape[2], shape[3]);
642 let data = t.data();
643 let mut out = vec![0.0f32; n * target_h * target_w * c];
644 for b in 0..n {
645 for y in 0..target_h {
646 let src_off = ((b * _h + y) * w) * c;
647 let dst_off = ((b * target_h + y) * target_w) * c;
648 for x in 0..target_w {
649 let so = src_off + x * c;
650 let do_ = dst_off + x * c;
651 out[do_..do_ + c].copy_from_slice(&data[so..so + c]);
652 }
653 }
654 }
655 Tensor::from_vec(vec![n, target_h, target_w, c], out).map_err(Into::into)
656}
657
658pub struct PatchEmbedding {
671 pub projection_w: Tensor,
673 pub projection_b: Tensor,
675 pub image_size: usize,
676 pub patch_size: usize,
677 pub in_channels: usize,
678 pub embed_dim: usize,
679 pub num_patches: usize,
680}
681
682impl PatchEmbedding {
683 pub fn new(
685 image_size: usize,
686 patch_size: usize,
687 in_channels: usize,
688 embed_dim: usize,
689 ) -> Result<Self, ModelError> {
690 if !image_size.is_multiple_of(patch_size) {
691 return Err(ModelError::InvalidParameterShape {
692 parameter: "image_size must be divisible by patch_size",
693 expected: vec![image_size, patch_size],
694 got: vec![image_size % patch_size],
695 });
696 }
697 let num_patches = (image_size / patch_size) * (image_size / patch_size);
698 let patch_dim = patch_size * patch_size * in_channels;
699 Ok(Self {
700 projection_w: Tensor::from_vec(
701 vec![patch_dim, embed_dim],
702 vec![0.0; patch_dim * embed_dim],
703 )?,
704 projection_b: Tensor::from_vec(vec![embed_dim], vec![0.0; embed_dim])?,
705 image_size,
706 patch_size,
707 in_channels,
708 embed_dim,
709 num_patches,
710 })
711 }
712
713 pub fn forward(&self, input: &Tensor) -> Result<Tensor, ModelError> {
715 let shape = input.shape();
716 let (batch, h, w, c) = (shape[0], shape[1], shape[2], shape[3]);
717 let ps = self.patch_size;
718 let grid_h = h / ps;
719 let grid_w = w / ps;
720 let num_patches = grid_h * grid_w;
721 let patch_dim = ps * ps * c;
722 let data = input.data();
723
724 let mut patches = vec![0.0f32; batch * num_patches * patch_dim];
726 for b in 0..batch {
727 for gh in 0..grid_h {
728 for gw in 0..grid_w {
729 let patch_idx = gh * grid_w + gw;
730 let dst_base = (b * num_patches + patch_idx) * patch_dim;
731 let mut offset = 0;
732 for ph in 0..ps {
733 for pw in 0..ps {
734 let iy = gh * ps + ph;
735 let ix = gw * ps + pw;
736 let src = ((b * h + iy) * w + ix) * c;
737 patches[dst_base + offset..dst_base + offset + c]
738 .copy_from_slice(&data[src..src + c]);
739 offset += c;
740 }
741 }
742 }
743 }
744 }
745
746 let patches_t = Tensor::from_vec(vec![batch * num_patches, patch_dim], patches)?;
748 let projected = matmul_2d(&patches_t, &self.projection_w)?;
749 let projected = projected.add(&self.projection_b.unsqueeze(0)?)?;
750
751 projected
753 .reshape(vec![batch, num_patches, self.embed_dim])
754 .map_err(Into::into)
755 }
756}
757
758pub struct VisionTransformer {
771 pub patch_embed: PatchEmbedding,
772 pub cls_token: Tensor,
774 pub pos_embed: Tensor,
776 pub encoder_blocks: Vec<TransformerEncoderBlock>,
778 pub ln_gamma: Tensor,
780 pub ln_beta: Tensor,
782 pub head_w: Tensor,
784 pub head_b: Tensor,
786 pub embed_dim: usize,
787 pub num_classes: usize,
788}
789
790impl VisionTransformer {
791 #[allow(clippy::too_many_arguments)]
793 pub fn new(
794 image_size: usize,
795 patch_size: usize,
796 in_channels: usize,
797 embed_dim: usize,
798 num_heads: usize,
799 num_layers: usize,
800 num_classes: usize,
801 mlp_ratio: f32,
802 ) -> Result<Self, ModelError> {
803 let patch_embed = PatchEmbedding::new(image_size, patch_size, in_channels, embed_dim)?;
804 let num_patches = patch_embed.num_patches;
805 let seq_len = num_patches + 1; let cls_token = Tensor::from_vec(vec![1, embed_dim], vec![0.0; embed_dim])?;
808 let pos_embed =
809 Tensor::from_vec(vec![1, seq_len, embed_dim], vec![0.0; seq_len * embed_dim])?;
810
811 let d_ff = (embed_dim as f32 * mlp_ratio) as usize;
812 let mut encoder_blocks = Vec::with_capacity(num_layers);
813 for _ in 0..num_layers {
814 encoder_blocks.push(TransformerEncoderBlock::new(embed_dim, num_heads, d_ff)?);
815 }
816
817 let ln_gamma = Tensor::from_vec(vec![embed_dim], vec![1.0; embed_dim])?;
818 let ln_beta = Tensor::from_vec(vec![embed_dim], vec![0.0; embed_dim])?;
819
820 let head_w = Tensor::from_vec(
821 vec![embed_dim, num_classes],
822 vec![0.0; embed_dim * num_classes],
823 )?;
824 let head_b = Tensor::from_vec(vec![num_classes], vec![0.0; num_classes])?;
825
826 Ok(Self {
827 patch_embed,
828 cls_token,
829 pos_embed,
830 encoder_blocks,
831 ln_gamma,
832 ln_beta,
833 head_w,
834 head_b,
835 embed_dim,
836 num_classes,
837 })
838 }
839
840 pub fn forward(&self, input: &Tensor) -> Result<Tensor, ModelError> {
842 let batch = input.shape()[0];
843
844 let patch_tokens = self.patch_embed.forward(input)?;
846 let num_patches = patch_tokens.shape()[1];
847
848 let cls_expanded = self.cls_token.repeat(&[batch, 1])?; let cls_expanded = cls_expanded.reshape(vec![batch, 1, self.embed_dim])?;
852
853 let seq_len = num_patches + 1;
855 let patch_data = patch_tokens.data();
856 let cls_data = cls_expanded.data();
857 let mut combined = vec![0.0f32; batch * seq_len * self.embed_dim];
858 for b in 0..batch {
859 let cls_src = b * self.embed_dim;
861 let dst_base = b * seq_len * self.embed_dim;
862 combined[dst_base..dst_base + self.embed_dim]
863 .copy_from_slice(&cls_data[cls_src..cls_src + self.embed_dim]);
864 let patch_src = b * num_patches * self.embed_dim;
866 let patch_dst = dst_base + self.embed_dim;
867 let patch_len = num_patches * self.embed_dim;
868 combined[patch_dst..patch_dst + patch_len]
869 .copy_from_slice(&patch_data[patch_src..patch_src + patch_len]);
870 }
871 let mut x = Tensor::from_vec(vec![batch, seq_len, self.embed_dim], combined)?;
872
873 let pos = self.pos_embed.repeat(&[batch, 1, 1])?;
875 x = x.add(&pos)?;
876
877 let mut out_data = vec![0.0f32; batch * seq_len * self.embed_dim];
880 for b in 0..batch {
881 let start = b * seq_len * self.embed_dim;
883 let end = start + seq_len * self.embed_dim;
884 let slice = &x.data()[start..end];
885 let mut seq = Tensor::from_vec(vec![seq_len, self.embed_dim], slice.to_vec())?;
886
887 for block in &self.encoder_blocks {
888 seq = block.forward(&seq)?;
889 }
890
891 let seq_data = seq.data();
892 out_data[start..end].copy_from_slice(seq_data);
893 }
894 let x = Tensor::from_vec(vec![batch, seq_len, self.embed_dim], out_data)?;
895
896 let x_2d = x.reshape(vec![batch * seq_len, self.embed_dim])?;
898 let params = LayerNormLastDimParams {
899 gamma: &self.ln_gamma,
900 beta: &self.ln_beta,
901 epsilon: 1e-5,
902 };
903 let normed = layer_norm_last_dim(&x_2d, params)?;
904 let normed = normed.reshape(vec![batch, seq_len, self.embed_dim])?;
905
906 let normed_data = normed.data();
908 let mut cls_out = vec![0.0f32; batch * self.embed_dim];
909 for b in 0..batch {
910 let src = b * seq_len * self.embed_dim;
911 cls_out[b * self.embed_dim..(b + 1) * self.embed_dim]
912 .copy_from_slice(&normed_data[src..src + self.embed_dim]);
913 }
914 let cls_features = Tensor::from_vec(vec![batch, self.embed_dim], cls_out)?;
915
916 let logits = matmul_2d(&cls_features, &self.head_w)?;
918 let logits = logits.add(&self.head_b.unsqueeze(0)?)?;
919
920 Ok(logits)
921 }
922}