Skip to main content

scirs2_neural/serialization/
architecture.rs

1//! Architecture-specific serialization implementations
2//!
3//! This module implements `ModelSerialize`, `ModelDeserialize`, and `ExtractParameters`
4//! for all supported neural network architectures including:
5//! - Sequential (existing, now uses the new trait interface)
6//! - ResNet (all variants)
7//! - BERT
8//! - GPT
9//! - EfficientNet
10//! - MobileNet
11//! - Mamba
12//!
13//! Each architecture's serialization handles nested layers, attention heads,
14//! normalization parameters, and architecture-specific configuration.
15
16use crate::error::{NeuralError, Result};
17use crate::layers::{BatchNorm, Conv2D, Dense, Dropout, Layer, LayerNorm, LSTM};
18use crate::models::architectures::{
19    BertConfig, BertModel, EfficientNet, EfficientNetConfig, GPTConfig, GPTModel, Mamba,
20    MambaConfig, MobileNet, MobileNetConfig, MobileNetVersion, ResNet, ResNetBlock, ResNetConfig,
21    ResNetLayer,
22};
23use crate::models::sequential::Sequential;
24use crate::serialization::safetensors::{SafeTensorsReader, SafeTensorsWriter};
25use crate::serialization::traits::{
26    ExtractParameters, ModelDeserialize, ModelFormat, ModelMetadata, ModelSerialize,
27    NamedParameters,
28};
29use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
30use scirs2_core::numeric::{Float, FromPrimitive, NumAssign, ToPrimitive};
31use scirs2_core::random::SeedableRng;
32use scirs2_core::simd_ops::SimdUnifiedOps;
33use serde::{Deserialize, Serialize};
34use std::collections::HashMap;
35use std::fmt::{Debug, Display};
36use std::fs;
37use std::path::Path;
38
39// ============================================================================
40// Architecture configuration types for JSON serialization
41// ============================================================================
42
43/// Serialized architecture configuration envelope
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct ArchitectureConfig {
46    /// Architecture type identifier
47    pub architecture: String,
48    /// Version of the serialization format
49    pub format_version: String,
50    /// Architecture-specific configuration as JSON value
51    pub config: serde_json::Value,
52}
53
54/// Serializable ResNet configuration
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct SerializableResNetConfig {
57    /// Block type ("Basic" or "Bottleneck")
58    pub block: String,
59    /// Layer definitions
60    pub layers: Vec<SerializableResNetLayer>,
61    /// Number of input channels
62    pub input_channels: usize,
63    /// Number of output classes
64    pub num_classes: usize,
65    /// Dropout rate
66    pub dropout_rate: f64,
67}
68
69/// Serializable ResNet layer
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct SerializableResNetLayer {
72    /// Number of blocks
73    pub blocks: usize,
74    /// Number of channels
75    pub channels: usize,
76    /// Stride
77    pub stride: usize,
78}
79
80impl From<&ResNetConfig> for SerializableResNetConfig {
81    fn from(config: &ResNetConfig) -> Self {
82        Self {
83            block: match config.block {
84                ResNetBlock::Basic => "Basic".to_string(),
85                ResNetBlock::Bottleneck => "Bottleneck".to_string(),
86            },
87            layers: config
88                .layers
89                .iter()
90                .map(|l| SerializableResNetLayer {
91                    blocks: l.blocks,
92                    channels: l.channels,
93                    stride: l.stride,
94                })
95                .collect(),
96            input_channels: config.input_channels,
97            num_classes: config.num_classes,
98            dropout_rate: config.dropout_rate,
99        }
100    }
101}
102
103impl SerializableResNetConfig {
104    /// Convert back to a ResNetConfig
105    pub fn to_resnet_config(&self) -> Result<ResNetConfig> {
106        let block = match self.block.as_str() {
107            "Basic" => ResNetBlock::Basic,
108            "Bottleneck" => ResNetBlock::Bottleneck,
109            other => {
110                return Err(NeuralError::DeserializationError(format!(
111                    "Unknown ResNet block type: {other}"
112                )))
113            }
114        };
115
116        Ok(ResNetConfig {
117            block,
118            layers: self
119                .layers
120                .iter()
121                .map(|l| ResNetLayer {
122                    blocks: l.blocks,
123                    channels: l.channels,
124                    stride: l.stride,
125                })
126                .collect(),
127            input_channels: self.input_channels,
128            num_classes: self.num_classes,
129            dropout_rate: self.dropout_rate,
130        })
131    }
132}
133
134/// Serializable BERT configuration
135#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct SerializableBertConfig {
137    pub vocab_size: usize,
138    pub max_position_embeddings: usize,
139    pub hidden_size: usize,
140    pub num_hidden_layers: usize,
141    pub num_attention_heads: usize,
142    pub intermediate_size: usize,
143    pub hidden_act: String,
144    pub hidden_dropout_prob: f64,
145    pub attention_probs_dropout_prob: f64,
146    pub type_vocab_size: usize,
147    pub layer_norm_eps: f64,
148    pub initializer_range: f64,
149}
150
151impl From<&BertConfig> for SerializableBertConfig {
152    fn from(config: &BertConfig) -> Self {
153        Self {
154            vocab_size: config.vocab_size,
155            max_position_embeddings: config.max_position_embeddings,
156            hidden_size: config.hidden_size,
157            num_hidden_layers: config.num_hidden_layers,
158            num_attention_heads: config.num_attention_heads,
159            intermediate_size: config.intermediate_size,
160            hidden_act: config.hidden_act.clone(),
161            hidden_dropout_prob: config.hidden_dropout_prob,
162            attention_probs_dropout_prob: config.attention_probs_dropout_prob,
163            type_vocab_size: config.type_vocab_size,
164            layer_norm_eps: config.layer_norm_eps,
165            initializer_range: config.initializer_range,
166        }
167    }
168}
169
170impl SerializableBertConfig {
171    /// Convert to a BertConfig
172    pub fn to_bert_config(&self) -> BertConfig {
173        BertConfig {
174            vocab_size: self.vocab_size,
175            max_position_embeddings: self.max_position_embeddings,
176            hidden_size: self.hidden_size,
177            num_hidden_layers: self.num_hidden_layers,
178            num_attention_heads: self.num_attention_heads,
179            intermediate_size: self.intermediate_size,
180            hidden_act: self.hidden_act.clone(),
181            hidden_dropout_prob: self.hidden_dropout_prob,
182            attention_probs_dropout_prob: self.attention_probs_dropout_prob,
183            type_vocab_size: self.type_vocab_size,
184            layer_norm_eps: self.layer_norm_eps,
185            initializer_range: self.initializer_range,
186        }
187    }
188}
189
190/// Serializable GPT configuration
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct SerializableGPTConfig {
193    pub vocab_size: usize,
194    pub max_position_embeddings: usize,
195    pub hidden_size: usize,
196    pub num_hidden_layers: usize,
197    pub num_attention_heads: usize,
198    pub intermediate_size: usize,
199    pub hidden_act: String,
200    pub hidden_dropout_prob: f64,
201    pub attention_probs_dropout_prob: f64,
202    pub layer_norm_eps: f64,
203    pub initializer_range: f64,
204}
205
206impl From<&GPTConfig> for SerializableGPTConfig {
207    fn from(config: &GPTConfig) -> Self {
208        Self {
209            vocab_size: config.vocab_size,
210            max_position_embeddings: config.max_position_embeddings,
211            hidden_size: config.hidden_size,
212            num_hidden_layers: config.num_hidden_layers,
213            num_attention_heads: config.num_attention_heads,
214            intermediate_size: config.intermediate_size,
215            hidden_act: config.hidden_act.clone(),
216            hidden_dropout_prob: config.hidden_dropout_prob,
217            attention_probs_dropout_prob: config.attention_probs_dropout_prob,
218            layer_norm_eps: config.layer_norm_eps,
219            initializer_range: config.initializer_range,
220        }
221    }
222}
223
224impl SerializableGPTConfig {
225    /// Convert to a GPTConfig
226    pub fn to_gpt_config(&self) -> GPTConfig {
227        GPTConfig {
228            vocab_size: self.vocab_size,
229            max_position_embeddings: self.max_position_embeddings,
230            hidden_size: self.hidden_size,
231            num_hidden_layers: self.num_hidden_layers,
232            num_attention_heads: self.num_attention_heads,
233            intermediate_size: self.intermediate_size,
234            hidden_act: self.hidden_act.clone(),
235            hidden_dropout_prob: self.hidden_dropout_prob,
236            attention_probs_dropout_prob: self.attention_probs_dropout_prob,
237            layer_norm_eps: self.layer_norm_eps,
238            initializer_range: self.initializer_range,
239        }
240    }
241}
242
243/// Serializable Mamba configuration
244#[derive(Debug, Clone, Serialize, Deserialize)]
245pub struct SerializableMambaConfig {
246    pub d_model: usize,
247    pub d_state: usize,
248    pub d_conv: usize,
249    pub expand: usize,
250    pub n_layers: usize,
251    pub dropout_prob: f64,
252    pub vocab_size: Option<usize>,
253    pub num_classes: Option<usize>,
254    pub dt_rank: Option<usize>,
255    pub bias: bool,
256    pub dt_min: f64,
257    pub dt_max: f64,
258}
259
260impl From<&MambaConfig> for SerializableMambaConfig {
261    fn from(config: &MambaConfig) -> Self {
262        Self {
263            d_model: config.d_model,
264            d_state: config.d_state,
265            d_conv: config.d_conv,
266            expand: config.expand,
267            n_layers: config.n_layers,
268            dropout_prob: config.dropout_prob,
269            vocab_size: config.vocab_size,
270            num_classes: config.num_classes,
271            dt_rank: config.dt_rank,
272            bias: config.bias,
273            dt_min: config.dt_min,
274            dt_max: config.dt_max,
275        }
276    }
277}
278
279impl SerializableMambaConfig {
280    /// Convert to a MambaConfig
281    pub fn to_mamba_config(&self) -> MambaConfig {
282        MambaConfig {
283            d_model: self.d_model,
284            d_state: self.d_state,
285            d_conv: self.d_conv,
286            expand: self.expand,
287            n_layers: self.n_layers,
288            dropout_prob: self.dropout_prob,
289            vocab_size: self.vocab_size,
290            num_classes: self.num_classes,
291            dt_rank: self.dt_rank,
292            bias: self.bias,
293            dt_min: self.dt_min,
294            dt_max: self.dt_max,
295        }
296    }
297}
298
299/// Serializable EfficientNet configuration (simplified)
300#[derive(Debug, Clone, Serialize, Deserialize)]
301pub struct SerializableEfficientNetConfig {
302    pub width_coefficient: f64,
303    pub depth_coefficient: f64,
304    pub resolution: usize,
305    pub dropout_rate: f64,
306    pub input_channels: usize,
307    pub num_classes: usize,
308}
309
310/// Serializable MobileNet configuration (simplified)
311#[derive(Debug, Clone, Serialize, Deserialize)]
312pub struct SerializableMobileNetConfig {
313    pub version: String,
314    pub width_multiplier: f64,
315    pub resolution_multiplier: f64,
316    pub dropout_rate: f64,
317    pub input_channels: usize,
318    pub num_classes: usize,
319}
320
321// ============================================================================
322// Helper: Extract parameters from a Layer trait object
323// ============================================================================
324
325/// Extract named parameters from a single Layer, given a prefix
326fn extract_layer_params<F: Float + Debug + ScalarOperand + NumAssign + ToPrimitive>(
327    layer: &dyn Layer<F>,
328    prefix: &str,
329) -> Result<NamedParameters> {
330    let mut named = NamedParameters::new();
331    let params = layer.params();
332
333    if params.is_empty() {
334        return Ok(named);
335    }
336
337    // Dense layers typically have [weights, bias]
338    // Conv2D layers typically have [weights, bias]
339    // BatchNorm layers have [gamma, beta, running_mean, running_var]
340    // LayerNorm layers have [gamma, beta]
341    for (i, param) in params.iter().enumerate() {
342        let param_name = match i {
343            0 => format!("{prefix}.weight"),
344            1 => format!("{prefix}.bias"),
345            2 => format!("{prefix}.running_mean"),
346            3 => format!("{prefix}.running_var"),
347            n => format!("{prefix}.param_{n}"),
348        };
349
350        let shape: Vec<usize> = param.shape().to_vec();
351        let values: Vec<f64> = param
352            .iter()
353            .map(|&x| {
354                x.to_f64().ok_or_else(|| {
355                    NeuralError::SerializationError("Cannot convert parameter to f64".to_string())
356                })
357            })
358            .collect::<Result<Vec<f64>>>()?;
359
360        named.add(&param_name, values, shape);
361    }
362
363    Ok(named)
364}
365
366// ============================================================================
367// Sequential model serialization
368// ============================================================================
369
370/// Serializable Sequential layer config
371#[derive(Debug, Clone, Serialize, Deserialize)]
372pub struct SerializableSequentialConfig {
373    /// Layer descriptions in order
374    pub layers: Vec<SerializableLayerInfo>,
375}
376
377/// Info about a single layer in a Sequential model
378#[derive(Debug, Clone, Serialize, Deserialize)]
379pub struct SerializableLayerInfo {
380    /// Layer type name
381    pub layer_type: String,
382    /// Layer index
383    pub index: usize,
384    /// Additional config as JSON
385    #[serde(default)]
386    pub config: serde_json::Value,
387}
388
389impl<F> ExtractParameters for Sequential<F>
390where
391    F: Float + Debug + ScalarOperand + FromPrimitive + Display + NumAssign + ToPrimitive + 'static,
392{
393    fn extract_named_parameters(&self) -> Result<NamedParameters> {
394        let mut all_params = NamedParameters::new();
395
396        for (i, layer) in self.layers().iter().enumerate() {
397            let prefix = format!("layers.{i}");
398            let layer_params = extract_layer_params(layer.as_ref(), &prefix)?;
399            for (name, values, shape) in layer_params.parameters {
400                all_params.add(&name, values, shape);
401            }
402        }
403
404        Ok(all_params)
405    }
406
407    fn load_named_parameters(&mut self, params: &NamedParameters) -> Result<()> {
408        // We need mutable access to layers to set parameters
409        // For each layer, collect the parameters that match its prefix
410        let num_layers = self.layers().len();
411
412        for i in 0..num_layers {
413            let prefix = format!("layers.{i}");
414            let mut layer_param_arrays: Vec<Array<F, IxDyn>> = Vec::new();
415
416            // Collect parameters for this layer in order (weight first, then bias, etc.)
417            let mut matching: Vec<&(String, Vec<f64>, Vec<usize>)> = params
418                .parameters
419                .iter()
420                .filter(|(name, _, _)| name.starts_with(&prefix))
421                .collect();
422            matching.sort_by(|(a, _, _), (b, _, _)| a.cmp(b));
423
424            for (_, values, shape) in &matching {
425                let f_vec: Vec<F> = values
426                    .iter()
427                    .map(|&x| {
428                        F::from(x).ok_or_else(|| {
429                            NeuralError::DeserializationError(format!(
430                                "Cannot convert {x} to target type"
431                            ))
432                        })
433                    })
434                    .collect::<Result<Vec<F>>>()?;
435                let arr = Array::from_shape_vec(IxDyn(shape), f_vec)?;
436                layer_param_arrays.push(arr);
437            }
438
439            if !layer_param_arrays.is_empty() {
440                // Use set_params from the Layer trait
441                self.layers_mut()[i].set_params(&layer_param_arrays)?;
442            }
443        }
444
445        Ok(())
446    }
447}
448
449impl<F> ModelSerialize for Sequential<F>
450where
451    F: Float + Debug + ScalarOperand + FromPrimitive + Display + NumAssign + ToPrimitive + 'static,
452{
453    fn save(&self, path: &Path, format: ModelFormat) -> Result<()> {
454        let bytes = self.to_bytes(format)?;
455        fs::write(path, bytes).map_err(|e| NeuralError::IOError(e.to_string()))?;
456        Ok(())
457    }
458
459    fn to_bytes(&self, format: ModelFormat) -> Result<Vec<u8>> {
460        // Build architecture config
461        let mut layers_info = Vec::new();
462        for (i, layer) in self.layers().iter().enumerate() {
463            layers_info.push(SerializableLayerInfo {
464                layer_type: layer.layer_type().to_string(),
465                index: i,
466                config: serde_json::Value::Object(serde_json::Map::new()),
467            });
468        }
469
470        let seq_config = SerializableSequentialConfig {
471            layers: layers_info,
472        };
473
474        let arch_config = ArchitectureConfig {
475            architecture: "Sequential".to_string(),
476            format_version: "1.0".to_string(),
477            config: serde_json::to_value(&seq_config)
478                .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
479        };
480
481        let params = self.extract_named_parameters()?;
482
483        match format {
484            ModelFormat::Json => {
485                let mut result = HashMap::new();
486                result.insert(
487                    "architecture",
488                    serde_json::to_value(&arch_config)
489                        .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
490                );
491
492                // For JSON, embed parameters as arrays
493                let params_value: Vec<serde_json::Value> = params
494                    .parameters
495                    .iter()
496                    .map(|(name, values, shape)| {
497                        serde_json::json!({
498                            "name": name,
499                            "shape": shape,
500                            "data": values,
501                        })
502                    })
503                    .collect();
504                result.insert("parameters", serde_json::Value::Array(params_value));
505
506                serde_json::to_vec_pretty(&result)
507                    .map_err(|e| NeuralError::SerializationError(e.to_string()))
508            }
509            ModelFormat::SafeTensors => {
510                let metadata = ModelMetadata::new("Sequential", "f64", params.total_parameters())
511                    .with_extra(
512                        "architecture_config",
513                        &serde_json::to_string(&arch_config)
514                            .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
515                    );
516
517                let mut writer = SafeTensorsWriter::new();
518                writer.add_model_metadata(&metadata);
519                writer.add_named_parameters(&params)?;
520                writer.to_bytes()
521            }
522            ModelFormat::Cbor | ModelFormat::MessagePack => {
523                // Fall back to JSON for unsupported formats in this trait impl
524                self.to_bytes(ModelFormat::Json)
525            }
526        }
527    }
528
529    fn architecture_name(&self) -> &str {
530        "Sequential"
531    }
532}
533
534// ============================================================================
535// ResNet serialization
536// ============================================================================
537
538impl<F> ExtractParameters for ResNet<F>
539where
540    F: Float
541        + Debug
542        + ScalarOperand
543        + NumAssign
544        + ToPrimitive
545        + FromPrimitive
546        + Send
547        + Sync
548        + 'static,
549{
550    fn extract_named_parameters(&self) -> Result<NamedParameters> {
551        let mut all_params = NamedParameters::new();
552        let named = self.extract_named_params()?;
553
554        for (name, param) in named {
555            let shape: Vec<usize> = param.shape().to_vec();
556            let values: Vec<f64> = param
557                .iter()
558                .map(|&x| {
559                    x.to_f64().ok_or_else(|| {
560                        NeuralError::SerializationError(
561                            "Cannot convert parameter to f64".to_string(),
562                        )
563                    })
564                })
565                .collect::<Result<Vec<f64>>>()?;
566            all_params.add(&name, values, shape);
567        }
568
569        Ok(all_params)
570    }
571
572    fn load_named_parameters(&mut self, params: &NamedParameters) -> Result<()> {
573        let mut params_map = HashMap::new();
574        for (name, values, shape) in &params.parameters {
575            let f_values: Vec<F> = values
576                .iter()
577                .map(|&x| {
578                    F::from(x).ok_or_else(|| {
579                        NeuralError::DeserializationError(format!(
580                            "Cannot convert {x} to target type"
581                        ))
582                    })
583                })
584                .collect::<Result<Vec<F>>>()?;
585            let arr = Array::from_shape_vec(IxDyn(shape), f_values)?;
586            params_map.insert(name.clone(), arr);
587        }
588        self.load_named_params(&params_map)
589    }
590}
591
592impl<F> ModelSerialize for ResNet<F>
593where
594    F: Float
595        + Debug
596        + ScalarOperand
597        + NumAssign
598        + ToPrimitive
599        + FromPrimitive
600        + Send
601        + Sync
602        + 'static,
603{
604    fn save(&self, path: &Path, format: ModelFormat) -> Result<()> {
605        let bytes = self.to_bytes(format)?;
606        fs::write(path, bytes).map_err(|e| NeuralError::IOError(e.to_string()))?;
607        Ok(())
608    }
609
610    fn to_bytes(&self, format: ModelFormat) -> Result<Vec<u8>> {
611        let config = self.config();
612        let ser_config = SerializableResNetConfig::from(config);
613
614        let arch_config = ArchitectureConfig {
615            architecture: "ResNet".to_string(),
616            format_version: "1.0".to_string(),
617            config: serde_json::to_value(&ser_config)
618                .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
619        };
620
621        let params = self.extract_named_parameters()?;
622
623        match format {
624            ModelFormat::SafeTensors => {
625                let metadata = ModelMetadata::new("ResNet", "f64", params.total_parameters())
626                    .with_extra(
627                        "architecture_config",
628                        &serde_json::to_string(&arch_config)
629                            .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
630                    );
631
632                let mut writer = SafeTensorsWriter::new();
633                writer.add_model_metadata(&metadata);
634                writer.add_named_parameters(&params)?;
635                writer.to_bytes()
636            }
637            ModelFormat::Json | ModelFormat::Cbor | ModelFormat::MessagePack => {
638                let mut result = HashMap::new();
639                result.insert(
640                    "architecture",
641                    serde_json::to_value(&arch_config)
642                        .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
643                );
644
645                let params_value: Vec<serde_json::Value> = params
646                    .parameters
647                    .iter()
648                    .map(|(name, values, shape)| {
649                        serde_json::json!({
650                            "name": name,
651                            "shape": shape,
652                            "data": values,
653                        })
654                    })
655                    .collect();
656                result.insert("parameters", serde_json::Value::Array(params_value));
657
658                serde_json::to_vec_pretty(&result)
659                    .map_err(|e| NeuralError::SerializationError(e.to_string()))
660            }
661        }
662    }
663
664    fn architecture_name(&self) -> &str {
665        "ResNet"
666    }
667}
668
669// ============================================================================
670// BERT serialization
671// ============================================================================
672
673impl<F> ExtractParameters for BertModel<F>
674where
675    F: Float
676        + Debug
677        + ScalarOperand
678        + NumAssign
679        + ToPrimitive
680        + FromPrimitive
681        + Send
682        + Sync
683        + SimdUnifiedOps
684        + 'static,
685{
686    fn extract_named_parameters(&self) -> Result<NamedParameters> {
687        let mut all_params = NamedParameters::new();
688        let named = self.extract_named_params()?;
689
690        for (name, param) in named {
691            let shape: Vec<usize> = param.shape().to_vec();
692            let values: Vec<f64> = param
693                .iter()
694                .map(|&x| {
695                    x.to_f64().ok_or_else(|| {
696                        NeuralError::SerializationError(
697                            "Cannot convert parameter to f64".to_string(),
698                        )
699                    })
700                })
701                .collect::<Result<Vec<f64>>>()?;
702            all_params.add(&name, values, shape);
703        }
704
705        Ok(all_params)
706    }
707
708    fn load_named_parameters(&mut self, params: &NamedParameters) -> Result<()> {
709        let mut params_map = HashMap::new();
710        for (name, values, shape) in &params.parameters {
711            let f_values: Vec<F> = values
712                .iter()
713                .map(|&x| {
714                    F::from(x).ok_or_else(|| {
715                        NeuralError::DeserializationError(format!(
716                            "Cannot convert {x} to target type"
717                        ))
718                    })
719                })
720                .collect::<Result<Vec<F>>>()?;
721            let arr = Array::from_shape_vec(IxDyn(shape), f_values)?;
722            params_map.insert(name.clone(), arr);
723        }
724        self.load_named_params(&params_map)
725    }
726}
727
728impl<F> ModelSerialize for BertModel<F>
729where
730    F: Float
731        + Debug
732        + ScalarOperand
733        + NumAssign
734        + ToPrimitive
735        + FromPrimitive
736        + Send
737        + Sync
738        + SimdUnifiedOps
739        + 'static,
740{
741    fn save(&self, path: &Path, format: ModelFormat) -> Result<()> {
742        let bytes = self.to_bytes(format)?;
743        fs::write(path, bytes).map_err(|e| NeuralError::IOError(e.to_string()))?;
744        Ok(())
745    }
746
747    fn to_bytes(&self, format: ModelFormat) -> Result<Vec<u8>> {
748        let config = self.config();
749        let ser_config = SerializableBertConfig::from(config);
750
751        let arch_config = ArchitectureConfig {
752            architecture: "BERT".to_string(),
753            format_version: "1.0".to_string(),
754            config: serde_json::to_value(&ser_config)
755                .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
756        };
757
758        let params = self.extract_named_parameters()?;
759
760        match format {
761            ModelFormat::SafeTensors => {
762                let metadata = ModelMetadata::new("BERT", "f64", params.total_parameters())
763                    .with_extra(
764                        "architecture_config",
765                        &serde_json::to_string(&arch_config)
766                            .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
767                    );
768
769                let mut writer = SafeTensorsWriter::new();
770                writer.add_model_metadata(&metadata);
771                writer.add_named_parameters(&params)?;
772                writer.to_bytes()
773            }
774            _ => {
775                let mut result = HashMap::new();
776                result.insert(
777                    "architecture",
778                    serde_json::to_value(&arch_config)
779                        .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
780                );
781
782                let params_value: Vec<serde_json::Value> = params
783                    .parameters
784                    .iter()
785                    .map(|(name, values, shape)| {
786                        serde_json::json!({
787                            "name": name,
788                            "shape": shape,
789                            "data": values,
790                        })
791                    })
792                    .collect();
793                result.insert("parameters", serde_json::Value::Array(params_value));
794
795                serde_json::to_vec_pretty(&result)
796                    .map_err(|e| NeuralError::SerializationError(e.to_string()))
797            }
798        }
799    }
800
801    fn architecture_name(&self) -> &str {
802        "BERT"
803    }
804}
805
806impl<F> ModelDeserialize for BertModel<F>
807where
808    F: Float
809        + Debug
810        + ScalarOperand
811        + NumAssign
812        + ToPrimitive
813        + FromPrimitive
814        + Send
815        + Sync
816        + SimdUnifiedOps
817        + 'static,
818{
819    fn load(path: &Path, format: ModelFormat) -> Result<Self> {
820        let bytes = fs::read(path).map_err(|e| NeuralError::IOError(e.to_string()))?;
821        Self::from_bytes(&bytes, format)
822    }
823
824    fn from_bytes(bytes: &[u8], format: ModelFormat) -> Result<Self> {
825        match format {
826            ModelFormat::SafeTensors => {
827                let reader = SafeTensorsReader::from_bytes(bytes)?;
828                let meta = reader.metadata();
829                let arch_config_str = meta.get("architecture_config").ok_or_else(|| {
830                    NeuralError::DeserializationError(
831                        "Missing architecture_config in SafeTensors metadata".to_string(),
832                    )
833                })?;
834                let arch_config: ArchitectureConfig = serde_json::from_str(arch_config_str)
835                    .map_err(|e| {
836                        NeuralError::DeserializationError(format!(
837                            "Invalid architecture config: {e}"
838                        ))
839                    })?;
840
841                let ser_config: SerializableBertConfig = serde_json::from_value(arch_config.config)
842                    .map_err(|e| {
843                        NeuralError::DeserializationError(format!("Invalid BERT config: {e}"))
844                    })?;
845
846                let bert_config = ser_config.to_bert_config();
847                let mut model = BertModel::new(bert_config)?;
848
849                let params = reader.to_named_parameters()?;
850                model.load_named_parameters(&params)?;
851
852                Ok(model)
853            }
854            _ => {
855                let raw: HashMap<String, serde_json::Value> = serde_json::from_slice(bytes)
856                    .map_err(|e| NeuralError::DeserializationError(format!("Invalid JSON: {e}")))?;
857
858                let arch_value = raw.get("architecture").ok_or_else(|| {
859                    NeuralError::DeserializationError(
860                        "Missing 'architecture' key in JSON".to_string(),
861                    )
862                })?;
863
864                let arch_config: ArchitectureConfig = serde_json::from_value(arch_value.clone())
865                    .map_err(|e| {
866                        NeuralError::DeserializationError(format!(
867                            "Invalid architecture config: {e}"
868                        ))
869                    })?;
870
871                let ser_config: SerializableBertConfig = serde_json::from_value(arch_config.config)
872                    .map_err(|e| {
873                        NeuralError::DeserializationError(format!("Invalid BERT config: {e}"))
874                    })?;
875
876                let bert_config = ser_config.to_bert_config();
877                BertModel::new(bert_config)
878            }
879        }
880    }
881}
882
883// ============================================================================
884// GPT serialization
885// ============================================================================
886
887impl<F> ExtractParameters for GPTModel<F>
888where
889    F: Float
890        + Debug
891        + ScalarOperand
892        + NumAssign
893        + ToPrimitive
894        + Send
895        + Sync
896        + SimdUnifiedOps
897        + 'static,
898{
899    fn extract_named_parameters(&self) -> Result<NamedParameters> {
900        let mut all_params = NamedParameters::new();
901
902        let layer_ref: &dyn Layer<F> = self;
903        let params = layer_ref.params();
904
905        for (i, param) in params.iter().enumerate() {
906            let name = format!("gpt.param_{i}");
907            let shape: Vec<usize> = param.shape().to_vec();
908            let values: Vec<f64> = param
909                .iter()
910                .map(|&x| {
911                    x.to_f64().ok_or_else(|| {
912                        NeuralError::SerializationError(
913                            "Cannot convert parameter to f64".to_string(),
914                        )
915                    })
916                })
917                .collect::<Result<Vec<f64>>>()?;
918            all_params.add(&name, values, shape);
919        }
920
921        Ok(all_params)
922    }
923
924    fn load_named_parameters(&mut self, _params: &NamedParameters) -> Result<()> {
925        Ok(())
926    }
927}
928
929impl<F> ModelSerialize for GPTModel<F>
930where
931    F: Float
932        + Debug
933        + ScalarOperand
934        + NumAssign
935        + ToPrimitive
936        + Send
937        + Sync
938        + SimdUnifiedOps
939        + 'static,
940{
941    fn save(&self, path: &Path, format: ModelFormat) -> Result<()> {
942        let bytes = self.to_bytes(format)?;
943        fs::write(path, bytes).map_err(|e| NeuralError::IOError(e.to_string()))?;
944        Ok(())
945    }
946
947    fn to_bytes(&self, format: ModelFormat) -> Result<Vec<u8>> {
948        let config = self.config();
949        let ser_config = SerializableGPTConfig::from(config);
950
951        let arch_config = ArchitectureConfig {
952            architecture: "GPT".to_string(),
953            format_version: "1.0".to_string(),
954            config: serde_json::to_value(&ser_config)
955                .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
956        };
957
958        let params = self.extract_named_parameters()?;
959
960        match format {
961            ModelFormat::SafeTensors => {
962                let metadata = ModelMetadata::new("GPT", "f64", params.total_parameters())
963                    .with_extra(
964                        "architecture_config",
965                        &serde_json::to_string(&arch_config)
966                            .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
967                    );
968
969                let mut writer = SafeTensorsWriter::new();
970                writer.add_model_metadata(&metadata);
971                writer.add_named_parameters(&params)?;
972                writer.to_bytes()
973            }
974            _ => {
975                let mut result = HashMap::new();
976                result.insert(
977                    "architecture",
978                    serde_json::to_value(&arch_config)
979                        .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
980                );
981
982                let params_value: Vec<serde_json::Value> = params
983                    .parameters
984                    .iter()
985                    .map(|(name, values, shape)| {
986                        serde_json::json!({
987                            "name": name,
988                            "shape": shape,
989                            "data": values,
990                        })
991                    })
992                    .collect();
993                result.insert("parameters", serde_json::Value::Array(params_value));
994
995                serde_json::to_vec_pretty(&result)
996                    .map_err(|e| NeuralError::SerializationError(e.to_string()))
997            }
998        }
999    }
1000
1001    fn architecture_name(&self) -> &str {
1002        "GPT"
1003    }
1004}
1005
1006impl<F> ModelDeserialize for GPTModel<F>
1007where
1008    F: Float
1009        + Debug
1010        + ScalarOperand
1011        + NumAssign
1012        + ToPrimitive
1013        + Send
1014        + Sync
1015        + SimdUnifiedOps
1016        + 'static,
1017{
1018    fn load(path: &Path, format: ModelFormat) -> Result<Self> {
1019        let bytes = fs::read(path).map_err(|e| NeuralError::IOError(e.to_string()))?;
1020        Self::from_bytes(&bytes, format)
1021    }
1022
1023    fn from_bytes(bytes: &[u8], format: ModelFormat) -> Result<Self> {
1024        match format {
1025            ModelFormat::SafeTensors => {
1026                let reader = SafeTensorsReader::from_bytes(bytes)?;
1027                let meta = reader.metadata();
1028                let arch_config_str = meta.get("architecture_config").ok_or_else(|| {
1029                    NeuralError::DeserializationError(
1030                        "Missing architecture_config in SafeTensors metadata".to_string(),
1031                    )
1032                })?;
1033                let arch_config: ArchitectureConfig = serde_json::from_str(arch_config_str)
1034                    .map_err(|e| {
1035                        NeuralError::DeserializationError(format!(
1036                            "Invalid architecture config: {e}"
1037                        ))
1038                    })?;
1039
1040                let ser_config: SerializableGPTConfig = serde_json::from_value(arch_config.config)
1041                    .map_err(|e| {
1042                        NeuralError::DeserializationError(format!("Invalid GPT config: {e}"))
1043                    })?;
1044
1045                let gpt_config = ser_config.to_gpt_config();
1046                let mut model = GPTModel::new(gpt_config)?;
1047
1048                let params = reader.to_named_parameters()?;
1049                model.load_named_parameters(&params)?;
1050
1051                Ok(model)
1052            }
1053            _ => {
1054                let raw: HashMap<String, serde_json::Value> = serde_json::from_slice(bytes)
1055                    .map_err(|e| NeuralError::DeserializationError(format!("Invalid JSON: {e}")))?;
1056
1057                let arch_value = raw.get("architecture").ok_or_else(|| {
1058                    NeuralError::DeserializationError("Missing 'architecture' key".to_string())
1059                })?;
1060
1061                let arch_config: ArchitectureConfig = serde_json::from_value(arch_value.clone())
1062                    .map_err(|e| {
1063                        NeuralError::DeserializationError(format!(
1064                            "Invalid architecture config: {e}"
1065                        ))
1066                    })?;
1067
1068                let ser_config: SerializableGPTConfig = serde_json::from_value(arch_config.config)
1069                    .map_err(|e| {
1070                        NeuralError::DeserializationError(format!("Invalid GPT config: {e}"))
1071                    })?;
1072
1073                let gpt_config = ser_config.to_gpt_config();
1074                GPTModel::new(gpt_config)
1075            }
1076        }
1077    }
1078}
1079
1080// ============================================================================
1081// ResNet deserialization
1082// ============================================================================
1083
1084impl<F> ModelDeserialize for ResNet<F>
1085where
1086    F: Float
1087        + Debug
1088        + ScalarOperand
1089        + NumAssign
1090        + ToPrimitive
1091        + FromPrimitive
1092        + Send
1093        + Sync
1094        + 'static,
1095{
1096    fn load(path: &Path, format: ModelFormat) -> Result<Self> {
1097        let bytes = fs::read(path).map_err(|e| NeuralError::IOError(e.to_string()))?;
1098        Self::from_bytes(&bytes, format)
1099    }
1100
1101    fn from_bytes(bytes: &[u8], format: ModelFormat) -> Result<Self> {
1102        match format {
1103            ModelFormat::SafeTensors => {
1104                let reader = SafeTensorsReader::from_bytes(bytes)?;
1105                let meta = reader.metadata();
1106                let arch_config_str = meta.get("architecture_config").ok_or_else(|| {
1107                    NeuralError::DeserializationError("Missing architecture_config".to_string())
1108                })?;
1109                let arch_config: ArchitectureConfig = serde_json::from_str(arch_config_str)
1110                    .map_err(|e| {
1111                        NeuralError::DeserializationError(format!(
1112                            "Invalid architecture config: {e}"
1113                        ))
1114                    })?;
1115
1116                let ser_config: SerializableResNetConfig =
1117                    serde_json::from_value(arch_config.config).map_err(|e| {
1118                        NeuralError::DeserializationError(format!("Invalid ResNet config: {e}"))
1119                    })?;
1120
1121                let resnet_config = ser_config.to_resnet_config()?;
1122                let mut model = ResNet::new(resnet_config)?;
1123
1124                let params = reader.to_named_parameters()?;
1125                model.load_named_parameters(&params)?;
1126
1127                Ok(model)
1128            }
1129            _ => {
1130                let raw: HashMap<String, serde_json::Value> = serde_json::from_slice(bytes)
1131                    .map_err(|e| NeuralError::DeserializationError(format!("Invalid JSON: {e}")))?;
1132
1133                let arch_value = raw.get("architecture").ok_or_else(|| {
1134                    NeuralError::DeserializationError("Missing 'architecture' key".to_string())
1135                })?;
1136
1137                let arch_config: ArchitectureConfig = serde_json::from_value(arch_value.clone())
1138                    .map_err(|e| {
1139                        NeuralError::DeserializationError(format!(
1140                            "Invalid architecture config: {e}"
1141                        ))
1142                    })?;
1143
1144                let ser_config: SerializableResNetConfig =
1145                    serde_json::from_value(arch_config.config).map_err(|e| {
1146                        NeuralError::DeserializationError(format!("Invalid ResNet config: {e}"))
1147                    })?;
1148
1149                let resnet_config = ser_config.to_resnet_config()?;
1150                ResNet::new(resnet_config)
1151            }
1152        }
1153    }
1154}
1155
1156// ============================================================================
1157// Mamba serialization
1158// ============================================================================
1159
1160impl<F> ExtractParameters for Mamba<F>
1161where
1162    F: Float
1163        + Debug
1164        + ScalarOperand
1165        + NumAssign
1166        + ToPrimitive
1167        + Send
1168        + Sync
1169        + SimdUnifiedOps
1170        + 'static,
1171{
1172    fn extract_named_parameters(&self) -> Result<NamedParameters> {
1173        let mut all_params = NamedParameters::new();
1174
1175        let layer_ref: &dyn Layer<F> = self;
1176        let params = layer_ref.params();
1177
1178        for (i, param) in params.iter().enumerate() {
1179            let name = format!("mamba.param_{i}");
1180            let shape: Vec<usize> = param.shape().to_vec();
1181            let values: Vec<f64> = param
1182                .iter()
1183                .map(|&x| {
1184                    x.to_f64().ok_or_else(|| {
1185                        NeuralError::SerializationError(
1186                            "Cannot convert parameter to f64".to_string(),
1187                        )
1188                    })
1189                })
1190                .collect::<Result<Vec<f64>>>()?;
1191            all_params.add(&name, values, shape);
1192        }
1193
1194        Ok(all_params)
1195    }
1196
1197    fn load_named_parameters(&mut self, _params: &NamedParameters) -> Result<()> {
1198        Ok(())
1199    }
1200}
1201
1202impl<F> ModelSerialize for Mamba<F>
1203where
1204    F: Float
1205        + Debug
1206        + ScalarOperand
1207        + NumAssign
1208        + ToPrimitive
1209        + Send
1210        + Sync
1211        + SimdUnifiedOps
1212        + 'static,
1213{
1214    fn save(&self, path: &Path, format: ModelFormat) -> Result<()> {
1215        let bytes = self.to_bytes(format)?;
1216        fs::write(path, bytes).map_err(|e| NeuralError::IOError(e.to_string()))?;
1217        Ok(())
1218    }
1219
1220    fn to_bytes(&self, format: ModelFormat) -> Result<Vec<u8>> {
1221        let config = self.config();
1222        let ser_config = SerializableMambaConfig::from(config);
1223
1224        let arch_config = ArchitectureConfig {
1225            architecture: "Mamba".to_string(),
1226            format_version: "1.0".to_string(),
1227            config: serde_json::to_value(&ser_config)
1228                .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
1229        };
1230
1231        let params = self.extract_named_parameters()?;
1232
1233        match format {
1234            ModelFormat::SafeTensors => {
1235                let metadata = ModelMetadata::new("Mamba", "f64", params.total_parameters())
1236                    .with_extra(
1237                        "architecture_config",
1238                        &serde_json::to_string(&arch_config)
1239                            .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
1240                    );
1241
1242                let mut writer = SafeTensorsWriter::new();
1243                writer.add_model_metadata(&metadata);
1244                writer.add_named_parameters(&params)?;
1245                writer.to_bytes()
1246            }
1247            _ => {
1248                let mut result = HashMap::new();
1249                result.insert(
1250                    "architecture",
1251                    serde_json::to_value(&arch_config)
1252                        .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
1253                );
1254
1255                let params_value: Vec<serde_json::Value> = params
1256                    .parameters
1257                    .iter()
1258                    .map(|(name, values, shape)| {
1259                        serde_json::json!({
1260                            "name": name,
1261                            "shape": shape,
1262                            "data": values,
1263                        })
1264                    })
1265                    .collect();
1266                result.insert("parameters", serde_json::Value::Array(params_value));
1267
1268                serde_json::to_vec_pretty(&result)
1269                    .map_err(|e| NeuralError::SerializationError(e.to_string()))
1270            }
1271        }
1272    }
1273
1274    fn architecture_name(&self) -> &str {
1275        "Mamba"
1276    }
1277}
1278
1279impl<F> ModelDeserialize for Mamba<F>
1280where
1281    F: Float
1282        + Debug
1283        + ScalarOperand
1284        + NumAssign
1285        + ToPrimitive
1286        + Send
1287        + Sync
1288        + SimdUnifiedOps
1289        + 'static,
1290{
1291    fn load(path: &Path, format: ModelFormat) -> Result<Self> {
1292        let bytes = fs::read(path).map_err(|e| NeuralError::IOError(e.to_string()))?;
1293        Self::from_bytes(&bytes, format)
1294    }
1295
1296    fn from_bytes(bytes: &[u8], format: ModelFormat) -> Result<Self> {
1297        match format {
1298            ModelFormat::SafeTensors => {
1299                let reader = SafeTensorsReader::from_bytes(bytes)?;
1300                let meta = reader.metadata();
1301                let arch_config_str = meta.get("architecture_config").ok_or_else(|| {
1302                    NeuralError::DeserializationError("Missing architecture_config".to_string())
1303                })?;
1304                let arch_config: ArchitectureConfig = serde_json::from_str(arch_config_str)
1305                    .map_err(|e| {
1306                        NeuralError::DeserializationError(format!(
1307                            "Invalid architecture config: {e}"
1308                        ))
1309                    })?;
1310
1311                let ser_config: SerializableMambaConfig =
1312                    serde_json::from_value(arch_config.config).map_err(|e| {
1313                        NeuralError::DeserializationError(format!("Invalid Mamba config: {e}"))
1314                    })?;
1315
1316                let mamba_config = ser_config.to_mamba_config();
1317                let mut rng = scirs2_core::ChaCha8Rng::seed_from_u64(42);
1318                let mut model = Mamba::new(mamba_config, &mut rng)?;
1319
1320                let params = reader.to_named_parameters()?;
1321                model.load_named_parameters(&params)?;
1322
1323                Ok(model)
1324            }
1325            _ => {
1326                let raw: HashMap<String, serde_json::Value> = serde_json::from_slice(bytes)
1327                    .map_err(|e| NeuralError::DeserializationError(format!("Invalid JSON: {e}")))?;
1328
1329                let arch_value = raw.get("architecture").ok_or_else(|| {
1330                    NeuralError::DeserializationError("Missing 'architecture' key".to_string())
1331                })?;
1332
1333                let arch_config: ArchitectureConfig = serde_json::from_value(arch_value.clone())
1334                    .map_err(|e| {
1335                        NeuralError::DeserializationError(format!(
1336                            "Invalid architecture config: {e}"
1337                        ))
1338                    })?;
1339
1340                let ser_config: SerializableMambaConfig =
1341                    serde_json::from_value(arch_config.config).map_err(|e| {
1342                        NeuralError::DeserializationError(format!("Invalid Mamba config: {e}"))
1343                    })?;
1344
1345                let mamba_config = ser_config.to_mamba_config();
1346                let mut rng = scirs2_core::ChaCha8Rng::seed_from_u64(42);
1347                Mamba::new(mamba_config, &mut rng)
1348            }
1349        }
1350    }
1351}
1352
1353// ============================================================================
1354// EfficientNet serialization
1355// ============================================================================
1356
1357impl<F> ExtractParameters for EfficientNet<F>
1358where
1359    F: Float + Debug + ScalarOperand + NumAssign + ToPrimitive + Send + Sync + 'static,
1360{
1361    fn extract_named_parameters(&self) -> Result<NamedParameters> {
1362        let mut all_params = NamedParameters::new();
1363
1364        let layer_ref: &dyn Layer<F> = self;
1365        let params = layer_ref.params();
1366
1367        for (i, param) in params.iter().enumerate() {
1368            let name = format!("efficientnet.param_{i}");
1369            let shape: Vec<usize> = param.shape().to_vec();
1370            let values: Vec<f64> = param
1371                .iter()
1372                .map(|&x| {
1373                    x.to_f64().ok_or_else(|| {
1374                        NeuralError::SerializationError(
1375                            "Cannot convert parameter to f64".to_string(),
1376                        )
1377                    })
1378                })
1379                .collect::<Result<Vec<f64>>>()?;
1380            all_params.add(&name, values, shape);
1381        }
1382
1383        Ok(all_params)
1384    }
1385
1386    fn load_named_parameters(&mut self, _params: &NamedParameters) -> Result<()> {
1387        Ok(())
1388    }
1389}
1390
1391impl<F> ModelSerialize for EfficientNet<F>
1392where
1393    F: Float + Debug + ScalarOperand + NumAssign + ToPrimitive + Send + Sync + 'static,
1394{
1395    fn save(&self, path: &Path, format: ModelFormat) -> Result<()> {
1396        let bytes = self.to_bytes(format)?;
1397        fs::write(path, bytes).map_err(|e| NeuralError::IOError(e.to_string()))?;
1398        Ok(())
1399    }
1400
1401    fn to_bytes(&self, format: ModelFormat) -> Result<Vec<u8>> {
1402        let config = self.config();
1403        let ser_config = SerializableEfficientNetConfig {
1404            width_coefficient: config.width_coefficient,
1405            depth_coefficient: config.depth_coefficient,
1406            resolution: config.resolution,
1407            dropout_rate: config.dropout_rate,
1408            input_channels: config.input_channels,
1409            num_classes: config.num_classes,
1410        };
1411
1412        let arch_config = ArchitectureConfig {
1413            architecture: "EfficientNet".to_string(),
1414            format_version: "1.0".to_string(),
1415            config: serde_json::to_value(&ser_config)
1416                .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
1417        };
1418
1419        let params = self.extract_named_parameters()?;
1420
1421        match format {
1422            ModelFormat::SafeTensors => {
1423                let metadata = ModelMetadata::new("EfficientNet", "f64", params.total_parameters())
1424                    .with_extra(
1425                        "architecture_config",
1426                        &serde_json::to_string(&arch_config)
1427                            .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
1428                    );
1429
1430                let mut writer = SafeTensorsWriter::new();
1431                writer.add_model_metadata(&metadata);
1432                writer.add_named_parameters(&params)?;
1433                writer.to_bytes()
1434            }
1435            _ => {
1436                let mut result = HashMap::new();
1437                result.insert(
1438                    "architecture",
1439                    serde_json::to_value(&arch_config)
1440                        .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
1441                );
1442
1443                let params_value: Vec<serde_json::Value> = params
1444                    .parameters
1445                    .iter()
1446                    .map(|(name, values, shape)| {
1447                        serde_json::json!({
1448                            "name": name,
1449                            "shape": shape,
1450                            "data": values,
1451                        })
1452                    })
1453                    .collect();
1454                result.insert("parameters", serde_json::Value::Array(params_value));
1455
1456                serde_json::to_vec_pretty(&result)
1457                    .map_err(|e| NeuralError::SerializationError(e.to_string()))
1458            }
1459        }
1460    }
1461
1462    fn architecture_name(&self) -> &str {
1463        "EfficientNet"
1464    }
1465}
1466
1467// ============================================================================
1468// MobileNet serialization
1469// ============================================================================
1470
1471impl<F> ExtractParameters for MobileNet<F>
1472where
1473    F: Float + Debug + ScalarOperand + NumAssign + ToPrimitive + Send + Sync + 'static,
1474{
1475    fn extract_named_parameters(&self) -> Result<NamedParameters> {
1476        let mut all_params = NamedParameters::new();
1477
1478        let layer_ref: &dyn Layer<F> = self;
1479        let params = layer_ref.params();
1480
1481        for (i, param) in params.iter().enumerate() {
1482            let name = format!("mobilenet.param_{i}");
1483            let shape: Vec<usize> = param.shape().to_vec();
1484            let values: Vec<f64> = param
1485                .iter()
1486                .map(|&x| {
1487                    x.to_f64().ok_or_else(|| {
1488                        NeuralError::SerializationError(
1489                            "Cannot convert parameter to f64".to_string(),
1490                        )
1491                    })
1492                })
1493                .collect::<Result<Vec<f64>>>()?;
1494            all_params.add(&name, values, shape);
1495        }
1496
1497        Ok(all_params)
1498    }
1499
1500    fn load_named_parameters(&mut self, _params: &NamedParameters) -> Result<()> {
1501        Ok(())
1502    }
1503}
1504
1505impl<F> ModelSerialize for MobileNet<F>
1506where
1507    F: Float + Debug + ScalarOperand + NumAssign + ToPrimitive + Send + Sync + 'static,
1508{
1509    fn save(&self, path: &Path, format: ModelFormat) -> Result<()> {
1510        let bytes = self.to_bytes(format)?;
1511        fs::write(path, bytes).map_err(|e| NeuralError::IOError(e.to_string()))?;
1512        Ok(())
1513    }
1514
1515    fn to_bytes(&self, format: ModelFormat) -> Result<Vec<u8>> {
1516        let config = self.config();
1517        let ser_config = SerializableMobileNetConfig {
1518            version: match config.version {
1519                MobileNetVersion::V1 => "V1".to_string(),
1520                MobileNetVersion::V2 => "V2".to_string(),
1521                MobileNetVersion::V3Small => "V3Small".to_string(),
1522                MobileNetVersion::V3Large => "V3Large".to_string(),
1523            },
1524            width_multiplier: config.width_multiplier,
1525            resolution_multiplier: config.resolution_multiplier,
1526            dropout_rate: config.dropout_rate,
1527            input_channels: config.input_channels,
1528            num_classes: config.num_classes,
1529        };
1530
1531        let arch_config = ArchitectureConfig {
1532            architecture: "MobileNet".to_string(),
1533            format_version: "1.0".to_string(),
1534            config: serde_json::to_value(&ser_config)
1535                .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
1536        };
1537
1538        let params = self.extract_named_parameters()?;
1539
1540        match format {
1541            ModelFormat::SafeTensors => {
1542                let metadata = ModelMetadata::new("MobileNet", "f64", params.total_parameters())
1543                    .with_extra(
1544                        "architecture_config",
1545                        &serde_json::to_string(&arch_config)
1546                            .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
1547                    );
1548
1549                let mut writer = SafeTensorsWriter::new();
1550                writer.add_model_metadata(&metadata);
1551                writer.add_named_parameters(&params)?;
1552                writer.to_bytes()
1553            }
1554            _ => {
1555                let mut result = HashMap::new();
1556                result.insert(
1557                    "architecture",
1558                    serde_json::to_value(&arch_config)
1559                        .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
1560                );
1561
1562                let params_value: Vec<serde_json::Value> = params
1563                    .parameters
1564                    .iter()
1565                    .map(|(name, values, shape)| {
1566                        serde_json::json!({
1567                            "name": name,
1568                            "shape": shape,
1569                            "data": values,
1570                        })
1571                    })
1572                    .collect();
1573                result.insert("parameters", serde_json::Value::Array(params_value));
1574
1575                serde_json::to_vec_pretty(&result)
1576                    .map_err(|e| NeuralError::SerializationError(e.to_string()))
1577            }
1578        }
1579    }
1580
1581    fn architecture_name(&self) -> &str {
1582        "MobileNet"
1583    }
1584}
1585
1586// ============================================================================
1587// Utility: detect architecture from file
1588// ============================================================================
1589
1590/// Detect the architecture type from a serialized model file
1591pub fn detect_architecture(path: &Path) -> Result<String> {
1592    let bytes = fs::read(path).map_err(|e| NeuralError::IOError(e.to_string()))?;
1593    detect_architecture_from_bytes(&bytes)
1594}
1595
1596/// Detect the architecture type from serialized bytes
1597pub fn detect_architecture_from_bytes(bytes: &[u8]) -> Result<String> {
1598    // Try SafeTensors first (starts with 8-byte header size)
1599    if bytes.len() >= 8 {
1600        if let Ok(reader) = SafeTensorsReader::from_bytes(bytes) {
1601            let meta = reader.metadata();
1602            if let Some(arch) = meta.get("architecture") {
1603                return Ok(arch.clone());
1604            }
1605        }
1606    }
1607
1608    // Try JSON
1609    if let Ok(raw) = serde_json::from_slice::<HashMap<String, serde_json::Value>>(bytes) {
1610        if let Some(arch_value) = raw.get("architecture") {
1611            if let Ok(arch_config) =
1612                serde_json::from_value::<ArchitectureConfig>(arch_value.clone())
1613            {
1614                return Ok(arch_config.architecture);
1615            }
1616        }
1617    }
1618
1619    Err(NeuralError::DeserializationError(
1620        "Cannot detect architecture from file: unrecognized format".to_string(),
1621    ))
1622}
1623
1624#[cfg(test)]
1625mod tests {
1626    use super::*;
1627
1628    #[test]
1629    fn test_serializable_resnet_config_roundtrip() -> Result<()> {
1630        let config = ResNetConfig::resnet18(3, 1000);
1631        let ser = SerializableResNetConfig::from(&config);
1632
1633        // Serialize to JSON
1634        let json = serde_json::to_string(&ser)
1635            .map_err(|e| NeuralError::SerializationError(e.to_string()))?;
1636
1637        // Deserialize back
1638        let deser: SerializableResNetConfig = serde_json::from_str(&json)
1639            .map_err(|e| NeuralError::DeserializationError(e.to_string()))?;
1640
1641        let restored = deser.to_resnet_config()?;
1642        assert_eq!(restored.input_channels, 3);
1643        assert_eq!(restored.num_classes, 1000);
1644        assert_eq!(restored.layers.len(), 4);
1645
1646        Ok(())
1647    }
1648
1649    #[test]
1650    fn test_serializable_bert_config_roundtrip() -> Result<()> {
1651        let config = BertConfig::bert_base_uncased();
1652        let ser = SerializableBertConfig::from(&config);
1653
1654        let json = serde_json::to_string(&ser)
1655            .map_err(|e| NeuralError::SerializationError(e.to_string()))?;
1656
1657        let deser: SerializableBertConfig = serde_json::from_str(&json)
1658            .map_err(|e| NeuralError::DeserializationError(e.to_string()))?;
1659
1660        let restored = deser.to_bert_config();
1661        assert_eq!(restored.vocab_size, 30522);
1662        assert_eq!(restored.hidden_size, 768);
1663        assert_eq!(restored.num_hidden_layers, 12);
1664
1665        Ok(())
1666    }
1667
1668    #[test]
1669    fn test_serializable_gpt_config_roundtrip() -> Result<()> {
1670        let config = GPTConfig::gpt2_small();
1671        let ser = SerializableGPTConfig::from(&config);
1672
1673        let json = serde_json::to_string(&ser)
1674            .map_err(|e| NeuralError::SerializationError(e.to_string()))?;
1675
1676        let deser: SerializableGPTConfig = serde_json::from_str(&json)
1677            .map_err(|e| NeuralError::DeserializationError(e.to_string()))?;
1678
1679        let restored = deser.to_gpt_config();
1680        assert_eq!(restored.vocab_size, 50257);
1681        assert_eq!(restored.hidden_size, 768);
1682
1683        Ok(())
1684    }
1685
1686    #[test]
1687    fn test_serializable_mamba_config_roundtrip() -> Result<()> {
1688        let config = MambaConfig::new(256);
1689        let ser = SerializableMambaConfig::from(&config);
1690
1691        let json = serde_json::to_string(&ser)
1692            .map_err(|e| NeuralError::SerializationError(e.to_string()))?;
1693
1694        let deser: SerializableMambaConfig = serde_json::from_str(&json)
1695            .map_err(|e| NeuralError::DeserializationError(e.to_string()))?;
1696
1697        let restored = deser.to_mamba_config();
1698        assert_eq!(restored.d_model, 256);
1699        assert_eq!(restored.d_state, 16);
1700
1701        Ok(())
1702    }
1703
1704    #[test]
1705    fn test_architecture_config_envelope() -> Result<()> {
1706        let config = ArchitectureConfig {
1707            architecture: "ResNet".to_string(),
1708            format_version: "1.0".to_string(),
1709            config: serde_json::json!({
1710                "block": "Basic",
1711                "input_channels": 3,
1712                "num_classes": 10,
1713            }),
1714        };
1715
1716        let json = serde_json::to_string(&config)
1717            .map_err(|e| NeuralError::SerializationError(e.to_string()))?;
1718
1719        let restored: ArchitectureConfig = serde_json::from_str(&json)
1720            .map_err(|e| NeuralError::DeserializationError(e.to_string()))?;
1721
1722        assert_eq!(restored.architecture, "ResNet");
1723        assert_eq!(restored.format_version, "1.0");
1724
1725        Ok(())
1726    }
1727
1728    #[test]
1729    fn test_resnet_model_serialize() -> Result<()> {
1730        let config = ResNetConfig::resnet18(3, 10);
1731        let model = ResNet::<f64>::new(config)?;
1732
1733        // Test SafeTensors serialization
1734        let bytes = model.to_bytes(ModelFormat::SafeTensors)?;
1735        assert!(!bytes.is_empty());
1736
1737        // Verify it can be read back
1738        let reader = SafeTensorsReader::from_bytes(&bytes)?;
1739        let meta = reader.metadata();
1740        assert_eq!(meta.get("architecture"), Some(&"ResNet".to_string()));
1741
1742        // Test JSON serialization
1743        let json_bytes = model.to_bytes(ModelFormat::Json)?;
1744        assert!(!json_bytes.is_empty());
1745
1746        Ok(())
1747    }
1748
1749    #[test]
1750    fn test_resnet_save_load_roundtrip() -> Result<()> {
1751        let test_dir = std::env::temp_dir().join("scirs2_arch_resnet");
1752        fs::create_dir_all(&test_dir).map_err(|e| NeuralError::IOError(e.to_string()))?;
1753        let path = test_dir.join("resnet18.safetensors");
1754
1755        let config = ResNetConfig::resnet18(3, 10);
1756        let model = ResNet::<f64>::new(config)?;
1757        model.save(&path, ModelFormat::SafeTensors)?;
1758
1759        let loaded = ResNet::<f64>::load(&path, ModelFormat::SafeTensors)?;
1760        assert_eq!(loaded.config().input_channels, 3);
1761        assert_eq!(loaded.config().num_classes, 10);
1762
1763        let _ = fs::remove_dir_all(&test_dir);
1764        Ok(())
1765    }
1766
1767    #[test]
1768    fn test_detect_architecture_safetensors() -> Result<()> {
1769        let config = ResNetConfig::resnet18(3, 10);
1770        let model = ResNet::<f64>::new(config)?;
1771        let bytes = model.to_bytes(ModelFormat::SafeTensors)?;
1772
1773        let arch = detect_architecture_from_bytes(&bytes)?;
1774        assert_eq!(arch, "ResNet");
1775        Ok(())
1776    }
1777
1778    #[test]
1779    fn test_detect_architecture_json() -> Result<()> {
1780        let config = ResNetConfig::resnet18(3, 10);
1781        let model = ResNet::<f64>::new(config)?;
1782        let bytes = model.to_bytes(ModelFormat::Json)?;
1783
1784        let arch = detect_architecture_from_bytes(&bytes)?;
1785        assert_eq!(arch, "ResNet");
1786        Ok(())
1787    }
1788}