Skip to main content

yscv_model/
zoo.rs

1//! Pretrained model zoo: architecture registry, builders, and weight management.
2
3use std::collections::HashMap;
4use std::path::PathBuf;
5
6use serde::{Deserialize, Serialize};
7
8use yscv_autograd::Graph;
9use yscv_tensor::Tensor;
10
11use crate::{
12    ModelError, SequentialModel, add_bottleneck_block, add_residual_block,
13    build_resnet_feature_extractor, load_weights, save_weights,
14};
15
16// ---------------------------------------------------------------------------
17// Architecture registry
18// ---------------------------------------------------------------------------
19
20/// Known model architectures in the zoo.
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
22pub enum ModelArchitecture {
23    ResNet18,
24    ResNet34,
25    ResNet50,
26    ResNet101,
27    VGG16,
28    VGG19,
29    MobileNetV2,
30    EfficientNetB0,
31    AlexNet,
32    ViTTiny,
33    ViTBase,
34    ViTLarge,
35    DeiTTiny,
36}
37
38impl ModelArchitecture {
39    /// Returns the canonical configuration for this architecture.
40    pub fn config(&self) -> ArchitectureConfig {
41        match self {
42            Self::ResNet18 => ArchitectureConfig {
43                input_channels: 3,
44                num_classes: 1000,
45                stage_channels: vec![64, 128, 256, 512],
46                blocks_per_stage: vec![2, 2, 2, 2],
47            },
48            Self::ResNet34 => ArchitectureConfig {
49                input_channels: 3,
50                num_classes: 1000,
51                stage_channels: vec![64, 128, 256, 512],
52                blocks_per_stage: vec![3, 4, 6, 3],
53            },
54            Self::ResNet50 => ArchitectureConfig {
55                input_channels: 3,
56                num_classes: 1000,
57                stage_channels: vec![64, 128, 256, 512],
58                blocks_per_stage: vec![3, 4, 6, 3],
59            },
60            Self::ResNet101 => ArchitectureConfig {
61                input_channels: 3,
62                num_classes: 1000,
63                stage_channels: vec![64, 128, 256, 512],
64                blocks_per_stage: vec![3, 4, 23, 3],
65            },
66            Self::VGG16 => ArchitectureConfig {
67                input_channels: 3,
68                num_classes: 1000,
69                stage_channels: vec![64, 128, 256, 512, 512],
70                blocks_per_stage: vec![2, 2, 3, 3, 3],
71            },
72            Self::VGG19 => ArchitectureConfig {
73                input_channels: 3,
74                num_classes: 1000,
75                stage_channels: vec![64, 128, 256, 512, 512],
76                blocks_per_stage: vec![2, 2, 4, 4, 4],
77            },
78            Self::MobileNetV2 => ArchitectureConfig {
79                input_channels: 3,
80                num_classes: 1000,
81                stage_channels: vec![32, 16, 24, 32, 64, 96, 160, 320],
82                blocks_per_stage: vec![1, 1, 2, 3, 4, 3, 3, 1],
83            },
84            Self::EfficientNetB0 => ArchitectureConfig {
85                input_channels: 3,
86                num_classes: 1000,
87                stage_channels: vec![32, 16, 24, 40, 80, 112, 192, 320],
88                blocks_per_stage: vec![1, 1, 2, 2, 3, 3, 4, 1],
89            },
90            Self::AlexNet => ArchitectureConfig {
91                input_channels: 3,
92                num_classes: 1000,
93                stage_channels: vec![64, 192, 384, 256, 256],
94                blocks_per_stage: vec![1, 1, 1, 1, 1],
95            },
96            Self::ViTTiny => ArchitectureConfig {
97                input_channels: 3,
98                num_classes: 1000,
99                stage_channels: vec![192],  // embed_dim
100                blocks_per_stage: vec![12], // num_layers
101            },
102            Self::ViTBase => ArchitectureConfig {
103                input_channels: 3,
104                num_classes: 1000,
105                stage_channels: vec![768],
106                blocks_per_stage: vec![12],
107            },
108            Self::ViTLarge => ArchitectureConfig {
109                input_channels: 3,
110                num_classes: 1000,
111                stage_channels: vec![1024],
112                blocks_per_stage: vec![24],
113            },
114            Self::DeiTTiny => ArchitectureConfig {
115                input_channels: 3,
116                num_classes: 1000,
117                stage_channels: vec![192],
118                blocks_per_stage: vec![12],
119            },
120        }
121    }
122
123    /// Returns a filesystem-safe name for this architecture (used for weight files).
124    pub fn name(&self) -> &'static str {
125        match self {
126            Self::ResNet18 => "resnet18",
127            Self::ResNet34 => "resnet34",
128            Self::ResNet50 => "resnet50",
129            Self::ResNet101 => "resnet101",
130            Self::VGG16 => "vgg16",
131            Self::VGG19 => "vgg19",
132            Self::MobileNetV2 => "mobilenet_v2",
133            Self::EfficientNetB0 => "efficientnet_b0",
134            Self::AlexNet => "alexnet",
135            Self::ViTTiny => "vit_tiny",
136            Self::ViTBase => "vit_base",
137            Self::ViTLarge => "vit_large",
138            Self::DeiTTiny => "deit_tiny",
139        }
140    }
141
142    /// All known architectures.
143    pub fn all() -> &'static [ModelArchitecture] {
144        &[
145            Self::ResNet18,
146            Self::ResNet34,
147            Self::ResNet50,
148            Self::ResNet101,
149            Self::VGG16,
150            Self::VGG19,
151            Self::MobileNetV2,
152            Self::EfficientNetB0,
153            Self::AlexNet,
154            Self::ViTTiny,
155            Self::ViTBase,
156            Self::ViTLarge,
157            Self::DeiTTiny,
158        ]
159    }
160}
161
162impl std::fmt::Display for ModelArchitecture {
163    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164        f.write_str(self.name())
165    }
166}
167
168// ---------------------------------------------------------------------------
169// Architecture config
170// ---------------------------------------------------------------------------
171
172/// Describes the shape of a model architecture.
173#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
174pub struct ArchitectureConfig {
175    /// Number of input channels (typically 3 for RGB).
176    pub input_channels: usize,
177    /// Number of output classes (default 1000 for ImageNet).
178    pub num_classes: usize,
179    /// Channel widths per stage.
180    pub stage_channels: Vec<usize>,
181    /// Block counts per stage.
182    pub blocks_per_stage: Vec<usize>,
183}
184
185impl ArchitectureConfig {
186    /// Returns a copy with a different number of output classes.
187    pub fn with_num_classes(&self, num_classes: usize) -> Self {
188        let mut cfg = self.clone();
189        cfg.num_classes = num_classes;
190        cfg
191    }
192}
193
194// ---------------------------------------------------------------------------
195// Architecture builders
196// ---------------------------------------------------------------------------
197
198const BN_EPSILON: f32 = 1e-5;
199
200/// Builds a ResNet-family model: stem + residual stages + global-avg-pool + linear head.
201pub fn build_resnet(
202    graph: &mut Graph,
203    config: &ArchitectureConfig,
204) -> Result<SequentialModel, ModelError> {
205    let mut model = SequentialModel::new(graph);
206    let max_blocks = config.blocks_per_stage.iter().copied().max().unwrap_or(2);
207    build_resnet_feature_extractor(
208        &mut model,
209        config.input_channels,
210        &config.stage_channels,
211        max_blocks,
212        BN_EPSILON,
213    )?;
214    let final_ch = config.stage_channels.last().copied().unwrap_or(512);
215    model.add_linear_zero(graph, final_ch, config.num_classes)?;
216    Ok(model)
217}
218
219/// Builds a ResNet with per-stage block counts (bypasses the single-count helper).
220pub fn build_resnet_custom(
221    graph: &mut Graph,
222    config: &ArchitectureConfig,
223) -> Result<SequentialModel, ModelError> {
224    let mut model = SequentialModel::new(graph);
225    let initial_ch = config.stage_channels.first().copied().unwrap_or(64);
226
227    model.add_conv2d_zero(config.input_channels, initial_ch, 7, 7, 2, 2, true)?;
228    model.add_batch_norm2d_identity(initial_ch, BN_EPSILON)?;
229    model.add_relu();
230    model.add_max_pool2d(3, 3, 2, 2)?;
231
232    let mut ch = initial_ch;
233    for (stage_idx, &stage_ch) in config.stage_channels.iter().enumerate() {
234        if stage_ch != ch {
235            model.add_conv2d_zero(ch, stage_ch, 1, 1, 1, 1, false)?;
236            model.add_batch_norm2d_identity(stage_ch, BN_EPSILON)?;
237            model.add_relu();
238        }
239        let blocks = config.blocks_per_stage.get(stage_idx).copied().unwrap_or(2);
240        for _ in 0..blocks {
241            add_residual_block(&mut model, stage_ch, BN_EPSILON)?;
242        }
243        ch = stage_ch;
244    }
245
246    model.add_global_avg_pool2d();
247    model.add_flatten();
248    model.add_linear_zero(graph, ch, config.num_classes)?;
249    Ok(model)
250}
251
252/// Builds a VGG-style sequential conv network.
253///
254/// Each stage: `blocks_per_stage[i]` x (Conv3x3 -> BN -> ReLU), then MaxPool2x2.
255pub fn build_vgg(
256    graph: &mut Graph,
257    config: &ArchitectureConfig,
258) -> Result<SequentialModel, ModelError> {
259    let mut model = SequentialModel::new(graph);
260    let mut ch = config.input_channels;
261
262    for (stage_idx, &out_ch) in config.stage_channels.iter().enumerate() {
263        let blocks = config.blocks_per_stage.get(stage_idx).copied().unwrap_or(2);
264        for b in 0..blocks {
265            let in_ch = if b == 0 { ch } else { out_ch };
266            model.add_conv2d_zero(in_ch, out_ch, 3, 3, 1, 1, true)?;
267            model.add_batch_norm2d_identity(out_ch, BN_EPSILON)?;
268            model.add_relu();
269        }
270        model.add_max_pool2d(2, 2, 2, 2)?;
271        ch = out_ch;
272    }
273
274    model.add_global_avg_pool2d();
275    model.add_flatten();
276    model.add_linear_zero(graph, ch, config.num_classes)?;
277    Ok(model)
278}
279
280/// Builds a MobileNetV2-style model using inverted bottleneck blocks.
281pub fn build_mobilenet_v2(
282    graph: &mut Graph,
283    config: &ArchitectureConfig,
284) -> Result<SequentialModel, ModelError> {
285    let mut model = SequentialModel::new(graph);
286    let stem_ch = config.stage_channels.first().copied().unwrap_or(32);
287    model.add_conv2d_zero(config.input_channels, stem_ch, 3, 3, 2, 2, false)?;
288    model.add_batch_norm2d_identity(stem_ch, BN_EPSILON)?;
289    model.add_relu();
290
291    let mut ch = stem_ch;
292    for (stage_idx, &out_ch) in config.stage_channels.iter().enumerate().skip(1) {
293        let blocks = config.blocks_per_stage.get(stage_idx).copied().unwrap_or(1);
294        let expand_ratio = 6;
295        for b in 0..blocks {
296            let stride = if b == 0 && stage_idx > 1 { 2 } else { 1 };
297            let expand_ch = ch * expand_ratio;
298            add_bottleneck_block(&mut model, ch, expand_ch, out_ch, stride, BN_EPSILON)?;
299            ch = out_ch;
300        }
301    }
302
303    let last_ch = 1280;
304    model.add_conv2d_zero(ch, last_ch, 1, 1, 1, 1, false)?;
305    model.add_batch_norm2d_identity(last_ch, BN_EPSILON)?;
306    model.add_relu();
307    model.add_global_avg_pool2d();
308    model.add_flatten();
309    model.add_linear_zero(graph, last_ch, config.num_classes)?;
310    Ok(model)
311}
312
313/// Builds a simple AlexNet-style conv stack.
314pub fn build_alexnet(
315    graph: &mut Graph,
316    config: &ArchitectureConfig,
317) -> Result<SequentialModel, ModelError> {
318    let mut model = SequentialModel::new(graph);
319    let channels = &config.stage_channels;
320
321    let ch0 = channels.first().copied().unwrap_or(64);
322    model.add_conv2d_zero(config.input_channels, ch0, 11, 11, 4, 4, true)?;
323    model.add_relu();
324    model.add_max_pool2d(3, 3, 2, 2)?;
325
326    let ch1 = channels.get(1).copied().unwrap_or(192);
327    model.add_conv2d_zero(ch0, ch1, 5, 5, 1, 1, true)?;
328    model.add_relu();
329    model.add_max_pool2d(3, 3, 2, 2)?;
330
331    let ch2 = channels.get(2).copied().unwrap_or(384);
332    model.add_conv2d_zero(ch1, ch2, 3, 3, 1, 1, true)?;
333    model.add_relu();
334
335    let ch3 = channels.get(3).copied().unwrap_or(256);
336    model.add_conv2d_zero(ch2, ch3, 3, 3, 1, 1, true)?;
337    model.add_relu();
338
339    let ch4 = channels.get(4).copied().unwrap_or(256);
340    model.add_conv2d_zero(ch3, ch4, 3, 3, 1, 1, true)?;
341    model.add_relu();
342    model.add_max_pool2d(3, 3, 2, 2)?;
343
344    model.add_global_avg_pool2d();
345    model.add_flatten();
346    model.add_linear_zero(graph, ch4, config.num_classes)?;
347    Ok(model)
348}
349
350// ---------------------------------------------------------------------------
351// Model Zoo (weight loading / saving)
352// ---------------------------------------------------------------------------
353
354/// File-based pretrained model registry.
355///
356/// Points to a local directory that stores `{arch_name}.bin` weight files
357/// in the same format as `save_weights` / `load_weights`.
358pub struct ModelZoo {
359    registry_dir: PathBuf,
360}
361
362impl ModelZoo {
363    /// Creates a new zoo pointing at `registry_dir`.
364    pub fn new(registry_dir: impl Into<PathBuf>) -> Self {
365        Self {
366            registry_dir: registry_dir.into(),
367        }
368    }
369
370    fn weight_path(&self, arch: ModelArchitecture) -> PathBuf {
371        self.registry_dir.join(format!("{}.bin", arch.name()))
372    }
373
374    /// Builds the architecture and loads pretrained weights from
375    /// `{registry_dir}/{arch_name}.bin`.
376    pub fn load_pretrained(
377        &self,
378        arch: ModelArchitecture,
379        graph: &mut Graph,
380    ) -> Result<SequentialModel, ModelError> {
381        let path = self.weight_path(arch);
382        let weights = load_weights(&path)?;
383        let config = arch.config();
384        let mut model = build_architecture(arch, graph, &config)?;
385        apply_weights(&mut model, graph, &weights)?;
386        Ok(model)
387    }
388
389    /// Lists architectures for which a `.bin` weight file exists in the registry.
390    pub fn list_available(&self) -> Vec<ModelArchitecture> {
391        ModelArchitecture::all()
392            .iter()
393            .copied()
394            .filter(|a| self.weight_path(*a).is_file())
395            .collect()
396    }
397
398    /// Saves a model's layer weights to `{registry_dir}/{arch_name}.bin`.
399    pub fn save_pretrained(
400        &self,
401        arch: ModelArchitecture,
402        model: &SequentialModel,
403        graph: &Graph,
404    ) -> Result<(), ModelError> {
405        let path = self.weight_path(arch);
406        if let Some(parent) = path.parent() {
407            std::fs::create_dir_all(parent).map_err(|e| ModelError::DatasetLoadIo {
408                path: parent.display().to_string(),
409                message: e.to_string(),
410            })?;
411        }
412        let tensors = collect_model_tensors(model, graph)?;
413        save_weights(&path, &tensors)
414    }
415}
416
417/// Collects all named tensors from a `SequentialModel` for serialization.
418fn collect_model_tensors(
419    model: &SequentialModel,
420    graph: &Graph,
421) -> Result<HashMap<String, yscv_tensor::Tensor>, ModelError> {
422    let mut tensors = HashMap::new();
423    for (idx, layer) in model.layers().iter().enumerate() {
424        match layer {
425            crate::ModelLayer::Conv2d(l) => {
426                tensors.insert(format!("layer.{idx}.conv2d.weight"), l.weight().clone());
427                if let Some(b) = l.bias() {
428                    tensors.insert(format!("layer.{idx}.conv2d.bias"), b.clone());
429                }
430            }
431            crate::ModelLayer::BatchNorm2d(l) => {
432                tensors.insert(format!("layer.{idx}.bn.gamma"), l.gamma().clone());
433                tensors.insert(format!("layer.{idx}.bn.beta"), l.beta().clone());
434                tensors.insert(
435                    format!("layer.{idx}.bn.running_mean"),
436                    l.running_mean().clone(),
437                );
438                tensors.insert(
439                    format!("layer.{idx}.bn.running_var"),
440                    l.running_var().clone(),
441                );
442            }
443            crate::ModelLayer::Linear(l) => {
444                let w = graph
445                    .value(l.weight_node().expect("linear layer has weight node"))?
446                    .clone();
447                let b = graph
448                    .value(l.bias_node().expect("linear layer has bias node"))?
449                    .clone();
450                tensors.insert(format!("layer.{idx}.linear.weight"), w);
451                tensors.insert(format!("layer.{idx}.linear.bias"), b);
452            }
453            _ => {}
454        }
455    }
456    Ok(tensors)
457}
458
459/// Apply named weight tensors to a SequentialModel's layers.
460/// Uses the same naming convention as `collect_model_tensors`:
461/// - `layer.{idx}.conv2d.weight`, `layer.{idx}.conv2d.bias`
462/// - `layer.{idx}.bn.gamma`, `layer.{idx}.bn.beta`, `layer.{idx}.bn.running_mean`, `layer.{idx}.bn.running_var`
463/// - `layer.{idx}.linear.weight`, `layer.{idx}.linear.bias`
464fn apply_weights(
465    model: &mut SequentialModel,
466    graph: &mut Graph,
467    weights: &HashMap<String, Tensor>,
468) -> Result<(), ModelError> {
469    for (idx, layer) in model.layers_mut().iter_mut().enumerate() {
470        match layer {
471            crate::ModelLayer::Conv2d(l) => {
472                if let Some(w) = weights.get(&format!("layer.{idx}.conv2d.weight")) {
473                    *l.weight_mut() = w.clone();
474                }
475                if let Some(b) = weights.get(&format!("layer.{idx}.conv2d.bias"))
476                    && let Some(bias) = l.bias_mut()
477                {
478                    *bias = b.clone();
479                }
480            }
481            crate::ModelLayer::BatchNorm2d(l) => {
482                if let Some(g) = weights.get(&format!("layer.{idx}.bn.gamma")) {
483                    *l.gamma_mut() = g.clone();
484                }
485                if let Some(b) = weights.get(&format!("layer.{idx}.bn.beta")) {
486                    *l.beta_mut() = b.clone();
487                }
488                if let Some(m) = weights.get(&format!("layer.{idx}.bn.running_mean")) {
489                    *l.running_mean_mut() = m.clone();
490                }
491                if let Some(v) = weights.get(&format!("layer.{idx}.bn.running_var")) {
492                    *l.running_var_mut() = v.clone();
493                }
494            }
495            crate::ModelLayer::Linear(l) => {
496                if let Some(w) = weights.get(&format!("layer.{idx}.linear.weight")) {
497                    *graph.value_mut(l.weight_node().expect("linear layer has weight node"))? =
498                        w.clone();
499                }
500                if let Some(b) = weights.get(&format!("layer.{idx}.linear.bias")) {
501                    *graph.value_mut(l.bias_node().expect("linear layer has bias node"))? =
502                        b.clone();
503                }
504            }
505            _ => {}
506        }
507    }
508    Ok(())
509}
510
511// ---------------------------------------------------------------------------
512// Feature extraction API
513// ---------------------------------------------------------------------------
514
515/// Builds a backbone (feature extractor) without the final classifier head.
516pub fn build_feature_extractor(
517    arch: ModelArchitecture,
518    graph: &mut Graph,
519    config: &ArchitectureConfig,
520) -> Result<SequentialModel, ModelError> {
521    match arch {
522        ModelArchitecture::ResNet18
523        | ModelArchitecture::ResNet34
524        | ModelArchitecture::ResNet50
525        | ModelArchitecture::ResNet101 => {
526            let mut model = SequentialModel::new(graph);
527            let max_blocks = config.blocks_per_stage.iter().copied().max().unwrap_or(2);
528            build_resnet_feature_extractor(
529                &mut model,
530                config.input_channels,
531                &config.stage_channels,
532                max_blocks,
533                BN_EPSILON,
534            )?;
535            Ok(model)
536        }
537        ModelArchitecture::VGG16 | ModelArchitecture::VGG19 => {
538            let mut model = SequentialModel::new(graph);
539            let mut ch = config.input_channels;
540            for (stage_idx, &out_ch) in config.stage_channels.iter().enumerate() {
541                let blocks = config.blocks_per_stage.get(stage_idx).copied().unwrap_or(2);
542                for b in 0..blocks {
543                    let in_ch = if b == 0 { ch } else { out_ch };
544                    model.add_conv2d_zero(in_ch, out_ch, 3, 3, 1, 1, true)?;
545                    model.add_batch_norm2d_identity(out_ch, BN_EPSILON)?;
546                    model.add_relu();
547                }
548                model.add_max_pool2d(2, 2, 2, 2)?;
549                ch = out_ch;
550            }
551            model.add_global_avg_pool2d();
552            model.add_flatten();
553            Ok(model)
554        }
555        ModelArchitecture::MobileNetV2 | ModelArchitecture::EfficientNetB0 => {
556            let mut model = SequentialModel::new(graph);
557            let stem_ch = config.stage_channels.first().copied().unwrap_or(32);
558            model.add_conv2d_zero(config.input_channels, stem_ch, 3, 3, 2, 2, false)?;
559            model.add_batch_norm2d_identity(stem_ch, BN_EPSILON)?;
560            model.add_relu();
561
562            let mut ch = stem_ch;
563            for (stage_idx, &out_ch) in config.stage_channels.iter().enumerate().skip(1) {
564                let blocks = config.blocks_per_stage.get(stage_idx).copied().unwrap_or(1);
565                for b in 0..blocks {
566                    let stride = if b == 0 && stage_idx > 1 { 2 } else { 1 };
567                    let expand_ch = ch * 6;
568                    add_bottleneck_block(&mut model, ch, expand_ch, out_ch, stride, BN_EPSILON)?;
569                    ch = out_ch;
570                }
571            }
572            let last_ch = 1280;
573            model.add_conv2d_zero(ch, last_ch, 1, 1, 1, 1, false)?;
574            model.add_batch_norm2d_identity(last_ch, BN_EPSILON)?;
575            model.add_relu();
576            model.add_global_avg_pool2d();
577            model.add_flatten();
578            Ok(model)
579        }
580        ModelArchitecture::AlexNet => {
581            let mut model = SequentialModel::new(graph);
582            let channels = &config.stage_channels;
583            let ch0 = channels.first().copied().unwrap_or(64);
584            model.add_conv2d_zero(config.input_channels, ch0, 11, 11, 4, 4, true)?;
585            model.add_relu();
586            model.add_max_pool2d(3, 3, 2, 2)?;
587
588            let ch1 = channels.get(1).copied().unwrap_or(192);
589            model.add_conv2d_zero(ch0, ch1, 5, 5, 1, 1, true)?;
590            model.add_relu();
591            model.add_max_pool2d(3, 3, 2, 2)?;
592
593            let ch2 = channels.get(2).copied().unwrap_or(384);
594            model.add_conv2d_zero(ch1, ch2, 3, 3, 1, 1, true)?;
595            model.add_relu();
596
597            let ch3 = channels.get(3).copied().unwrap_or(256);
598            model.add_conv2d_zero(ch2, ch3, 3, 3, 1, 1, true)?;
599            model.add_relu();
600
601            let ch4 = channels.get(4).copied().unwrap_or(256);
602            model.add_conv2d_zero(ch3, ch4, 3, 3, 1, 1, true)?;
603            model.add_relu();
604            model.add_max_pool2d(3, 3, 2, 2)?;
605            model.add_global_avg_pool2d();
606            model.add_flatten();
607            Ok(model)
608        }
609        ModelArchitecture::ViTTiny
610        | ModelArchitecture::ViTBase
611        | ModelArchitecture::ViTLarge
612        | ModelArchitecture::DeiTTiny => {
613            let embed_dim = config.stage_channels.first().copied().unwrap_or(192);
614            let mut model = SequentialModel::new(graph);
615            model.add_conv2d_zero(config.input_channels, embed_dim, 16, 16, 16, 16, false)?;
616            model.add_flatten();
617            Ok(model)
618        }
619    }
620}
621
622/// Builds a full classifier with a custom number of output classes.
623pub fn build_classifier(
624    arch: ModelArchitecture,
625    graph: &mut Graph,
626    num_classes: usize,
627) -> Result<SequentialModel, ModelError> {
628    let config = arch.config().with_num_classes(num_classes);
629    build_architecture(arch, graph, &config)
630}
631
632/// Internal dispatcher: builds a complete model for any architecture.
633fn build_architecture(
634    arch: ModelArchitecture,
635    graph: &mut Graph,
636    config: &ArchitectureConfig,
637) -> Result<SequentialModel, ModelError> {
638    match arch {
639        ModelArchitecture::ResNet18
640        | ModelArchitecture::ResNet34
641        | ModelArchitecture::ResNet50
642        | ModelArchitecture::ResNet101 => build_resnet_custom(graph, config),
643        ModelArchitecture::VGG16 | ModelArchitecture::VGG19 => build_vgg(graph, config),
644        ModelArchitecture::MobileNetV2 | ModelArchitecture::EfficientNetB0 => {
645            build_mobilenet_v2(graph, config)
646        }
647        ModelArchitecture::AlexNet => build_alexnet(graph, config),
648        ModelArchitecture::ViTTiny
649        | ModelArchitecture::ViTBase
650        | ModelArchitecture::ViTLarge
651        | ModelArchitecture::DeiTTiny => {
652            let embed_dim = config.stage_channels.first().copied().unwrap_or(192);
653            let mut model = SequentialModel::new(graph);
654            model.add_conv2d_zero(config.input_channels, embed_dim, 16, 16, 16, 16, false)?;
655            model.add_flatten();
656            model.add_linear_zero(graph, embed_dim, config.num_classes)?;
657            Ok(model)
658        }
659    }
660}
661
662#[cfg(test)]
663mod tests {
664    use super::*;
665
666    #[test]
667    fn test_load_pretrained_applies_weights() -> Result<(), Box<dyn std::error::Error>> {
668        let mut graph = Graph::new();
669        let arch = ModelArchitecture::AlexNet;
670        let config = arch.config();
671        let model = build_architecture(arch, &mut graph, &config)?;
672
673        // Collect initial (zero) tensors and set some non-zero values.
674        let mut tensors = collect_model_tensors(&model, &graph)?;
675
676        // Verify we actually have tensors to work with.
677        assert!(!tensors.is_empty(), "model should have named tensors");
678
679        // Fill every tensor with 0.42 so we can detect whether they get applied.
680        for t in tensors.values_mut() {
681            let len = t.data().len();
682            *t = yscv_tensor::Tensor::from_vec(t.shape().to_vec(), vec![0.42_f32; len])?;
683        }
684
685        // Save to a temp file, then load into a fresh model.
686        let tmp_dir = std::env::temp_dir().join("yscv_test_zoo");
687        let zoo = ModelZoo::new(&tmp_dir);
688        let path = zoo.weight_path(arch);
689        if let Some(parent) = path.parent() {
690            std::fs::create_dir_all(parent).ok();
691        }
692        save_weights(&path, &tensors)?;
693
694        // Build a fresh model (zero-init) and apply the saved weights.
695        let mut graph2 = Graph::new();
696        let loaded_model = zoo.load_pretrained(arch, &mut graph2)?;
697
698        // Collect tensors from the loaded model and verify they are non-zero (0.42).
699        let loaded_tensors = collect_model_tensors(&loaded_model, &graph2)?;
700
701        for (name, original) in &tensors {
702            let loaded = loaded_tensors
703                .get(name)
704                .ok_or_else(|| ModelError::WeightNotFound { name: name.clone() })?;
705            assert_eq!(
706                original.shape(),
707                loaded.shape(),
708                "shape mismatch for {name}"
709            );
710            // Every value should be 0.42, not zero.
711            for (i, (&orig, &load)) in original.data().iter().zip(loaded.data().iter()).enumerate()
712            {
713                assert!(
714                    (orig - load).abs() < 1e-6,
715                    "value mismatch for {name}[{i}]: expected {orig}, got {load}"
716                );
717            }
718        }
719
720        // Cleanup.
721        std::fs::remove_file(&path).ok();
722        std::fs::remove_dir(&tmp_dir).ok();
723        Ok(())
724    }
725}