1use crate::error::{NeuralError, Result};
9use crate::layers::conv::PaddingMode;
10use crate::layers::{BatchNorm, Conv2D, Dense, Dropout, Layer};
11use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
12use scirs2_core::numeric::{Float, FromPrimitive, NumAssign, ToPrimitive};
13use scirs2_core::random::SeedableRng;
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::fmt::Debug;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub enum ResNetBlock {
21 Basic,
23 Bottleneck,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct ResNetLayer {
30 pub blocks: usize,
32 pub channels: usize,
34 pub stride: usize,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct ResNetConfig {
41 pub block: ResNetBlock,
43 pub layers: Vec<ResNetLayer>,
45 pub input_channels: usize,
47 pub num_classes: usize,
49 pub dropout_rate: f64,
51}
52
53impl ResNetConfig {
54 pub fn resnet18(input_channels: usize, num_classes: usize) -> Self {
56 Self {
57 block: ResNetBlock::Basic,
58 layers: vec![
59 ResNetLayer {
60 blocks: 2,
61 channels: 64,
62 stride: 1,
63 },
64 ResNetLayer {
65 blocks: 2,
66 channels: 128,
67 stride: 2,
68 },
69 ResNetLayer {
70 blocks: 2,
71 channels: 256,
72 stride: 2,
73 },
74 ResNetLayer {
75 blocks: 2,
76 channels: 512,
77 stride: 2,
78 },
79 ],
80 input_channels,
81 num_classes,
82 dropout_rate: 0.0,
83 }
84 }
85
86 pub fn resnet34(input_channels: usize, num_classes: usize) -> Self {
88 Self {
89 block: ResNetBlock::Basic,
90 layers: vec![
91 ResNetLayer {
92 blocks: 3,
93 channels: 64,
94 stride: 1,
95 },
96 ResNetLayer {
97 blocks: 4,
98 channels: 128,
99 stride: 2,
100 },
101 ResNetLayer {
102 blocks: 6,
103 channels: 256,
104 stride: 2,
105 },
106 ResNetLayer {
107 blocks: 3,
108 channels: 512,
109 stride: 2,
110 },
111 ],
112 input_channels,
113 num_classes,
114 dropout_rate: 0.0,
115 }
116 }
117
118 pub fn resnet50(input_channels: usize, num_classes: usize) -> Self {
120 Self {
121 block: ResNetBlock::Bottleneck,
122 layers: vec![
123 ResNetLayer {
124 blocks: 3,
125 channels: 64,
126 stride: 1,
127 },
128 ResNetLayer {
129 blocks: 4,
130 channels: 128,
131 stride: 2,
132 },
133 ResNetLayer {
134 blocks: 6,
135 channels: 256,
136 stride: 2,
137 },
138 ResNetLayer {
139 blocks: 3,
140 channels: 512,
141 stride: 2,
142 },
143 ],
144 input_channels,
145 num_classes,
146 dropout_rate: 0.0,
147 }
148 }
149
150 pub fn resnet101(input_channels: usize, num_classes: usize) -> Self {
152 Self {
153 block: ResNetBlock::Bottleneck,
154 layers: vec![
155 ResNetLayer {
156 blocks: 3,
157 channels: 64,
158 stride: 1,
159 },
160 ResNetLayer {
161 blocks: 4,
162 channels: 128,
163 stride: 2,
164 },
165 ResNetLayer {
166 blocks: 23,
167 channels: 256,
168 stride: 2,
169 },
170 ResNetLayer {
171 blocks: 3,
172 channels: 512,
173 stride: 2,
174 },
175 ],
176 input_channels,
177 num_classes,
178 dropout_rate: 0.0,
179 }
180 }
181
182 pub fn resnet152(input_channels: usize, num_classes: usize) -> Self {
184 Self {
185 block: ResNetBlock::Bottleneck,
186 layers: vec![
187 ResNetLayer {
188 blocks: 3,
189 channels: 64,
190 stride: 1,
191 },
192 ResNetLayer {
193 blocks: 8,
194 channels: 128,
195 stride: 2,
196 },
197 ResNetLayer {
198 blocks: 36,
199 channels: 256,
200 stride: 2,
201 },
202 ResNetLayer {
203 blocks: 3,
204 channels: 512,
205 stride: 2,
206 },
207 ],
208 input_channels,
209 num_classes,
210 dropout_rate: 0.0,
211 }
212 }
213
214 pub fn with_dropout(mut self, rate: f64) -> Self {
216 self.dropout_rate = rate;
217 self
218 }
219}
220
221struct BasicBlock<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> {
223 conv1: Conv2D<F>,
225 bn1: BatchNorm<F>,
227 conv2: Conv2D<F>,
229 bn2: BatchNorm<F>,
231 downsample: Option<(Conv2D<F>, BatchNorm<F>)>,
233}
234
235impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Clone for BasicBlock<F> {
236 fn clone(&self) -> Self {
237 Self {
238 conv1: self.conv1.clone(),
239 bn1: self.bn1.clone(),
240 conv2: self.conv2.clone(),
241 bn2: self.bn2.clone(),
242 downsample: self.downsample.clone(),
243 }
244 }
245}
246
247impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> BasicBlock<F> {
248 pub fn new(
250 in_channels: usize,
251 out_channels: usize,
252 stride: usize,
253 downsample: bool,
254 ) -> Result<Self> {
255 let stride_tuple = (stride, stride);
256
257 let mut rng1 = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
258 let conv1 = Conv2D::new(in_channels, out_channels, (3, 3), stride_tuple, None)?
259 .with_padding(PaddingMode::Same);
260
261 let mut rng2 = scirs2_core::random::rngs::SmallRng::from_seed([43; 32]);
262 let bn1 = BatchNorm::new(out_channels, 1e-5, 0.1, &mut rng2)?;
263
264 let mut rng3 = scirs2_core::random::rngs::SmallRng::from_seed([44; 32]);
265 let conv2 = Conv2D::new(out_channels, out_channels, (3, 3), (1, 1), None)?
266 .with_padding(PaddingMode::Same);
267
268 let mut rng4 = scirs2_core::random::rngs::SmallRng::from_seed([45; 32]);
269 let bn2 = BatchNorm::new(out_channels, 1e-5, 0.1, &mut rng4)?;
270
271 let downsample = if downsample {
272 let mut rng5 = scirs2_core::random::rngs::SmallRng::from_seed([46; 32]);
273 let ds_conv = Conv2D::new(in_channels, out_channels, (1, 1), stride_tuple, None)?
274 .with_padding(PaddingMode::Valid);
275
276 let mut rng6 = scirs2_core::random::rngs::SmallRng::from_seed([47; 32]);
277 let ds_bn = BatchNorm::new(out_channels, 1e-5, 0.1, &mut rng6)?;
278 Some((ds_conv, ds_bn))
279 } else {
280 None
281 };
282
283 Ok(Self {
284 conv1,
285 bn1,
286 conv2,
287 bn2,
288 downsample,
289 })
290 }
291}
292
293impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Layer<F>
294 for BasicBlock<F>
295{
296 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
297 let mut x = self.conv1.forward(input)?;
299 x = self.bn1.forward(&x)?;
300 x = x.mapv(|v: F| v.max(F::zero())); x = self.conv2.forward(&x)?;
304 x = self.bn2.forward(&x)?;
305
306 let identity = if let Some((ref conv, ref bn)) = self.downsample {
308 let ds = conv.forward(input)?;
309 bn.forward(&ds)?
310 } else {
311 input.clone()
312 };
313
314 let x = &x + &identity;
316
317 let x = x.mapv(|v: F| v.max(F::zero()));
319
320 Ok(x)
321 }
322
323 fn backward(
324 &self,
325 _input: &Array<F, IxDyn>,
326 grad_output: &Array<F, IxDyn>,
327 ) -> Result<Array<F, IxDyn>> {
328 Ok(grad_output.clone())
329 }
330
331 fn update(&mut self, learning_rate: F) -> Result<()> {
332 self.conv1.update(learning_rate)?;
333 self.bn1.update(learning_rate)?;
334 self.conv2.update(learning_rate)?;
335 self.bn2.update(learning_rate)?;
336 if let Some((ref mut conv, ref mut bn)) = self.downsample {
337 conv.update(learning_rate)?;
338 bn.update(learning_rate)?;
339 }
340 Ok(())
341 }
342
343 fn as_any(&self) -> &dyn std::any::Any {
344 self
345 }
346
347 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
348 self
349 }
350}
351
352impl<
353 F: Float
354 + Debug
355 + ScalarOperand
356 + Send
357 + Sync
358 + NumAssign
359 + ToPrimitive
360 + FromPrimitive
361 + 'static,
362 > BasicBlock<F>
363{
364 pub(crate) fn extract_named_params(&self, prefix: &str) -> Vec<(String, Array<F, IxDyn>)> {
366 let mut result = Vec::new();
367 for (i, p) in self.conv1.params().iter().enumerate() {
369 let suffix = if i == 0 { "weight" } else { "bias" };
370 result.push((format!("{prefix}.conv1.{suffix}"), p.clone()));
371 }
372 for (i, p) in self.bn1.params().iter().enumerate() {
374 let suffix = if i == 0 { "weight" } else { "bias" };
375 result.push((format!("{prefix}.bn1.{suffix}"), p.clone()));
376 }
377 for (i, p) in self.conv2.params().iter().enumerate() {
379 let suffix = if i == 0 { "weight" } else { "bias" };
380 result.push((format!("{prefix}.conv2.{suffix}"), p.clone()));
381 }
382 for (i, p) in self.bn2.params().iter().enumerate() {
384 let suffix = if i == 0 { "weight" } else { "bias" };
385 result.push((format!("{prefix}.bn2.{suffix}"), p.clone()));
386 }
387 if let Some((ref conv, ref bn)) = self.downsample {
389 for (i, p) in conv.params().iter().enumerate() {
390 let suffix = if i == 0 { "weight" } else { "bias" };
391 result.push((format!("{prefix}.downsample.0.{suffix}"), p.clone()));
392 }
393 for (i, p) in bn.params().iter().enumerate() {
394 let suffix = if i == 0 { "weight" } else { "bias" };
395 result.push((format!("{prefix}.downsample.1.{suffix}"), p.clone()));
396 }
397 }
398 result
399 }
400
401 pub(crate) fn load_named_params(
403 &mut self,
404 prefix: &str,
405 params_map: &HashMap<String, Array<F, IxDyn>>,
406 ) -> Result<()> {
407 if let Some(w) = params_map.get(&format!("{prefix}.conv1.weight")) {
409 let mut ps = vec![w.clone()];
410 if let Some(b) = params_map.get(&format!("{prefix}.conv1.bias")) {
411 ps.push(b.clone());
412 }
413 self.conv1.set_params(&ps)?;
414 }
415 if let Some(w) = params_map.get(&format!("{prefix}.bn1.weight")) {
417 let mut ps = vec![w.clone()];
418 if let Some(b) = params_map.get(&format!("{prefix}.bn1.bias")) {
419 ps.push(b.clone());
420 }
421 self.bn1.set_params(&ps)?;
422 }
423 if let Some(w) = params_map.get(&format!("{prefix}.conv2.weight")) {
425 let mut ps = vec![w.clone()];
426 if let Some(b) = params_map.get(&format!("{prefix}.conv2.bias")) {
427 ps.push(b.clone());
428 }
429 self.conv2.set_params(&ps)?;
430 }
431 if let Some(w) = params_map.get(&format!("{prefix}.bn2.weight")) {
433 let mut ps = vec![w.clone()];
434 if let Some(b) = params_map.get(&format!("{prefix}.bn2.bias")) {
435 ps.push(b.clone());
436 }
437 self.bn2.set_params(&ps)?;
438 }
439 if let Some((ref mut conv, ref mut bn)) = self.downsample {
441 if let Some(w) = params_map.get(&format!("{prefix}.downsample.0.weight")) {
442 let mut ps = vec![w.clone()];
443 if let Some(b) = params_map.get(&format!("{prefix}.downsample.0.bias")) {
444 ps.push(b.clone());
445 }
446 conv.set_params(&ps)?;
447 }
448 if let Some(w) = params_map.get(&format!("{prefix}.downsample.1.weight")) {
449 let mut ps = vec![w.clone()];
450 if let Some(b) = params_map.get(&format!("{prefix}.downsample.1.bias")) {
451 ps.push(b.clone());
452 }
453 bn.set_params(&ps)?;
454 }
455 }
456 Ok(())
457 }
458}
459
460struct BottleneckBlock<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> {
462 conv1: Conv2D<F>,
464 bn1: BatchNorm<F>,
466 conv2: Conv2D<F>,
468 bn2: BatchNorm<F>,
470 conv3: Conv2D<F>,
472 bn3: BatchNorm<F>,
474 downsample: Option<(Conv2D<F>, BatchNorm<F>)>,
476 #[allow(dead_code)]
478 expansion: usize,
479}
480
481impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Clone
482 for BottleneckBlock<F>
483{
484 fn clone(&self) -> Self {
485 Self {
486 conv1: self.conv1.clone(),
487 bn1: self.bn1.clone(),
488 conv2: self.conv2.clone(),
489 bn2: self.bn2.clone(),
490 conv3: self.conv3.clone(),
491 bn3: self.bn3.clone(),
492 downsample: self.downsample.clone(),
493 expansion: self.expansion,
494 }
495 }
496}
497
498impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> BottleneckBlock<F> {
499 const EXPANSION: usize = 4;
501
502 pub fn new(
504 in_channels: usize,
505 out_channels: usize,
506 stride: usize,
507 downsample: bool,
508 ) -> Result<Self> {
509 let bottleneck_channels = out_channels / Self::EXPANSION;
510 let stride_tuple = (stride, stride);
511
512 let mut rng1 = scirs2_core::random::rngs::SmallRng::from_seed([48; 32]);
514 let conv1 = Conv2D::new(in_channels, bottleneck_channels, (1, 1), (1, 1), None)?
515 .with_padding(PaddingMode::Valid);
516
517 let mut rng2 = scirs2_core::random::rngs::SmallRng::from_seed([49; 32]);
518 let bn1 = BatchNorm::new(bottleneck_channels, 1e-5, 0.1, &mut rng2)?;
519
520 let mut rng3 = scirs2_core::random::rngs::SmallRng::from_seed([50; 32]);
522 let conv2 = Conv2D::new(
523 bottleneck_channels,
524 bottleneck_channels,
525 (3, 3),
526 stride_tuple,
527 None,
528 )?
529 .with_padding(PaddingMode::Same);
530
531 let mut rng4 = scirs2_core::random::rngs::SmallRng::from_seed([51; 32]);
532 let bn2 = BatchNorm::new(bottleneck_channels, 1e-5, 0.1, &mut rng4)?;
533
534 let mut rng5 = scirs2_core::random::rngs::SmallRng::from_seed([52; 32]);
536 let conv3 = Conv2D::new(bottleneck_channels, out_channels, (1, 1), (1, 1), None)?
537 .with_padding(PaddingMode::Valid);
538
539 let mut rng6 = scirs2_core::random::rngs::SmallRng::from_seed([53; 32]);
540 let bn3 = BatchNorm::new(out_channels, 1e-5, 0.1, &mut rng6)?;
541
542 let downsample = if downsample {
544 let mut rng7 = scirs2_core::random::rngs::SmallRng::from_seed([54; 32]);
545 let ds_conv = Conv2D::new(in_channels, out_channels, (1, 1), stride_tuple, None)?
546 .with_padding(PaddingMode::Valid);
547
548 let mut rng8 = scirs2_core::random::rngs::SmallRng::from_seed([55; 32]);
549 let ds_bn = BatchNorm::new(out_channels, 1e-5, 0.1, &mut rng8)?;
550 Some((ds_conv, ds_bn))
551 } else {
552 None
553 };
554
555 Ok(Self {
556 conv1,
557 bn1,
558 conv2,
559 bn2,
560 conv3,
561 bn3,
562 downsample,
563 expansion: Self::EXPANSION,
564 })
565 }
566}
567
568impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Layer<F>
569 for BottleneckBlock<F>
570{
571 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
572 let mut x = self.conv1.forward(input)?;
574 x = self.bn1.forward(&x)?;
575 x = x.mapv(|v: F| v.max(F::zero())); x = self.conv2.forward(&x)?;
579 x = self.bn2.forward(&x)?;
580 x = x.mapv(|v: F| v.max(F::zero())); x = self.conv3.forward(&x)?;
584 x = self.bn3.forward(&x)?;
585
586 let identity = if let Some((ref conv, ref bn)) = self.downsample {
588 let ds = conv.forward(input)?;
589 bn.forward(&ds)?
590 } else {
591 input.clone()
592 };
593
594 let x = &x + &identity;
596
597 let x = x.mapv(|v: F| v.max(F::zero()));
599
600 Ok(x)
601 }
602
603 fn backward(
604 &self,
605 _input: &Array<F, IxDyn>,
606 grad_output: &Array<F, IxDyn>,
607 ) -> Result<Array<F, IxDyn>> {
608 Ok(grad_output.clone())
609 }
610
611 fn update(&mut self, learning_rate: F) -> Result<()> {
612 self.conv1.update(learning_rate)?;
613 self.bn1.update(learning_rate)?;
614 self.conv2.update(learning_rate)?;
615 self.bn2.update(learning_rate)?;
616 self.conv3.update(learning_rate)?;
617 self.bn3.update(learning_rate)?;
618 if let Some((ref mut conv, ref mut bn)) = self.downsample {
619 conv.update(learning_rate)?;
620 bn.update(learning_rate)?;
621 }
622 Ok(())
623 }
624
625 fn as_any(&self) -> &dyn std::any::Any {
626 self
627 }
628
629 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
630 self
631 }
632}
633
634impl<
635 F: Float
636 + Debug
637 + ScalarOperand
638 + Send
639 + Sync
640 + NumAssign
641 + ToPrimitive
642 + FromPrimitive
643 + 'static,
644 > BottleneckBlock<F>
645{
646 pub(crate) fn extract_named_params(&self, prefix: &str) -> Vec<(String, Array<F, IxDyn>)> {
648 let mut result = Vec::new();
649 for (i, p) in self.conv1.params().iter().enumerate() {
650 let suffix = if i == 0 { "weight" } else { "bias" };
651 result.push((format!("{prefix}.conv1.{suffix}"), p.clone()));
652 }
653 for (i, p) in self.bn1.params().iter().enumerate() {
654 let suffix = if i == 0 { "weight" } else { "bias" };
655 result.push((format!("{prefix}.bn1.{suffix}"), p.clone()));
656 }
657 for (i, p) in self.conv2.params().iter().enumerate() {
658 let suffix = if i == 0 { "weight" } else { "bias" };
659 result.push((format!("{prefix}.conv2.{suffix}"), p.clone()));
660 }
661 for (i, p) in self.bn2.params().iter().enumerate() {
662 let suffix = if i == 0 { "weight" } else { "bias" };
663 result.push((format!("{prefix}.bn2.{suffix}"), p.clone()));
664 }
665 for (i, p) in self.conv3.params().iter().enumerate() {
666 let suffix = if i == 0 { "weight" } else { "bias" };
667 result.push((format!("{prefix}.conv3.{suffix}"), p.clone()));
668 }
669 for (i, p) in self.bn3.params().iter().enumerate() {
670 let suffix = if i == 0 { "weight" } else { "bias" };
671 result.push((format!("{prefix}.bn3.{suffix}"), p.clone()));
672 }
673 if let Some((ref conv, ref bn)) = self.downsample {
674 for (i, p) in conv.params().iter().enumerate() {
675 let suffix = if i == 0 { "weight" } else { "bias" };
676 result.push((format!("{prefix}.downsample.0.{suffix}"), p.clone()));
677 }
678 for (i, p) in bn.params().iter().enumerate() {
679 let suffix = if i == 0 { "weight" } else { "bias" };
680 result.push((format!("{prefix}.downsample.1.{suffix}"), p.clone()));
681 }
682 }
683 result
684 }
685
686 pub(crate) fn load_named_params(
688 &mut self,
689 prefix: &str,
690 params_map: &HashMap<String, Array<F, IxDyn>>,
691 ) -> Result<()> {
692 if let Some(w) = params_map.get(&format!("{prefix}.conv1.weight")) {
693 let mut ps = vec![w.clone()];
694 if let Some(b) = params_map.get(&format!("{prefix}.conv1.bias")) {
695 ps.push(b.clone());
696 }
697 self.conv1.set_params(&ps)?;
698 }
699 if let Some(w) = params_map.get(&format!("{prefix}.bn1.weight")) {
700 let mut ps = vec![w.clone()];
701 if let Some(b) = params_map.get(&format!("{prefix}.bn1.bias")) {
702 ps.push(b.clone());
703 }
704 self.bn1.set_params(&ps)?;
705 }
706 if let Some(w) = params_map.get(&format!("{prefix}.conv2.weight")) {
707 let mut ps = vec![w.clone()];
708 if let Some(b) = params_map.get(&format!("{prefix}.conv2.bias")) {
709 ps.push(b.clone());
710 }
711 self.conv2.set_params(&ps)?;
712 }
713 if let Some(w) = params_map.get(&format!("{prefix}.bn2.weight")) {
714 let mut ps = vec![w.clone()];
715 if let Some(b) = params_map.get(&format!("{prefix}.bn2.bias")) {
716 ps.push(b.clone());
717 }
718 self.bn2.set_params(&ps)?;
719 }
720 if let Some(w) = params_map.get(&format!("{prefix}.conv3.weight")) {
721 let mut ps = vec![w.clone()];
722 if let Some(b) = params_map.get(&format!("{prefix}.conv3.bias")) {
723 ps.push(b.clone());
724 }
725 self.conv3.set_params(&ps)?;
726 }
727 if let Some(w) = params_map.get(&format!("{prefix}.bn3.weight")) {
728 let mut ps = vec![w.clone()];
729 if let Some(b) = params_map.get(&format!("{prefix}.bn3.bias")) {
730 ps.push(b.clone());
731 }
732 self.bn3.set_params(&ps)?;
733 }
734 if let Some((ref mut conv, ref mut bn)) = self.downsample {
735 if let Some(w) = params_map.get(&format!("{prefix}.downsample.0.weight")) {
736 let mut ps = vec![w.clone()];
737 if let Some(b) = params_map.get(&format!("{prefix}.downsample.0.bias")) {
738 ps.push(b.clone());
739 }
740 conv.set_params(&ps)?;
741 }
742 if let Some(w) = params_map.get(&format!("{prefix}.downsample.1.weight")) {
743 let mut ps = vec![w.clone()];
744 if let Some(b) = params_map.get(&format!("{prefix}.downsample.1.bias")) {
745 ps.push(b.clone());
746 }
747 bn.set_params(&ps)?;
748 }
749 }
750 Ok(())
751 }
752}
753
754pub struct ResNet<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> {
756 conv1: Conv2D<F>,
758 bn1: BatchNorm<F>,
760 layer1: Vec<BasicBlock<F>>,
762 layer1_bottleneck: Vec<BottleneckBlock<F>>,
764 fc: Dense<F>,
766 dropout: Option<Dropout<F>>,
768 config: ResNetConfig,
770}
771
772impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> ResNet<F> {
773 pub fn new(config: ResNetConfig) -> Result<Self> {
775 let mut rng1 = scirs2_core::random::rngs::SmallRng::from_seed([56; 32]);
777 let conv1 = Conv2D::new(config.input_channels, 64, (7, 7), (2, 2), None)?
778 .with_padding(PaddingMode::Same);
779
780 let mut rng2 = scirs2_core::random::rngs::SmallRng::from_seed([57; 32]);
781 let bn1 = BatchNorm::new(64, 1e-5, 0.1, &mut rng2)?;
782
783 let layer1 = Vec::new();
785 let layer1_bottleneck = Vec::new();
786
787 let fc_in_features = match config.block {
789 ResNetBlock::Basic => config.layers.last().map(|l| l.channels).unwrap_or(512),
790 ResNetBlock::Bottleneck => config.layers.last().map(|l| l.channels * 4).unwrap_or(2048),
791 };
792
793 let mut rng3 = scirs2_core::random::rngs::SmallRng::from_seed([58; 32]);
794 let fc = Dense::new(fc_in_features, config.num_classes, None, &mut rng3)?;
795
796 let dropout = if config.dropout_rate > 0.0 {
798 let mut rng4 = scirs2_core::random::rngs::SmallRng::from_seed([59; 32]);
799 Some(Dropout::new(config.dropout_rate, &mut rng4)?)
800 } else {
801 None
802 };
803
804 Ok(Self {
805 conv1,
806 bn1,
807 layer1,
808 layer1_bottleneck,
809 fc,
810 dropout,
811 config,
812 })
813 }
814
815 pub fn resnet18(input_channels: usize, num_classes: usize) -> Result<Self> {
817 let config = ResNetConfig::resnet18(input_channels, num_classes);
818 Self::new(config)
819 }
820
821 pub fn resnet34(input_channels: usize, num_classes: usize) -> Result<Self> {
823 let config = ResNetConfig::resnet34(input_channels, num_classes);
824 Self::new(config)
825 }
826
827 pub fn resnet50(input_channels: usize, num_classes: usize) -> Result<Self> {
829 let config = ResNetConfig::resnet50(input_channels, num_classes);
830 Self::new(config)
831 }
832
833 pub fn resnet101(input_channels: usize, num_classes: usize) -> Result<Self> {
835 let config = ResNetConfig::resnet101(input_channels, num_classes);
836 Self::new(config)
837 }
838
839 pub fn resnet152(input_channels: usize, num_classes: usize) -> Result<Self> {
841 let config = ResNetConfig::resnet152(input_channels, num_classes);
842 Self::new(config)
843 }
844
845 pub fn config(&self) -> &ResNetConfig {
847 &self.config
848 }
849
850 fn global_avg_pool(x: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
852 let shape = x.shape();
853 if shape.len() != 4 {
854 return Err(NeuralError::InferenceError(format!(
855 "Expected 4D input for average pooling, got shape {:?}",
856 shape
857 )));
858 }
859
860 let batch_size = shape[0];
861 let channels = shape[1];
862 let height = shape[2];
863 let width = shape[3];
864
865 let mut output = Array::zeros(IxDyn(&[batch_size, channels]));
866 let count = F::from(height * width).expect("Failed to convert to float");
867
868 for b in 0..batch_size {
869 for c in 0..channels {
870 let mut sum = F::zero();
871 for h in 0..height {
872 for w in 0..width {
873 sum += x[[b, c, h, w]];
874 }
875 }
876 output[[b, c]] = sum / count;
877 }
878 }
879
880 Ok(output)
881 }
882}
883
884impl<
885 F: Float
886 + Debug
887 + ScalarOperand
888 + Send
889 + Sync
890 + NumAssign
891 + ToPrimitive
892 + FromPrimitive
893 + 'static,
894 > ResNet<F>
895{
896 pub fn extract_named_params(&self) -> Result<Vec<(String, Array<F, IxDyn>)>> {
904 let mut result = Vec::new();
905
906 for (i, p) in self.conv1.params().iter().enumerate() {
908 let suffix = if i == 0 { "weight" } else { "bias" };
909 result.push((format!("conv1.{suffix}"), p.clone()));
910 }
911 for (i, p) in self.bn1.params().iter().enumerate() {
912 let suffix = if i == 0 { "weight" } else { "bias" };
913 result.push((format!("bn1.{suffix}"), p.clone()));
914 }
915
916 for (idx, block) in self.layer1.iter().enumerate() {
918 let block_params = block.extract_named_params(&format!("layer1.{idx}"));
919 result.extend(block_params);
920 }
921 for (idx, block) in self.layer1_bottleneck.iter().enumerate() {
922 let block_params = block.extract_named_params(&format!("layer1.{idx}"));
923 result.extend(block_params);
924 }
925
926 for (i, p) in self.fc.params().iter().enumerate() {
928 let suffix = if i == 0 { "weight" } else { "bias" };
929 result.push((format!("fc.{suffix}"), p.clone()));
930 }
931
932 Ok(result)
933 }
934
935 pub fn load_named_params(
940 &mut self,
941 params_map: &HashMap<String, Array<F, IxDyn>>,
942 ) -> Result<()> {
943 if let Some(w) = params_map.get("conv1.weight") {
945 let mut ps = vec![w.clone()];
946 if let Some(b) = params_map.get("conv1.bias") {
947 ps.push(b.clone());
948 }
949 self.conv1.set_params(&ps)?;
950 }
951 if let Some(w) = params_map.get("bn1.weight") {
953 let mut ps = vec![w.clone()];
954 if let Some(b) = params_map.get("bn1.bias") {
955 ps.push(b.clone());
956 }
957 self.bn1.set_params(&ps)?;
958 }
959
960 for (idx, block) in self.layer1.iter_mut().enumerate() {
962 block.load_named_params(&format!("layer1.{idx}"), params_map)?;
963 }
964 for (idx, block) in self.layer1_bottleneck.iter_mut().enumerate() {
966 block.load_named_params(&format!("layer1.{idx}"), params_map)?;
967 }
968
969 if let Some(w) = params_map.get("fc.weight") {
971 let mut ps = vec![w.clone()];
972 if let Some(b) = params_map.get("fc.bias") {
973 ps.push(b.clone());
974 }
975 self.fc.set_params(&ps)?;
976 }
977
978 Ok(())
979 }
980}
981
982impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Layer<F> for ResNet<F> {
983 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
984 let mut x = self.conv1.forward(input)?;
986 x = self.bn1.forward(&x)?;
987 x = x.mapv(|v: F| v.max(F::zero())); for block in &self.layer1 {
991 x = block.forward(&x)?;
992 }
993
994 for block in &self.layer1_bottleneck {
996 x = block.forward(&x)?;
997 }
998
999 x = Self::global_avg_pool(&x)?;
1001
1002 if let Some(ref dropout) = self.dropout {
1004 x = dropout.forward(&x)?;
1005 }
1006
1007 x = self.fc.forward(&x)?;
1009
1010 Ok(x)
1011 }
1012
1013 fn backward(
1014 &self,
1015 _input: &Array<F, IxDyn>,
1016 grad_output: &Array<F, IxDyn>,
1017 ) -> Result<Array<F, IxDyn>> {
1018 Ok(grad_output.clone())
1019 }
1020
1021 fn update(&mut self, learning_rate: F) -> Result<()> {
1022 self.conv1.update(learning_rate)?;
1023 self.bn1.update(learning_rate)?;
1024
1025 for block in &mut self.layer1 {
1026 block.update(learning_rate)?;
1027 }
1028
1029 for block in &mut self.layer1_bottleneck {
1030 block.update(learning_rate)?;
1031 }
1032
1033 self.fc.update(learning_rate)?;
1034
1035 Ok(())
1036 }
1037
1038 fn as_any(&self) -> &dyn std::any::Any {
1039 self
1040 }
1041
1042 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
1043 self
1044 }
1045}
1046
1047#[cfg(test)]
1048mod tests {
1049 use super::*;
1050
1051 #[test]
1052 fn test_resnet_config_18() {
1053 let config = ResNetConfig::resnet18(3, 1000);
1054 assert_eq!(config.input_channels, 3);
1055 assert_eq!(config.num_classes, 1000);
1056 assert_eq!(config.layers.len(), 4);
1057 assert!(matches!(config.block, ResNetBlock::Basic));
1058 }
1059
1060 #[test]
1061 fn test_resnet_config_50() {
1062 let config = ResNetConfig::resnet50(3, 1000);
1063 assert!(matches!(config.block, ResNetBlock::Bottleneck));
1064 assert_eq!(config.layers.len(), 4);
1065 }
1066
1067 #[test]
1068 fn test_resnet_config_with_dropout() {
1069 let config = ResNetConfig::resnet18(3, 100).with_dropout(0.5);
1070 assert_eq!(config.dropout_rate, 0.5);
1071 }
1072
1073 #[test]
1074 fn test_resnet_config_variants() {
1075 let config34 = ResNetConfig::resnet34(3, 1000);
1076 assert_eq!(config34.layers[0].blocks, 3);
1077 assert_eq!(config34.layers[1].blocks, 4);
1078
1079 let config101 = ResNetConfig::resnet101(3, 1000);
1080 assert_eq!(config101.layers[2].blocks, 23);
1081
1082 let config152 = ResNetConfig::resnet152(3, 1000);
1083 assert_eq!(config152.layers[2].blocks, 36);
1084 }
1085}