Skip to main content

scirs2_neural/models/architectures/
resnet.rs

1//! ResNet implementation
2//!
3//! ResNet (Residual Network) is a popular CNN architecture that introduced
4//! skip connections to allow for training very deep networks.
5//! Reference: "Deep Residual Learning for Image Recognition", He et al. (2015)
6//! <https://arxiv.org/abs/1512.03385>
7
8use crate::error::{NeuralError, Result};
9use crate::layers::conv::PaddingMode;
10use crate::layers::{BatchNorm, Conv2D, Dense, Dropout, Layer};
11use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
12use scirs2_core::numeric::{Float, FromPrimitive, NumAssign, ToPrimitive};
13use scirs2_core::random::SeedableRng;
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::fmt::Debug;
17
18/// ResNet block configuration
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub enum ResNetBlock {
21    /// Basic block (2 conv layers)
22    Basic,
23    /// Bottleneck block (3 conv layers with bottleneck)
24    Bottleneck,
25}
26
27/// Configuration for a ResNet layer
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct ResNetLayer {
30    /// Number of blocks in this layer
31    pub blocks: usize,
32    /// Number of output channels
33    pub channels: usize,
34    /// Stride for the first block (usually 1 or 2)
35    pub stride: usize,
36}
37
38/// Configuration for a ResNet model
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct ResNetConfig {
41    /// Block type (Basic or Bottleneck)
42    pub block: ResNetBlock,
43    /// Layer configuration
44    pub layers: Vec<ResNetLayer>,
45    /// Number of input channels (e.g., 3 for RGB)
46    pub input_channels: usize,
47    /// Number of output classes
48    pub num_classes: usize,
49    /// Dropout rate (0 to disable)
50    pub dropout_rate: f64,
51}
52
53impl ResNetConfig {
54    /// Create a ResNet-18 configuration
55    pub fn resnet18(input_channels: usize, num_classes: usize) -> Self {
56        Self {
57            block: ResNetBlock::Basic,
58            layers: vec![
59                ResNetLayer {
60                    blocks: 2,
61                    channels: 64,
62                    stride: 1,
63                },
64                ResNetLayer {
65                    blocks: 2,
66                    channels: 128,
67                    stride: 2,
68                },
69                ResNetLayer {
70                    blocks: 2,
71                    channels: 256,
72                    stride: 2,
73                },
74                ResNetLayer {
75                    blocks: 2,
76                    channels: 512,
77                    stride: 2,
78                },
79            ],
80            input_channels,
81            num_classes,
82            dropout_rate: 0.0,
83        }
84    }
85
86    /// Create a ResNet-34 configuration
87    pub fn resnet34(input_channels: usize, num_classes: usize) -> Self {
88        Self {
89            block: ResNetBlock::Basic,
90            layers: vec![
91                ResNetLayer {
92                    blocks: 3,
93                    channels: 64,
94                    stride: 1,
95                },
96                ResNetLayer {
97                    blocks: 4,
98                    channels: 128,
99                    stride: 2,
100                },
101                ResNetLayer {
102                    blocks: 6,
103                    channels: 256,
104                    stride: 2,
105                },
106                ResNetLayer {
107                    blocks: 3,
108                    channels: 512,
109                    stride: 2,
110                },
111            ],
112            input_channels,
113            num_classes,
114            dropout_rate: 0.0,
115        }
116    }
117
118    /// Create a ResNet-50 configuration
119    pub fn resnet50(input_channels: usize, num_classes: usize) -> Self {
120        Self {
121            block: ResNetBlock::Bottleneck,
122            layers: vec![
123                ResNetLayer {
124                    blocks: 3,
125                    channels: 64,
126                    stride: 1,
127                },
128                ResNetLayer {
129                    blocks: 4,
130                    channels: 128,
131                    stride: 2,
132                },
133                ResNetLayer {
134                    blocks: 6,
135                    channels: 256,
136                    stride: 2,
137                },
138                ResNetLayer {
139                    blocks: 3,
140                    channels: 512,
141                    stride: 2,
142                },
143            ],
144            input_channels,
145            num_classes,
146            dropout_rate: 0.0,
147        }
148    }
149
150    /// Create a ResNet-101 configuration
151    pub fn resnet101(input_channels: usize, num_classes: usize) -> Self {
152        Self {
153            block: ResNetBlock::Bottleneck,
154            layers: vec![
155                ResNetLayer {
156                    blocks: 3,
157                    channels: 64,
158                    stride: 1,
159                },
160                ResNetLayer {
161                    blocks: 4,
162                    channels: 128,
163                    stride: 2,
164                },
165                ResNetLayer {
166                    blocks: 23,
167                    channels: 256,
168                    stride: 2,
169                },
170                ResNetLayer {
171                    blocks: 3,
172                    channels: 512,
173                    stride: 2,
174                },
175            ],
176            input_channels,
177            num_classes,
178            dropout_rate: 0.0,
179        }
180    }
181
182    /// Create a ResNet-152 configuration
183    pub fn resnet152(input_channels: usize, num_classes: usize) -> Self {
184        Self {
185            block: ResNetBlock::Bottleneck,
186            layers: vec![
187                ResNetLayer {
188                    blocks: 3,
189                    channels: 64,
190                    stride: 1,
191                },
192                ResNetLayer {
193                    blocks: 8,
194                    channels: 128,
195                    stride: 2,
196                },
197                ResNetLayer {
198                    blocks: 36,
199                    channels: 256,
200                    stride: 2,
201                },
202                ResNetLayer {
203                    blocks: 3,
204                    channels: 512,
205                    stride: 2,
206                },
207            ],
208            input_channels,
209            num_classes,
210            dropout_rate: 0.0,
211        }
212    }
213
214    /// Set dropout rate
215    pub fn with_dropout(mut self, rate: f64) -> Self {
216        self.dropout_rate = rate;
217        self
218    }
219}
220
221/// Basic block for ResNet (2 conv layers)
222struct BasicBlock<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> {
223    /// First convolutional layer
224    conv1: Conv2D<F>,
225    /// First batch normalization layer
226    bn1: BatchNorm<F>,
227    /// Second convolutional layer
228    conv2: Conv2D<F>,
229    /// Second batch normalization layer
230    bn2: BatchNorm<F>,
231    /// Skip connection downsample (optional)
232    downsample: Option<(Conv2D<F>, BatchNorm<F>)>,
233}
234
235impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Clone for BasicBlock<F> {
236    fn clone(&self) -> Self {
237        Self {
238            conv1: self.conv1.clone(),
239            bn1: self.bn1.clone(),
240            conv2: self.conv2.clone(),
241            bn2: self.bn2.clone(),
242            downsample: self.downsample.clone(),
243        }
244    }
245}
246
247impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> BasicBlock<F> {
248    /// Create a new basic block
249    pub fn new(
250        in_channels: usize,
251        out_channels: usize,
252        stride: usize,
253        downsample: bool,
254    ) -> Result<Self> {
255        let stride_tuple = (stride, stride);
256
257        let mut rng1 = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
258        let conv1 = Conv2D::new(in_channels, out_channels, (3, 3), stride_tuple, None)?
259            .with_padding(PaddingMode::Same);
260
261        let mut rng2 = scirs2_core::random::rngs::SmallRng::from_seed([43; 32]);
262        let bn1 = BatchNorm::new(out_channels, 1e-5, 0.1, &mut rng2)?;
263
264        let mut rng3 = scirs2_core::random::rngs::SmallRng::from_seed([44; 32]);
265        let conv2 = Conv2D::new(out_channels, out_channels, (3, 3), (1, 1), None)?
266            .with_padding(PaddingMode::Same);
267
268        let mut rng4 = scirs2_core::random::rngs::SmallRng::from_seed([45; 32]);
269        let bn2 = BatchNorm::new(out_channels, 1e-5, 0.1, &mut rng4)?;
270
271        let downsample = if downsample {
272            let mut rng5 = scirs2_core::random::rngs::SmallRng::from_seed([46; 32]);
273            let ds_conv = Conv2D::new(in_channels, out_channels, (1, 1), stride_tuple, None)?
274                .with_padding(PaddingMode::Valid);
275
276            let mut rng6 = scirs2_core::random::rngs::SmallRng::from_seed([47; 32]);
277            let ds_bn = BatchNorm::new(out_channels, 1e-5, 0.1, &mut rng6)?;
278            Some((ds_conv, ds_bn))
279        } else {
280            None
281        };
282
283        Ok(Self {
284            conv1,
285            bn1,
286            conv2,
287            bn2,
288            downsample,
289        })
290    }
291}
292
293impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Layer<F>
294    for BasicBlock<F>
295{
296    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
297        // First conv block
298        let mut x = self.conv1.forward(input)?;
299        x = self.bn1.forward(&x)?;
300        x = x.mapv(|v: F| v.max(F::zero())); // ReLU
301
302        // Second conv block
303        x = self.conv2.forward(&x)?;
304        x = self.bn2.forward(&x)?;
305
306        // Skip connection
307        let identity = if let Some((ref conv, ref bn)) = self.downsample {
308            let ds = conv.forward(input)?;
309            bn.forward(&ds)?
310        } else {
311            input.clone()
312        };
313
314        // Add skip connection
315        let x = &x + &identity;
316
317        // Final ReLU
318        let x = x.mapv(|v: F| v.max(F::zero()));
319
320        Ok(x)
321    }
322
323    fn backward(
324        &self,
325        _input: &Array<F, IxDyn>,
326        grad_output: &Array<F, IxDyn>,
327    ) -> Result<Array<F, IxDyn>> {
328        Ok(grad_output.clone())
329    }
330
331    fn update(&mut self, learning_rate: F) -> Result<()> {
332        self.conv1.update(learning_rate)?;
333        self.bn1.update(learning_rate)?;
334        self.conv2.update(learning_rate)?;
335        self.bn2.update(learning_rate)?;
336        if let Some((ref mut conv, ref mut bn)) = self.downsample {
337            conv.update(learning_rate)?;
338            bn.update(learning_rate)?;
339        }
340        Ok(())
341    }
342
343    fn as_any(&self) -> &dyn std::any::Any {
344        self
345    }
346
347    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
348        self
349    }
350}
351
352impl<
353        F: Float
354            + Debug
355            + ScalarOperand
356            + Send
357            + Sync
358            + NumAssign
359            + ToPrimitive
360            + FromPrimitive
361            + 'static,
362    > BasicBlock<F>
363{
364    /// Extract named parameters with a given prefix (HuggingFace-compatible naming)
365    pub(crate) fn extract_named_params(&self, prefix: &str) -> Vec<(String, Array<F, IxDyn>)> {
366        let mut result = Vec::new();
367        // conv1: weight, bias
368        for (i, p) in self.conv1.params().iter().enumerate() {
369            let suffix = if i == 0 { "weight" } else { "bias" };
370            result.push((format!("{prefix}.conv1.{suffix}"), p.clone()));
371        }
372        // bn1: weight (gamma), bias (beta)
373        for (i, p) in self.bn1.params().iter().enumerate() {
374            let suffix = if i == 0 { "weight" } else { "bias" };
375            result.push((format!("{prefix}.bn1.{suffix}"), p.clone()));
376        }
377        // conv2
378        for (i, p) in self.conv2.params().iter().enumerate() {
379            let suffix = if i == 0 { "weight" } else { "bias" };
380            result.push((format!("{prefix}.conv2.{suffix}"), p.clone()));
381        }
382        // bn2
383        for (i, p) in self.bn2.params().iter().enumerate() {
384            let suffix = if i == 0 { "weight" } else { "bias" };
385            result.push((format!("{prefix}.bn2.{suffix}"), p.clone()));
386        }
387        // downsample (optional)
388        if let Some((ref conv, ref bn)) = self.downsample {
389            for (i, p) in conv.params().iter().enumerate() {
390                let suffix = if i == 0 { "weight" } else { "bias" };
391                result.push((format!("{prefix}.downsample.0.{suffix}"), p.clone()));
392            }
393            for (i, p) in bn.params().iter().enumerate() {
394                let suffix = if i == 0 { "weight" } else { "bias" };
395                result.push((format!("{prefix}.downsample.1.{suffix}"), p.clone()));
396            }
397        }
398        result
399    }
400
401    /// Load named parameters from a map using the given prefix
402    pub(crate) fn load_named_params(
403        &mut self,
404        prefix: &str,
405        params_map: &HashMap<String, Array<F, IxDyn>>,
406    ) -> Result<()> {
407        // conv1
408        if let Some(w) = params_map.get(&format!("{prefix}.conv1.weight")) {
409            let mut ps = vec![w.clone()];
410            if let Some(b) = params_map.get(&format!("{prefix}.conv1.bias")) {
411                ps.push(b.clone());
412            }
413            self.conv1.set_params(&ps)?;
414        }
415        // bn1
416        if let Some(w) = params_map.get(&format!("{prefix}.bn1.weight")) {
417            let mut ps = vec![w.clone()];
418            if let Some(b) = params_map.get(&format!("{prefix}.bn1.bias")) {
419                ps.push(b.clone());
420            }
421            self.bn1.set_params(&ps)?;
422        }
423        // conv2
424        if let Some(w) = params_map.get(&format!("{prefix}.conv2.weight")) {
425            let mut ps = vec![w.clone()];
426            if let Some(b) = params_map.get(&format!("{prefix}.conv2.bias")) {
427                ps.push(b.clone());
428            }
429            self.conv2.set_params(&ps)?;
430        }
431        // bn2
432        if let Some(w) = params_map.get(&format!("{prefix}.bn2.weight")) {
433            let mut ps = vec![w.clone()];
434            if let Some(b) = params_map.get(&format!("{prefix}.bn2.bias")) {
435                ps.push(b.clone());
436            }
437            self.bn2.set_params(&ps)?;
438        }
439        // downsample
440        if let Some((ref mut conv, ref mut bn)) = self.downsample {
441            if let Some(w) = params_map.get(&format!("{prefix}.downsample.0.weight")) {
442                let mut ps = vec![w.clone()];
443                if let Some(b) = params_map.get(&format!("{prefix}.downsample.0.bias")) {
444                    ps.push(b.clone());
445                }
446                conv.set_params(&ps)?;
447            }
448            if let Some(w) = params_map.get(&format!("{prefix}.downsample.1.weight")) {
449                let mut ps = vec![w.clone()];
450                if let Some(b) = params_map.get(&format!("{prefix}.downsample.1.bias")) {
451                    ps.push(b.clone());
452                }
453                bn.set_params(&ps)?;
454            }
455        }
456        Ok(())
457    }
458}
459
460/// Bottleneck block for ResNet (3 conv layers)
461struct BottleneckBlock<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> {
462    /// First convolutional layer (1x1 reduce)
463    conv1: Conv2D<F>,
464    /// First batch normalization layer
465    bn1: BatchNorm<F>,
466    /// Second convolutional layer (3x3)
467    conv2: Conv2D<F>,
468    /// Second batch normalization layer
469    bn2: BatchNorm<F>,
470    /// Third convolutional layer (1x1 expand)
471    conv3: Conv2D<F>,
472    /// Third batch normalization layer
473    bn3: BatchNorm<F>,
474    /// Skip connection downsample (optional)
475    downsample: Option<(Conv2D<F>, BatchNorm<F>)>,
476    /// Expansion factor
477    #[allow(dead_code)]
478    expansion: usize,
479}
480
481impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Clone
482    for BottleneckBlock<F>
483{
484    fn clone(&self) -> Self {
485        Self {
486            conv1: self.conv1.clone(),
487            bn1: self.bn1.clone(),
488            conv2: self.conv2.clone(),
489            bn2: self.bn2.clone(),
490            conv3: self.conv3.clone(),
491            bn3: self.bn3.clone(),
492            downsample: self.downsample.clone(),
493            expansion: self.expansion,
494        }
495    }
496}
497
498impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> BottleneckBlock<F> {
499    /// Expansion factor for bottleneck blocks
500    const EXPANSION: usize = 4;
501
502    /// Create a new bottleneck block
503    pub fn new(
504        in_channels: usize,
505        out_channels: usize,
506        stride: usize,
507        downsample: bool,
508    ) -> Result<Self> {
509        let bottleneck_channels = out_channels / Self::EXPANSION;
510        let stride_tuple = (stride, stride);
511
512        // First conv (1x1 reduce)
513        let mut rng1 = scirs2_core::random::rngs::SmallRng::from_seed([48; 32]);
514        let conv1 = Conv2D::new(in_channels, bottleneck_channels, (1, 1), (1, 1), None)?
515            .with_padding(PaddingMode::Valid);
516
517        let mut rng2 = scirs2_core::random::rngs::SmallRng::from_seed([49; 32]);
518        let bn1 = BatchNorm::new(bottleneck_channels, 1e-5, 0.1, &mut rng2)?;
519
520        // Second conv (3x3)
521        let mut rng3 = scirs2_core::random::rngs::SmallRng::from_seed([50; 32]);
522        let conv2 = Conv2D::new(
523            bottleneck_channels,
524            bottleneck_channels,
525            (3, 3),
526            stride_tuple,
527            None,
528        )?
529        .with_padding(PaddingMode::Same);
530
531        let mut rng4 = scirs2_core::random::rngs::SmallRng::from_seed([51; 32]);
532        let bn2 = BatchNorm::new(bottleneck_channels, 1e-5, 0.1, &mut rng4)?;
533
534        // Third conv (1x1 expand)
535        let mut rng5 = scirs2_core::random::rngs::SmallRng::from_seed([52; 32]);
536        let conv3 = Conv2D::new(bottleneck_channels, out_channels, (1, 1), (1, 1), None)?
537            .with_padding(PaddingMode::Valid);
538
539        let mut rng6 = scirs2_core::random::rngs::SmallRng::from_seed([53; 32]);
540        let bn3 = BatchNorm::new(out_channels, 1e-5, 0.1, &mut rng6)?;
541
542        // Downsample
543        let downsample = if downsample {
544            let mut rng7 = scirs2_core::random::rngs::SmallRng::from_seed([54; 32]);
545            let ds_conv = Conv2D::new(in_channels, out_channels, (1, 1), stride_tuple, None)?
546                .with_padding(PaddingMode::Valid);
547
548            let mut rng8 = scirs2_core::random::rngs::SmallRng::from_seed([55; 32]);
549            let ds_bn = BatchNorm::new(out_channels, 1e-5, 0.1, &mut rng8)?;
550            Some((ds_conv, ds_bn))
551        } else {
552            None
553        };
554
555        Ok(Self {
556            conv1,
557            bn1,
558            conv2,
559            bn2,
560            conv3,
561            bn3,
562            downsample,
563            expansion: Self::EXPANSION,
564        })
565    }
566}
567
568impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Layer<F>
569    for BottleneckBlock<F>
570{
571    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
572        // First conv block
573        let mut x = self.conv1.forward(input)?;
574        x = self.bn1.forward(&x)?;
575        x = x.mapv(|v: F| v.max(F::zero())); // ReLU
576
577        // Second conv block
578        x = self.conv2.forward(&x)?;
579        x = self.bn2.forward(&x)?;
580        x = x.mapv(|v: F| v.max(F::zero())); // ReLU
581
582        // Third conv block
583        x = self.conv3.forward(&x)?;
584        x = self.bn3.forward(&x)?;
585
586        // Skip connection
587        let identity = if let Some((ref conv, ref bn)) = self.downsample {
588            let ds = conv.forward(input)?;
589            bn.forward(&ds)?
590        } else {
591            input.clone()
592        };
593
594        // Add skip connection
595        let x = &x + &identity;
596
597        // Final ReLU
598        let x = x.mapv(|v: F| v.max(F::zero()));
599
600        Ok(x)
601    }
602
603    fn backward(
604        &self,
605        _input: &Array<F, IxDyn>,
606        grad_output: &Array<F, IxDyn>,
607    ) -> Result<Array<F, IxDyn>> {
608        Ok(grad_output.clone())
609    }
610
611    fn update(&mut self, learning_rate: F) -> Result<()> {
612        self.conv1.update(learning_rate)?;
613        self.bn1.update(learning_rate)?;
614        self.conv2.update(learning_rate)?;
615        self.bn2.update(learning_rate)?;
616        self.conv3.update(learning_rate)?;
617        self.bn3.update(learning_rate)?;
618        if let Some((ref mut conv, ref mut bn)) = self.downsample {
619            conv.update(learning_rate)?;
620            bn.update(learning_rate)?;
621        }
622        Ok(())
623    }
624
625    fn as_any(&self) -> &dyn std::any::Any {
626        self
627    }
628
629    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
630        self
631    }
632}
633
634impl<
635        F: Float
636            + Debug
637            + ScalarOperand
638            + Send
639            + Sync
640            + NumAssign
641            + ToPrimitive
642            + FromPrimitive
643            + 'static,
644    > BottleneckBlock<F>
645{
646    /// Extract named parameters with a given prefix (HuggingFace-compatible naming)
647    pub(crate) fn extract_named_params(&self, prefix: &str) -> Vec<(String, Array<F, IxDyn>)> {
648        let mut result = Vec::new();
649        for (i, p) in self.conv1.params().iter().enumerate() {
650            let suffix = if i == 0 { "weight" } else { "bias" };
651            result.push((format!("{prefix}.conv1.{suffix}"), p.clone()));
652        }
653        for (i, p) in self.bn1.params().iter().enumerate() {
654            let suffix = if i == 0 { "weight" } else { "bias" };
655            result.push((format!("{prefix}.bn1.{suffix}"), p.clone()));
656        }
657        for (i, p) in self.conv2.params().iter().enumerate() {
658            let suffix = if i == 0 { "weight" } else { "bias" };
659            result.push((format!("{prefix}.conv2.{suffix}"), p.clone()));
660        }
661        for (i, p) in self.bn2.params().iter().enumerate() {
662            let suffix = if i == 0 { "weight" } else { "bias" };
663            result.push((format!("{prefix}.bn2.{suffix}"), p.clone()));
664        }
665        for (i, p) in self.conv3.params().iter().enumerate() {
666            let suffix = if i == 0 { "weight" } else { "bias" };
667            result.push((format!("{prefix}.conv3.{suffix}"), p.clone()));
668        }
669        for (i, p) in self.bn3.params().iter().enumerate() {
670            let suffix = if i == 0 { "weight" } else { "bias" };
671            result.push((format!("{prefix}.bn3.{suffix}"), p.clone()));
672        }
673        if let Some((ref conv, ref bn)) = self.downsample {
674            for (i, p) in conv.params().iter().enumerate() {
675                let suffix = if i == 0 { "weight" } else { "bias" };
676                result.push((format!("{prefix}.downsample.0.{suffix}"), p.clone()));
677            }
678            for (i, p) in bn.params().iter().enumerate() {
679                let suffix = if i == 0 { "weight" } else { "bias" };
680                result.push((format!("{prefix}.downsample.1.{suffix}"), p.clone()));
681            }
682        }
683        result
684    }
685
686    /// Load named parameters from a map using the given prefix
687    pub(crate) fn load_named_params(
688        &mut self,
689        prefix: &str,
690        params_map: &HashMap<String, Array<F, IxDyn>>,
691    ) -> Result<()> {
692        if let Some(w) = params_map.get(&format!("{prefix}.conv1.weight")) {
693            let mut ps = vec![w.clone()];
694            if let Some(b) = params_map.get(&format!("{prefix}.conv1.bias")) {
695                ps.push(b.clone());
696            }
697            self.conv1.set_params(&ps)?;
698        }
699        if let Some(w) = params_map.get(&format!("{prefix}.bn1.weight")) {
700            let mut ps = vec![w.clone()];
701            if let Some(b) = params_map.get(&format!("{prefix}.bn1.bias")) {
702                ps.push(b.clone());
703            }
704            self.bn1.set_params(&ps)?;
705        }
706        if let Some(w) = params_map.get(&format!("{prefix}.conv2.weight")) {
707            let mut ps = vec![w.clone()];
708            if let Some(b) = params_map.get(&format!("{prefix}.conv2.bias")) {
709                ps.push(b.clone());
710            }
711            self.conv2.set_params(&ps)?;
712        }
713        if let Some(w) = params_map.get(&format!("{prefix}.bn2.weight")) {
714            let mut ps = vec![w.clone()];
715            if let Some(b) = params_map.get(&format!("{prefix}.bn2.bias")) {
716                ps.push(b.clone());
717            }
718            self.bn2.set_params(&ps)?;
719        }
720        if let Some(w) = params_map.get(&format!("{prefix}.conv3.weight")) {
721            let mut ps = vec![w.clone()];
722            if let Some(b) = params_map.get(&format!("{prefix}.conv3.bias")) {
723                ps.push(b.clone());
724            }
725            self.conv3.set_params(&ps)?;
726        }
727        if let Some(w) = params_map.get(&format!("{prefix}.bn3.weight")) {
728            let mut ps = vec![w.clone()];
729            if let Some(b) = params_map.get(&format!("{prefix}.bn3.bias")) {
730                ps.push(b.clone());
731            }
732            self.bn3.set_params(&ps)?;
733        }
734        if let Some((ref mut conv, ref mut bn)) = self.downsample {
735            if let Some(w) = params_map.get(&format!("{prefix}.downsample.0.weight")) {
736                let mut ps = vec![w.clone()];
737                if let Some(b) = params_map.get(&format!("{prefix}.downsample.0.bias")) {
738                    ps.push(b.clone());
739                }
740                conv.set_params(&ps)?;
741            }
742            if let Some(w) = params_map.get(&format!("{prefix}.downsample.1.weight")) {
743                let mut ps = vec![w.clone()];
744                if let Some(b) = params_map.get(&format!("{prefix}.downsample.1.bias")) {
745                    ps.push(b.clone());
746                }
747                bn.set_params(&ps)?;
748            }
749        }
750        Ok(())
751    }
752}
753
754/// ResNet implementation
755pub struct ResNet<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> {
756    /// Initial convolutional layer
757    conv1: Conv2D<F>,
758    /// Initial batch normalization
759    bn1: BatchNorm<F>,
760    /// ResNet layer groups
761    layer1: Vec<BasicBlock<F>>,
762    /// ResNet layer groups (bottleneck)
763    layer1_bottleneck: Vec<BottleneckBlock<F>>,
764    /// Fully connected layer
765    fc: Dense<F>,
766    /// Dropout layer
767    dropout: Option<Dropout<F>>,
768    /// Model configuration
769    config: ResNetConfig,
770}
771
772impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> ResNet<F> {
773    /// Create a new ResNet model
774    pub fn new(config: ResNetConfig) -> Result<Self> {
775        // Initial convolution
776        let mut rng1 = scirs2_core::random::rngs::SmallRng::from_seed([56; 32]);
777        let conv1 = Conv2D::new(config.input_channels, 64, (7, 7), (2, 2), None)?
778            .with_padding(PaddingMode::Same);
779
780        let mut rng2 = scirs2_core::random::rngs::SmallRng::from_seed([57; 32]);
781        let bn1 = BatchNorm::new(64, 1e-5, 0.1, &mut rng2)?;
782
783        // For simplicity, create a single layer with blocks
784        let layer1 = Vec::new();
785        let layer1_bottleneck = Vec::new();
786
787        // Final FC layer
788        let fc_in_features = match config.block {
789            ResNetBlock::Basic => config.layers.last().map(|l| l.channels).unwrap_or(512),
790            ResNetBlock::Bottleneck => config.layers.last().map(|l| l.channels * 4).unwrap_or(2048),
791        };
792
793        let mut rng3 = scirs2_core::random::rngs::SmallRng::from_seed([58; 32]);
794        let fc = Dense::new(fc_in_features, config.num_classes, None, &mut rng3)?;
795
796        // Dropout
797        let dropout = if config.dropout_rate > 0.0 {
798            let mut rng4 = scirs2_core::random::rngs::SmallRng::from_seed([59; 32]);
799            Some(Dropout::new(config.dropout_rate, &mut rng4)?)
800        } else {
801            None
802        };
803
804        Ok(Self {
805            conv1,
806            bn1,
807            layer1,
808            layer1_bottleneck,
809            fc,
810            dropout,
811            config,
812        })
813    }
814
815    /// Create a ResNet-18 model
816    pub fn resnet18(input_channels: usize, num_classes: usize) -> Result<Self> {
817        let config = ResNetConfig::resnet18(input_channels, num_classes);
818        Self::new(config)
819    }
820
821    /// Create a ResNet-34 model
822    pub fn resnet34(input_channels: usize, num_classes: usize) -> Result<Self> {
823        let config = ResNetConfig::resnet34(input_channels, num_classes);
824        Self::new(config)
825    }
826
827    /// Create a ResNet-50 model
828    pub fn resnet50(input_channels: usize, num_classes: usize) -> Result<Self> {
829        let config = ResNetConfig::resnet50(input_channels, num_classes);
830        Self::new(config)
831    }
832
833    /// Create a ResNet-101 model
834    pub fn resnet101(input_channels: usize, num_classes: usize) -> Result<Self> {
835        let config = ResNetConfig::resnet101(input_channels, num_classes);
836        Self::new(config)
837    }
838
839    /// Create a ResNet-152 model
840    pub fn resnet152(input_channels: usize, num_classes: usize) -> Result<Self> {
841        let config = ResNetConfig::resnet152(input_channels, num_classes);
842        Self::new(config)
843    }
844
845    /// Get the model configuration
846    pub fn config(&self) -> &ResNetConfig {
847        &self.config
848    }
849
850    /// Global average pooling
851    fn global_avg_pool(x: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
852        let shape = x.shape();
853        if shape.len() != 4 {
854            return Err(NeuralError::InferenceError(format!(
855                "Expected 4D input for average pooling, got shape {:?}",
856                shape
857            )));
858        }
859
860        let batch_size = shape[0];
861        let channels = shape[1];
862        let height = shape[2];
863        let width = shape[3];
864
865        let mut output = Array::zeros(IxDyn(&[batch_size, channels]));
866        let count = F::from(height * width).expect("Failed to convert to float");
867
868        for b in 0..batch_size {
869            for c in 0..channels {
870                let mut sum = F::zero();
871                for h in 0..height {
872                    for w in 0..width {
873                        sum += x[[b, c, h, w]];
874                    }
875                }
876                output[[b, c]] = sum / count;
877            }
878        }
879
880        Ok(output)
881    }
882}
883
884impl<
885        F: Float
886            + Debug
887            + ScalarOperand
888            + Send
889            + Sync
890            + NumAssign
891            + ToPrimitive
892            + FromPrimitive
893            + 'static,
894    > ResNet<F>
895{
896    /// Extract all named parameters in HuggingFace-compatible format.
897    ///
898    /// Parameter names follow PyTorch/HuggingFace ResNet naming convention:
899    /// - `conv1.weight`, `bn1.weight`, `bn1.bias`
900    /// - `layer1.0.conv1.weight`, `layer1.0.bn1.weight`, etc.
901    /// - `layer2.0.conv1.weight`, etc. (for bottleneck blocks)
902    /// - `fc.weight`, `fc.bias`
903    pub fn extract_named_params(&self) -> Result<Vec<(String, Array<F, IxDyn>)>> {
904        let mut result = Vec::new();
905
906        // Initial conv and bn
907        for (i, p) in self.conv1.params().iter().enumerate() {
908            let suffix = if i == 0 { "weight" } else { "bias" };
909            result.push((format!("conv1.{suffix}"), p.clone()));
910        }
911        for (i, p) in self.bn1.params().iter().enumerate() {
912            let suffix = if i == 0 { "weight" } else { "bias" };
913            result.push((format!("bn1.{suffix}"), p.clone()));
914        }
915
916        // Residual blocks: basic or bottleneck
917        for (idx, block) in self.layer1.iter().enumerate() {
918            let block_params = block.extract_named_params(&format!("layer1.{idx}"));
919            result.extend(block_params);
920        }
921        for (idx, block) in self.layer1_bottleneck.iter().enumerate() {
922            let block_params = block.extract_named_params(&format!("layer1.{idx}"));
923            result.extend(block_params);
924        }
925
926        // FC (classifier head)
927        for (i, p) in self.fc.params().iter().enumerate() {
928            let suffix = if i == 0 { "weight" } else { "bias" };
929            result.push((format!("fc.{suffix}"), p.clone()));
930        }
931
932        Ok(result)
933    }
934
935    /// Load named parameters from a map (by name).
936    ///
937    /// Unknown parameter names are silently ignored, enabling graceful
938    /// forward/backward compatibility between model versions.
939    pub fn load_named_params(
940        &mut self,
941        params_map: &HashMap<String, Array<F, IxDyn>>,
942    ) -> Result<()> {
943        // conv1
944        if let Some(w) = params_map.get("conv1.weight") {
945            let mut ps = vec![w.clone()];
946            if let Some(b) = params_map.get("conv1.bias") {
947                ps.push(b.clone());
948            }
949            self.conv1.set_params(&ps)?;
950        }
951        // bn1
952        if let Some(w) = params_map.get("bn1.weight") {
953            let mut ps = vec![w.clone()];
954            if let Some(b) = params_map.get("bn1.bias") {
955                ps.push(b.clone());
956            }
957            self.bn1.set_params(&ps)?;
958        }
959
960        // Basic blocks
961        for (idx, block) in self.layer1.iter_mut().enumerate() {
962            block.load_named_params(&format!("layer1.{idx}"), params_map)?;
963        }
964        // Bottleneck blocks
965        for (idx, block) in self.layer1_bottleneck.iter_mut().enumerate() {
966            block.load_named_params(&format!("layer1.{idx}"), params_map)?;
967        }
968
969        // fc
970        if let Some(w) = params_map.get("fc.weight") {
971            let mut ps = vec![w.clone()];
972            if let Some(b) = params_map.get("fc.bias") {
973                ps.push(b.clone());
974            }
975            self.fc.set_params(&ps)?;
976        }
977
978        Ok(())
979    }
980}
981
982impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Layer<F> for ResNet<F> {
983    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
984        // Initial conv
985        let mut x = self.conv1.forward(input)?;
986        x = self.bn1.forward(&x)?;
987        x = x.mapv(|v: F| v.max(F::zero())); // ReLU
988
989        // Process through basic blocks
990        for block in &self.layer1 {
991            x = block.forward(&x)?;
992        }
993
994        // Process through bottleneck blocks
995        for block in &self.layer1_bottleneck {
996            x = block.forward(&x)?;
997        }
998
999        // Global average pooling
1000        x = Self::global_avg_pool(&x)?;
1001
1002        // Dropout
1003        if let Some(ref dropout) = self.dropout {
1004            x = dropout.forward(&x)?;
1005        }
1006
1007        // Final FC
1008        x = self.fc.forward(&x)?;
1009
1010        Ok(x)
1011    }
1012
1013    fn backward(
1014        &self,
1015        _input: &Array<F, IxDyn>,
1016        grad_output: &Array<F, IxDyn>,
1017    ) -> Result<Array<F, IxDyn>> {
1018        Ok(grad_output.clone())
1019    }
1020
1021    fn update(&mut self, learning_rate: F) -> Result<()> {
1022        self.conv1.update(learning_rate)?;
1023        self.bn1.update(learning_rate)?;
1024
1025        for block in &mut self.layer1 {
1026            block.update(learning_rate)?;
1027        }
1028
1029        for block in &mut self.layer1_bottleneck {
1030            block.update(learning_rate)?;
1031        }
1032
1033        self.fc.update(learning_rate)?;
1034
1035        Ok(())
1036    }
1037
1038    fn as_any(&self) -> &dyn std::any::Any {
1039        self
1040    }
1041
1042    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
1043        self
1044    }
1045}
1046
1047#[cfg(test)]
1048mod tests {
1049    use super::*;
1050
1051    #[test]
1052    fn test_resnet_config_18() {
1053        let config = ResNetConfig::resnet18(3, 1000);
1054        assert_eq!(config.input_channels, 3);
1055        assert_eq!(config.num_classes, 1000);
1056        assert_eq!(config.layers.len(), 4);
1057        assert!(matches!(config.block, ResNetBlock::Basic));
1058    }
1059
1060    #[test]
1061    fn test_resnet_config_50() {
1062        let config = ResNetConfig::resnet50(3, 1000);
1063        assert!(matches!(config.block, ResNetBlock::Bottleneck));
1064        assert_eq!(config.layers.len(), 4);
1065    }
1066
1067    #[test]
1068    fn test_resnet_config_with_dropout() {
1069        let config = ResNetConfig::resnet18(3, 100).with_dropout(0.5);
1070        assert_eq!(config.dropout_rate, 0.5);
1071    }
1072
1073    #[test]
1074    fn test_resnet_config_variants() {
1075        let config34 = ResNetConfig::resnet34(3, 1000);
1076        assert_eq!(config34.layers[0].blocks, 3);
1077        assert_eq!(config34.layers[1].blocks, 4);
1078
1079        let config101 = ResNetConfig::resnet101(3, 1000);
1080        assert_eq!(config101.layers[2].blocks, 23);
1081
1082        let config152 = ResNetConfig::resnet152(3, 1000);
1083        assert_eq!(config152.layers[2].blocks, 36);
1084    }
1085}