1use std::collections::HashMap;
4use std::path::PathBuf;
5
6use serde::{Deserialize, Serialize};
7
8use yscv_autograd::Graph;
9use yscv_tensor::Tensor;
10
11use crate::{
12 ModelError, SequentialModel, add_bottleneck_block, add_residual_block,
13 build_resnet_feature_extractor, load_weights, save_weights,
14};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
22pub enum ModelArchitecture {
23 ResNet18,
24 ResNet34,
25 ResNet50,
26 ResNet101,
27 VGG16,
28 VGG19,
29 MobileNetV2,
30 EfficientNetB0,
31 AlexNet,
32 ViTTiny,
33 ViTBase,
34 ViTLarge,
35 DeiTTiny,
36}
37
38impl ModelArchitecture {
39 pub fn config(&self) -> ArchitectureConfig {
41 match self {
42 Self::ResNet18 => ArchitectureConfig {
43 input_channels: 3,
44 num_classes: 1000,
45 stage_channels: vec![64, 128, 256, 512],
46 blocks_per_stage: vec![2, 2, 2, 2],
47 },
48 Self::ResNet34 => ArchitectureConfig {
49 input_channels: 3,
50 num_classes: 1000,
51 stage_channels: vec![64, 128, 256, 512],
52 blocks_per_stage: vec![3, 4, 6, 3],
53 },
54 Self::ResNet50 => ArchitectureConfig {
55 input_channels: 3,
56 num_classes: 1000,
57 stage_channels: vec![64, 128, 256, 512],
58 blocks_per_stage: vec![3, 4, 6, 3],
59 },
60 Self::ResNet101 => ArchitectureConfig {
61 input_channels: 3,
62 num_classes: 1000,
63 stage_channels: vec![64, 128, 256, 512],
64 blocks_per_stage: vec![3, 4, 23, 3],
65 },
66 Self::VGG16 => ArchitectureConfig {
67 input_channels: 3,
68 num_classes: 1000,
69 stage_channels: vec![64, 128, 256, 512, 512],
70 blocks_per_stage: vec![2, 2, 3, 3, 3],
71 },
72 Self::VGG19 => ArchitectureConfig {
73 input_channels: 3,
74 num_classes: 1000,
75 stage_channels: vec![64, 128, 256, 512, 512],
76 blocks_per_stage: vec![2, 2, 4, 4, 4],
77 },
78 Self::MobileNetV2 => ArchitectureConfig {
79 input_channels: 3,
80 num_classes: 1000,
81 stage_channels: vec![32, 16, 24, 32, 64, 96, 160, 320],
82 blocks_per_stage: vec![1, 1, 2, 3, 4, 3, 3, 1],
83 },
84 Self::EfficientNetB0 => ArchitectureConfig {
85 input_channels: 3,
86 num_classes: 1000,
87 stage_channels: vec![32, 16, 24, 40, 80, 112, 192, 320],
88 blocks_per_stage: vec![1, 1, 2, 2, 3, 3, 4, 1],
89 },
90 Self::AlexNet => ArchitectureConfig {
91 input_channels: 3,
92 num_classes: 1000,
93 stage_channels: vec![64, 192, 384, 256, 256],
94 blocks_per_stage: vec![1, 1, 1, 1, 1],
95 },
96 Self::ViTTiny => ArchitectureConfig {
97 input_channels: 3,
98 num_classes: 1000,
99 stage_channels: vec![192], blocks_per_stage: vec![12], },
102 Self::ViTBase => ArchitectureConfig {
103 input_channels: 3,
104 num_classes: 1000,
105 stage_channels: vec![768],
106 blocks_per_stage: vec![12],
107 },
108 Self::ViTLarge => ArchitectureConfig {
109 input_channels: 3,
110 num_classes: 1000,
111 stage_channels: vec![1024],
112 blocks_per_stage: vec![24],
113 },
114 Self::DeiTTiny => ArchitectureConfig {
115 input_channels: 3,
116 num_classes: 1000,
117 stage_channels: vec![192],
118 blocks_per_stage: vec![12],
119 },
120 }
121 }
122
123 pub fn name(&self) -> &'static str {
125 match self {
126 Self::ResNet18 => "resnet18",
127 Self::ResNet34 => "resnet34",
128 Self::ResNet50 => "resnet50",
129 Self::ResNet101 => "resnet101",
130 Self::VGG16 => "vgg16",
131 Self::VGG19 => "vgg19",
132 Self::MobileNetV2 => "mobilenet_v2",
133 Self::EfficientNetB0 => "efficientnet_b0",
134 Self::AlexNet => "alexnet",
135 Self::ViTTiny => "vit_tiny",
136 Self::ViTBase => "vit_base",
137 Self::ViTLarge => "vit_large",
138 Self::DeiTTiny => "deit_tiny",
139 }
140 }
141
142 pub fn all() -> &'static [ModelArchitecture] {
144 &[
145 Self::ResNet18,
146 Self::ResNet34,
147 Self::ResNet50,
148 Self::ResNet101,
149 Self::VGG16,
150 Self::VGG19,
151 Self::MobileNetV2,
152 Self::EfficientNetB0,
153 Self::AlexNet,
154 Self::ViTTiny,
155 Self::ViTBase,
156 Self::ViTLarge,
157 Self::DeiTTiny,
158 ]
159 }
160}
161
162impl std::fmt::Display for ModelArchitecture {
163 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164 f.write_str(self.name())
165 }
166}
167
168#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
174pub struct ArchitectureConfig {
175 pub input_channels: usize,
177 pub num_classes: usize,
179 pub stage_channels: Vec<usize>,
181 pub blocks_per_stage: Vec<usize>,
183}
184
185impl ArchitectureConfig {
186 pub fn with_num_classes(&self, num_classes: usize) -> Self {
188 let mut cfg = self.clone();
189 cfg.num_classes = num_classes;
190 cfg
191 }
192}
193
194const BN_EPSILON: f32 = 1e-5;
199
200pub fn build_resnet(
202 graph: &mut Graph,
203 config: &ArchitectureConfig,
204) -> Result<SequentialModel, ModelError> {
205 let mut model = SequentialModel::new(graph);
206 let max_blocks = config.blocks_per_stage.iter().copied().max().unwrap_or(2);
207 build_resnet_feature_extractor(
208 &mut model,
209 config.input_channels,
210 &config.stage_channels,
211 max_blocks,
212 BN_EPSILON,
213 )?;
214 let final_ch = config.stage_channels.last().copied().unwrap_or(512);
215 model.add_linear_zero(graph, final_ch, config.num_classes)?;
216 Ok(model)
217}
218
219pub fn build_resnet_custom(
221 graph: &mut Graph,
222 config: &ArchitectureConfig,
223) -> Result<SequentialModel, ModelError> {
224 let mut model = SequentialModel::new(graph);
225 let initial_ch = config.stage_channels.first().copied().unwrap_or(64);
226
227 model.add_conv2d_zero(config.input_channels, initial_ch, 7, 7, 2, 2, true)?;
228 model.add_batch_norm2d_identity(initial_ch, BN_EPSILON)?;
229 model.add_relu();
230 model.add_max_pool2d(3, 3, 2, 2)?;
231
232 let mut ch = initial_ch;
233 for (stage_idx, &stage_ch) in config.stage_channels.iter().enumerate() {
234 if stage_ch != ch {
235 model.add_conv2d_zero(ch, stage_ch, 1, 1, 1, 1, false)?;
236 model.add_batch_norm2d_identity(stage_ch, BN_EPSILON)?;
237 model.add_relu();
238 }
239 let blocks = config.blocks_per_stage.get(stage_idx).copied().unwrap_or(2);
240 for _ in 0..blocks {
241 add_residual_block(&mut model, stage_ch, BN_EPSILON)?;
242 }
243 ch = stage_ch;
244 }
245
246 model.add_global_avg_pool2d();
247 model.add_flatten();
248 model.add_linear_zero(graph, ch, config.num_classes)?;
249 Ok(model)
250}
251
252pub fn build_vgg(
256 graph: &mut Graph,
257 config: &ArchitectureConfig,
258) -> Result<SequentialModel, ModelError> {
259 let mut model = SequentialModel::new(graph);
260 let mut ch = config.input_channels;
261
262 for (stage_idx, &out_ch) in config.stage_channels.iter().enumerate() {
263 let blocks = config.blocks_per_stage.get(stage_idx).copied().unwrap_or(2);
264 for b in 0..blocks {
265 let in_ch = if b == 0 { ch } else { out_ch };
266 model.add_conv2d_zero(in_ch, out_ch, 3, 3, 1, 1, true)?;
267 model.add_batch_norm2d_identity(out_ch, BN_EPSILON)?;
268 model.add_relu();
269 }
270 model.add_max_pool2d(2, 2, 2, 2)?;
271 ch = out_ch;
272 }
273
274 model.add_global_avg_pool2d();
275 model.add_flatten();
276 model.add_linear_zero(graph, ch, config.num_classes)?;
277 Ok(model)
278}
279
280pub fn build_mobilenet_v2(
282 graph: &mut Graph,
283 config: &ArchitectureConfig,
284) -> Result<SequentialModel, ModelError> {
285 let mut model = SequentialModel::new(graph);
286 let stem_ch = config.stage_channels.first().copied().unwrap_or(32);
287 model.add_conv2d_zero(config.input_channels, stem_ch, 3, 3, 2, 2, false)?;
288 model.add_batch_norm2d_identity(stem_ch, BN_EPSILON)?;
289 model.add_relu();
290
291 let mut ch = stem_ch;
292 for (stage_idx, &out_ch) in config.stage_channels.iter().enumerate().skip(1) {
293 let blocks = config.blocks_per_stage.get(stage_idx).copied().unwrap_or(1);
294 let expand_ratio = 6;
295 for b in 0..blocks {
296 let stride = if b == 0 && stage_idx > 1 { 2 } else { 1 };
297 let expand_ch = ch * expand_ratio;
298 add_bottleneck_block(&mut model, ch, expand_ch, out_ch, stride, BN_EPSILON)?;
299 ch = out_ch;
300 }
301 }
302
303 let last_ch = 1280;
304 model.add_conv2d_zero(ch, last_ch, 1, 1, 1, 1, false)?;
305 model.add_batch_norm2d_identity(last_ch, BN_EPSILON)?;
306 model.add_relu();
307 model.add_global_avg_pool2d();
308 model.add_flatten();
309 model.add_linear_zero(graph, last_ch, config.num_classes)?;
310 Ok(model)
311}
312
313pub fn build_alexnet(
315 graph: &mut Graph,
316 config: &ArchitectureConfig,
317) -> Result<SequentialModel, ModelError> {
318 let mut model = SequentialModel::new(graph);
319 let channels = &config.stage_channels;
320
321 let ch0 = channels.first().copied().unwrap_or(64);
322 model.add_conv2d_zero(config.input_channels, ch0, 11, 11, 4, 4, true)?;
323 model.add_relu();
324 model.add_max_pool2d(3, 3, 2, 2)?;
325
326 let ch1 = channels.get(1).copied().unwrap_or(192);
327 model.add_conv2d_zero(ch0, ch1, 5, 5, 1, 1, true)?;
328 model.add_relu();
329 model.add_max_pool2d(3, 3, 2, 2)?;
330
331 let ch2 = channels.get(2).copied().unwrap_or(384);
332 model.add_conv2d_zero(ch1, ch2, 3, 3, 1, 1, true)?;
333 model.add_relu();
334
335 let ch3 = channels.get(3).copied().unwrap_or(256);
336 model.add_conv2d_zero(ch2, ch3, 3, 3, 1, 1, true)?;
337 model.add_relu();
338
339 let ch4 = channels.get(4).copied().unwrap_or(256);
340 model.add_conv2d_zero(ch3, ch4, 3, 3, 1, 1, true)?;
341 model.add_relu();
342 model.add_max_pool2d(3, 3, 2, 2)?;
343
344 model.add_global_avg_pool2d();
345 model.add_flatten();
346 model.add_linear_zero(graph, ch4, config.num_classes)?;
347 Ok(model)
348}
349
350pub struct ModelZoo {
359 registry_dir: PathBuf,
360}
361
362impl ModelZoo {
363 pub fn new(registry_dir: impl Into<PathBuf>) -> Self {
365 Self {
366 registry_dir: registry_dir.into(),
367 }
368 }
369
370 fn weight_path(&self, arch: ModelArchitecture) -> PathBuf {
371 self.registry_dir.join(format!("{}.bin", arch.name()))
372 }
373
374 pub fn load_pretrained(
377 &self,
378 arch: ModelArchitecture,
379 graph: &mut Graph,
380 ) -> Result<SequentialModel, ModelError> {
381 let path = self.weight_path(arch);
382 let weights = load_weights(&path)?;
383 let config = arch.config();
384 let mut model = build_architecture(arch, graph, &config)?;
385 apply_weights(&mut model, graph, &weights)?;
386 Ok(model)
387 }
388
389 pub fn list_available(&self) -> Vec<ModelArchitecture> {
391 ModelArchitecture::all()
392 .iter()
393 .copied()
394 .filter(|a| self.weight_path(*a).is_file())
395 .collect()
396 }
397
398 pub fn save_pretrained(
400 &self,
401 arch: ModelArchitecture,
402 model: &SequentialModel,
403 graph: &Graph,
404 ) -> Result<(), ModelError> {
405 let path = self.weight_path(arch);
406 if let Some(parent) = path.parent() {
407 std::fs::create_dir_all(parent).map_err(|e| ModelError::DatasetLoadIo {
408 path: parent.display().to_string(),
409 message: e.to_string(),
410 })?;
411 }
412 let tensors = collect_model_tensors(model, graph)?;
413 save_weights(&path, &tensors)
414 }
415}
416
417fn collect_model_tensors(
419 model: &SequentialModel,
420 graph: &Graph,
421) -> Result<HashMap<String, yscv_tensor::Tensor>, ModelError> {
422 let mut tensors = HashMap::new();
423 for (idx, layer) in model.layers().iter().enumerate() {
424 match layer {
425 crate::ModelLayer::Conv2d(l) => {
426 tensors.insert(format!("layer.{idx}.conv2d.weight"), l.weight().clone());
427 if let Some(b) = l.bias() {
428 tensors.insert(format!("layer.{idx}.conv2d.bias"), b.clone());
429 }
430 }
431 crate::ModelLayer::BatchNorm2d(l) => {
432 tensors.insert(format!("layer.{idx}.bn.gamma"), l.gamma().clone());
433 tensors.insert(format!("layer.{idx}.bn.beta"), l.beta().clone());
434 tensors.insert(
435 format!("layer.{idx}.bn.running_mean"),
436 l.running_mean().clone(),
437 );
438 tensors.insert(
439 format!("layer.{idx}.bn.running_var"),
440 l.running_var().clone(),
441 );
442 }
443 crate::ModelLayer::Linear(l) => {
444 let w = graph
445 .value(l.weight_node().expect("linear layer has weight node"))?
446 .clone();
447 let b = graph
448 .value(l.bias_node().expect("linear layer has bias node"))?
449 .clone();
450 tensors.insert(format!("layer.{idx}.linear.weight"), w);
451 tensors.insert(format!("layer.{idx}.linear.bias"), b);
452 }
453 _ => {}
454 }
455 }
456 Ok(tensors)
457}
458
459fn apply_weights(
465 model: &mut SequentialModel,
466 graph: &mut Graph,
467 weights: &HashMap<String, Tensor>,
468) -> Result<(), ModelError> {
469 for (idx, layer) in model.layers_mut().iter_mut().enumerate() {
470 match layer {
471 crate::ModelLayer::Conv2d(l) => {
472 if let Some(w) = weights.get(&format!("layer.{idx}.conv2d.weight")) {
473 *l.weight_mut() = w.clone();
474 }
475 if let Some(b) = weights.get(&format!("layer.{idx}.conv2d.bias"))
476 && let Some(bias) = l.bias_mut()
477 {
478 *bias = b.clone();
479 }
480 }
481 crate::ModelLayer::BatchNorm2d(l) => {
482 if let Some(g) = weights.get(&format!("layer.{idx}.bn.gamma")) {
483 *l.gamma_mut() = g.clone();
484 }
485 if let Some(b) = weights.get(&format!("layer.{idx}.bn.beta")) {
486 *l.beta_mut() = b.clone();
487 }
488 if let Some(m) = weights.get(&format!("layer.{idx}.bn.running_mean")) {
489 *l.running_mean_mut() = m.clone();
490 }
491 if let Some(v) = weights.get(&format!("layer.{idx}.bn.running_var")) {
492 *l.running_var_mut() = v.clone();
493 }
494 }
495 crate::ModelLayer::Linear(l) => {
496 if let Some(w) = weights.get(&format!("layer.{idx}.linear.weight")) {
497 *graph.value_mut(l.weight_node().expect("linear layer has weight node"))? =
498 w.clone();
499 }
500 if let Some(b) = weights.get(&format!("layer.{idx}.linear.bias")) {
501 *graph.value_mut(l.bias_node().expect("linear layer has bias node"))? =
502 b.clone();
503 }
504 }
505 _ => {}
506 }
507 }
508 Ok(())
509}
510
511pub fn build_feature_extractor(
517 arch: ModelArchitecture,
518 graph: &mut Graph,
519 config: &ArchitectureConfig,
520) -> Result<SequentialModel, ModelError> {
521 match arch {
522 ModelArchitecture::ResNet18
523 | ModelArchitecture::ResNet34
524 | ModelArchitecture::ResNet50
525 | ModelArchitecture::ResNet101 => {
526 let mut model = SequentialModel::new(graph);
527 let max_blocks = config.blocks_per_stage.iter().copied().max().unwrap_or(2);
528 build_resnet_feature_extractor(
529 &mut model,
530 config.input_channels,
531 &config.stage_channels,
532 max_blocks,
533 BN_EPSILON,
534 )?;
535 Ok(model)
536 }
537 ModelArchitecture::VGG16 | ModelArchitecture::VGG19 => {
538 let mut model = SequentialModel::new(graph);
539 let mut ch = config.input_channels;
540 for (stage_idx, &out_ch) in config.stage_channels.iter().enumerate() {
541 let blocks = config.blocks_per_stage.get(stage_idx).copied().unwrap_or(2);
542 for b in 0..blocks {
543 let in_ch = if b == 0 { ch } else { out_ch };
544 model.add_conv2d_zero(in_ch, out_ch, 3, 3, 1, 1, true)?;
545 model.add_batch_norm2d_identity(out_ch, BN_EPSILON)?;
546 model.add_relu();
547 }
548 model.add_max_pool2d(2, 2, 2, 2)?;
549 ch = out_ch;
550 }
551 model.add_global_avg_pool2d();
552 model.add_flatten();
553 Ok(model)
554 }
555 ModelArchitecture::MobileNetV2 | ModelArchitecture::EfficientNetB0 => {
556 let mut model = SequentialModel::new(graph);
557 let stem_ch = config.stage_channels.first().copied().unwrap_or(32);
558 model.add_conv2d_zero(config.input_channels, stem_ch, 3, 3, 2, 2, false)?;
559 model.add_batch_norm2d_identity(stem_ch, BN_EPSILON)?;
560 model.add_relu();
561
562 let mut ch = stem_ch;
563 for (stage_idx, &out_ch) in config.stage_channels.iter().enumerate().skip(1) {
564 let blocks = config.blocks_per_stage.get(stage_idx).copied().unwrap_or(1);
565 for b in 0..blocks {
566 let stride = if b == 0 && stage_idx > 1 { 2 } else { 1 };
567 let expand_ch = ch * 6;
568 add_bottleneck_block(&mut model, ch, expand_ch, out_ch, stride, BN_EPSILON)?;
569 ch = out_ch;
570 }
571 }
572 let last_ch = 1280;
573 model.add_conv2d_zero(ch, last_ch, 1, 1, 1, 1, false)?;
574 model.add_batch_norm2d_identity(last_ch, BN_EPSILON)?;
575 model.add_relu();
576 model.add_global_avg_pool2d();
577 model.add_flatten();
578 Ok(model)
579 }
580 ModelArchitecture::AlexNet => {
581 let mut model = SequentialModel::new(graph);
582 let channels = &config.stage_channels;
583 let ch0 = channels.first().copied().unwrap_or(64);
584 model.add_conv2d_zero(config.input_channels, ch0, 11, 11, 4, 4, true)?;
585 model.add_relu();
586 model.add_max_pool2d(3, 3, 2, 2)?;
587
588 let ch1 = channels.get(1).copied().unwrap_or(192);
589 model.add_conv2d_zero(ch0, ch1, 5, 5, 1, 1, true)?;
590 model.add_relu();
591 model.add_max_pool2d(3, 3, 2, 2)?;
592
593 let ch2 = channels.get(2).copied().unwrap_or(384);
594 model.add_conv2d_zero(ch1, ch2, 3, 3, 1, 1, true)?;
595 model.add_relu();
596
597 let ch3 = channels.get(3).copied().unwrap_or(256);
598 model.add_conv2d_zero(ch2, ch3, 3, 3, 1, 1, true)?;
599 model.add_relu();
600
601 let ch4 = channels.get(4).copied().unwrap_or(256);
602 model.add_conv2d_zero(ch3, ch4, 3, 3, 1, 1, true)?;
603 model.add_relu();
604 model.add_max_pool2d(3, 3, 2, 2)?;
605 model.add_global_avg_pool2d();
606 model.add_flatten();
607 Ok(model)
608 }
609 ModelArchitecture::ViTTiny
610 | ModelArchitecture::ViTBase
611 | ModelArchitecture::ViTLarge
612 | ModelArchitecture::DeiTTiny => {
613 let embed_dim = config.stage_channels.first().copied().unwrap_or(192);
614 let mut model = SequentialModel::new(graph);
615 model.add_conv2d_zero(config.input_channels, embed_dim, 16, 16, 16, 16, false)?;
616 model.add_flatten();
617 Ok(model)
618 }
619 }
620}
621
622pub fn build_classifier(
624 arch: ModelArchitecture,
625 graph: &mut Graph,
626 num_classes: usize,
627) -> Result<SequentialModel, ModelError> {
628 let config = arch.config().with_num_classes(num_classes);
629 build_architecture(arch, graph, &config)
630}
631
632fn build_architecture(
634 arch: ModelArchitecture,
635 graph: &mut Graph,
636 config: &ArchitectureConfig,
637) -> Result<SequentialModel, ModelError> {
638 match arch {
639 ModelArchitecture::ResNet18
640 | ModelArchitecture::ResNet34
641 | ModelArchitecture::ResNet50
642 | ModelArchitecture::ResNet101 => build_resnet_custom(graph, config),
643 ModelArchitecture::VGG16 | ModelArchitecture::VGG19 => build_vgg(graph, config),
644 ModelArchitecture::MobileNetV2 | ModelArchitecture::EfficientNetB0 => {
645 build_mobilenet_v2(graph, config)
646 }
647 ModelArchitecture::AlexNet => build_alexnet(graph, config),
648 ModelArchitecture::ViTTiny
649 | ModelArchitecture::ViTBase
650 | ModelArchitecture::ViTLarge
651 | ModelArchitecture::DeiTTiny => {
652 let embed_dim = config.stage_channels.first().copied().unwrap_or(192);
653 let mut model = SequentialModel::new(graph);
654 model.add_conv2d_zero(config.input_channels, embed_dim, 16, 16, 16, 16, false)?;
655 model.add_flatten();
656 model.add_linear_zero(graph, embed_dim, config.num_classes)?;
657 Ok(model)
658 }
659 }
660}
661
662#[cfg(test)]
663mod tests {
664 use super::*;
665
666 #[test]
667 fn test_load_pretrained_applies_weights() -> Result<(), Box<dyn std::error::Error>> {
668 let mut graph = Graph::new();
669 let arch = ModelArchitecture::AlexNet;
670 let config = arch.config();
671 let model = build_architecture(arch, &mut graph, &config)?;
672
673 let mut tensors = collect_model_tensors(&model, &graph)?;
675
676 assert!(!tensors.is_empty(), "model should have named tensors");
678
679 for t in tensors.values_mut() {
681 let len = t.data().len();
682 *t = yscv_tensor::Tensor::from_vec(t.shape().to_vec(), vec![0.42_f32; len])?;
683 }
684
685 let tmp_dir = std::env::temp_dir().join("yscv_test_zoo");
687 let zoo = ModelZoo::new(&tmp_dir);
688 let path = zoo.weight_path(arch);
689 if let Some(parent) = path.parent() {
690 std::fs::create_dir_all(parent).ok();
691 }
692 save_weights(&path, &tensors)?;
693
694 let mut graph2 = Graph::new();
696 let loaded_model = zoo.load_pretrained(arch, &mut graph2)?;
697
698 let loaded_tensors = collect_model_tensors(&loaded_model, &graph2)?;
700
701 for (name, original) in &tensors {
702 let loaded = loaded_tensors
703 .get(name)
704 .ok_or_else(|| ModelError::WeightNotFound { name: name.clone() })?;
705 assert_eq!(
706 original.shape(),
707 loaded.shape(),
708 "shape mismatch for {name}"
709 );
710 for (i, (&orig, &load)) in original.data().iter().zip(loaded.data().iter()).enumerate()
712 {
713 assert!(
714 (orig - load).abs() < 1e-6,
715 "value mismatch for {name}[{i}]: expected {orig}, got {load}"
716 );
717 }
718 }
719
720 std::fs::remove_file(&path).ok();
722 std::fs::remove_dir(&tmp_dir).ok();
723 Ok(())
724 }
725}