1use crate::error::{NeuralError, Result};
17use crate::layers::{BatchNorm, Conv2D, Dense, Dropout, Layer, LayerNorm, LSTM};
18use crate::models::architectures::{
19 BertConfig, BertModel, EfficientNet, EfficientNetConfig, GPTConfig, GPTModel, Mamba,
20 MambaConfig, MobileNet, MobileNetConfig, MobileNetVersion, ResNet, ResNetBlock, ResNetConfig,
21 ResNetLayer,
22};
23use crate::models::sequential::Sequential;
24use crate::serialization::safetensors::{SafeTensorsReader, SafeTensorsWriter};
25use crate::serialization::traits::{
26 ExtractParameters, ModelDeserialize, ModelFormat, ModelMetadata, ModelSerialize,
27 NamedParameters,
28};
29use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
30use scirs2_core::numeric::{Float, FromPrimitive, NumAssign, ToPrimitive};
31use scirs2_core::random::SeedableRng;
32use scirs2_core::simd_ops::SimdUnifiedOps;
33use serde::{Deserialize, Serialize};
34use std::collections::HashMap;
35use std::fmt::{Debug, Display};
36use std::fs;
37use std::path::Path;
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct ArchitectureConfig {
46 pub architecture: String,
48 pub format_version: String,
50 pub config: serde_json::Value,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct SerializableResNetConfig {
57 pub block: String,
59 pub layers: Vec<SerializableResNetLayer>,
61 pub input_channels: usize,
63 pub num_classes: usize,
65 pub dropout_rate: f64,
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct SerializableResNetLayer {
72 pub blocks: usize,
74 pub channels: usize,
76 pub stride: usize,
78}
79
80impl From<&ResNetConfig> for SerializableResNetConfig {
81 fn from(config: &ResNetConfig) -> Self {
82 Self {
83 block: match config.block {
84 ResNetBlock::Basic => "Basic".to_string(),
85 ResNetBlock::Bottleneck => "Bottleneck".to_string(),
86 },
87 layers: config
88 .layers
89 .iter()
90 .map(|l| SerializableResNetLayer {
91 blocks: l.blocks,
92 channels: l.channels,
93 stride: l.stride,
94 })
95 .collect(),
96 input_channels: config.input_channels,
97 num_classes: config.num_classes,
98 dropout_rate: config.dropout_rate,
99 }
100 }
101}
102
103impl SerializableResNetConfig {
104 pub fn to_resnet_config(&self) -> Result<ResNetConfig> {
106 let block = match self.block.as_str() {
107 "Basic" => ResNetBlock::Basic,
108 "Bottleneck" => ResNetBlock::Bottleneck,
109 other => {
110 return Err(NeuralError::DeserializationError(format!(
111 "Unknown ResNet block type: {other}"
112 )))
113 }
114 };
115
116 Ok(ResNetConfig {
117 block,
118 layers: self
119 .layers
120 .iter()
121 .map(|l| ResNetLayer {
122 blocks: l.blocks,
123 channels: l.channels,
124 stride: l.stride,
125 })
126 .collect(),
127 input_channels: self.input_channels,
128 num_classes: self.num_classes,
129 dropout_rate: self.dropout_rate,
130 })
131 }
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct SerializableBertConfig {
137 pub vocab_size: usize,
138 pub max_position_embeddings: usize,
139 pub hidden_size: usize,
140 pub num_hidden_layers: usize,
141 pub num_attention_heads: usize,
142 pub intermediate_size: usize,
143 pub hidden_act: String,
144 pub hidden_dropout_prob: f64,
145 pub attention_probs_dropout_prob: f64,
146 pub type_vocab_size: usize,
147 pub layer_norm_eps: f64,
148 pub initializer_range: f64,
149}
150
151impl From<&BertConfig> for SerializableBertConfig {
152 fn from(config: &BertConfig) -> Self {
153 Self {
154 vocab_size: config.vocab_size,
155 max_position_embeddings: config.max_position_embeddings,
156 hidden_size: config.hidden_size,
157 num_hidden_layers: config.num_hidden_layers,
158 num_attention_heads: config.num_attention_heads,
159 intermediate_size: config.intermediate_size,
160 hidden_act: config.hidden_act.clone(),
161 hidden_dropout_prob: config.hidden_dropout_prob,
162 attention_probs_dropout_prob: config.attention_probs_dropout_prob,
163 type_vocab_size: config.type_vocab_size,
164 layer_norm_eps: config.layer_norm_eps,
165 initializer_range: config.initializer_range,
166 }
167 }
168}
169
170impl SerializableBertConfig {
171 pub fn to_bert_config(&self) -> BertConfig {
173 BertConfig {
174 vocab_size: self.vocab_size,
175 max_position_embeddings: self.max_position_embeddings,
176 hidden_size: self.hidden_size,
177 num_hidden_layers: self.num_hidden_layers,
178 num_attention_heads: self.num_attention_heads,
179 intermediate_size: self.intermediate_size,
180 hidden_act: self.hidden_act.clone(),
181 hidden_dropout_prob: self.hidden_dropout_prob,
182 attention_probs_dropout_prob: self.attention_probs_dropout_prob,
183 type_vocab_size: self.type_vocab_size,
184 layer_norm_eps: self.layer_norm_eps,
185 initializer_range: self.initializer_range,
186 }
187 }
188}
189
190#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct SerializableGPTConfig {
193 pub vocab_size: usize,
194 pub max_position_embeddings: usize,
195 pub hidden_size: usize,
196 pub num_hidden_layers: usize,
197 pub num_attention_heads: usize,
198 pub intermediate_size: usize,
199 pub hidden_act: String,
200 pub hidden_dropout_prob: f64,
201 pub attention_probs_dropout_prob: f64,
202 pub layer_norm_eps: f64,
203 pub initializer_range: f64,
204}
205
206impl From<&GPTConfig> for SerializableGPTConfig {
207 fn from(config: &GPTConfig) -> Self {
208 Self {
209 vocab_size: config.vocab_size,
210 max_position_embeddings: config.max_position_embeddings,
211 hidden_size: config.hidden_size,
212 num_hidden_layers: config.num_hidden_layers,
213 num_attention_heads: config.num_attention_heads,
214 intermediate_size: config.intermediate_size,
215 hidden_act: config.hidden_act.clone(),
216 hidden_dropout_prob: config.hidden_dropout_prob,
217 attention_probs_dropout_prob: config.attention_probs_dropout_prob,
218 layer_norm_eps: config.layer_norm_eps,
219 initializer_range: config.initializer_range,
220 }
221 }
222}
223
224impl SerializableGPTConfig {
225 pub fn to_gpt_config(&self) -> GPTConfig {
227 GPTConfig {
228 vocab_size: self.vocab_size,
229 max_position_embeddings: self.max_position_embeddings,
230 hidden_size: self.hidden_size,
231 num_hidden_layers: self.num_hidden_layers,
232 num_attention_heads: self.num_attention_heads,
233 intermediate_size: self.intermediate_size,
234 hidden_act: self.hidden_act.clone(),
235 hidden_dropout_prob: self.hidden_dropout_prob,
236 attention_probs_dropout_prob: self.attention_probs_dropout_prob,
237 layer_norm_eps: self.layer_norm_eps,
238 initializer_range: self.initializer_range,
239 }
240 }
241}
242
243#[derive(Debug, Clone, Serialize, Deserialize)]
245pub struct SerializableMambaConfig {
246 pub d_model: usize,
247 pub d_state: usize,
248 pub d_conv: usize,
249 pub expand: usize,
250 pub n_layers: usize,
251 pub dropout_prob: f64,
252 pub vocab_size: Option<usize>,
253 pub num_classes: Option<usize>,
254 pub dt_rank: Option<usize>,
255 pub bias: bool,
256 pub dt_min: f64,
257 pub dt_max: f64,
258}
259
260impl From<&MambaConfig> for SerializableMambaConfig {
261 fn from(config: &MambaConfig) -> Self {
262 Self {
263 d_model: config.d_model,
264 d_state: config.d_state,
265 d_conv: config.d_conv,
266 expand: config.expand,
267 n_layers: config.n_layers,
268 dropout_prob: config.dropout_prob,
269 vocab_size: config.vocab_size,
270 num_classes: config.num_classes,
271 dt_rank: config.dt_rank,
272 bias: config.bias,
273 dt_min: config.dt_min,
274 dt_max: config.dt_max,
275 }
276 }
277}
278
279impl SerializableMambaConfig {
280 pub fn to_mamba_config(&self) -> MambaConfig {
282 MambaConfig {
283 d_model: self.d_model,
284 d_state: self.d_state,
285 d_conv: self.d_conv,
286 expand: self.expand,
287 n_layers: self.n_layers,
288 dropout_prob: self.dropout_prob,
289 vocab_size: self.vocab_size,
290 num_classes: self.num_classes,
291 dt_rank: self.dt_rank,
292 bias: self.bias,
293 dt_min: self.dt_min,
294 dt_max: self.dt_max,
295 }
296 }
297}
298
299#[derive(Debug, Clone, Serialize, Deserialize)]
301pub struct SerializableEfficientNetConfig {
302 pub width_coefficient: f64,
303 pub depth_coefficient: f64,
304 pub resolution: usize,
305 pub dropout_rate: f64,
306 pub input_channels: usize,
307 pub num_classes: usize,
308}
309
310#[derive(Debug, Clone, Serialize, Deserialize)]
312pub struct SerializableMobileNetConfig {
313 pub version: String,
314 pub width_multiplier: f64,
315 pub resolution_multiplier: f64,
316 pub dropout_rate: f64,
317 pub input_channels: usize,
318 pub num_classes: usize,
319}
320
321fn extract_layer_params<F: Float + Debug + ScalarOperand + NumAssign + ToPrimitive>(
327 layer: &dyn Layer<F>,
328 prefix: &str,
329) -> Result<NamedParameters> {
330 let mut named = NamedParameters::new();
331 let params = layer.params();
332
333 if params.is_empty() {
334 return Ok(named);
335 }
336
337 for (i, param) in params.iter().enumerate() {
342 let param_name = match i {
343 0 => format!("{prefix}.weight"),
344 1 => format!("{prefix}.bias"),
345 2 => format!("{prefix}.running_mean"),
346 3 => format!("{prefix}.running_var"),
347 n => format!("{prefix}.param_{n}"),
348 };
349
350 let shape: Vec<usize> = param.shape().to_vec();
351 let values: Vec<f64> = param
352 .iter()
353 .map(|&x| {
354 x.to_f64().ok_or_else(|| {
355 NeuralError::SerializationError("Cannot convert parameter to f64".to_string())
356 })
357 })
358 .collect::<Result<Vec<f64>>>()?;
359
360 named.add(¶m_name, values, shape);
361 }
362
363 Ok(named)
364}
365
366#[derive(Debug, Clone, Serialize, Deserialize)]
372pub struct SerializableSequentialConfig {
373 pub layers: Vec<SerializableLayerInfo>,
375}
376
377#[derive(Debug, Clone, Serialize, Deserialize)]
379pub struct SerializableLayerInfo {
380 pub layer_type: String,
382 pub index: usize,
384 #[serde(default)]
386 pub config: serde_json::Value,
387}
388
389impl<F> ExtractParameters for Sequential<F>
390where
391 F: Float + Debug + ScalarOperand + FromPrimitive + Display + NumAssign + ToPrimitive + 'static,
392{
393 fn extract_named_parameters(&self) -> Result<NamedParameters> {
394 let mut all_params = NamedParameters::new();
395
396 for (i, layer) in self.layers().iter().enumerate() {
397 let prefix = format!("layers.{i}");
398 let layer_params = extract_layer_params(layer.as_ref(), &prefix)?;
399 for (name, values, shape) in layer_params.parameters {
400 all_params.add(&name, values, shape);
401 }
402 }
403
404 Ok(all_params)
405 }
406
407 fn load_named_parameters(&mut self, params: &NamedParameters) -> Result<()> {
408 let num_layers = self.layers().len();
411
412 for i in 0..num_layers {
413 let prefix = format!("layers.{i}");
414 let mut layer_param_arrays: Vec<Array<F, IxDyn>> = Vec::new();
415
416 let mut matching: Vec<&(String, Vec<f64>, Vec<usize>)> = params
418 .parameters
419 .iter()
420 .filter(|(name, _, _)| name.starts_with(&prefix))
421 .collect();
422 matching.sort_by(|(a, _, _), (b, _, _)| a.cmp(b));
423
424 for (_, values, shape) in &matching {
425 let f_vec: Vec<F> = values
426 .iter()
427 .map(|&x| {
428 F::from(x).ok_or_else(|| {
429 NeuralError::DeserializationError(format!(
430 "Cannot convert {x} to target type"
431 ))
432 })
433 })
434 .collect::<Result<Vec<F>>>()?;
435 let arr = Array::from_shape_vec(IxDyn(shape), f_vec)?;
436 layer_param_arrays.push(arr);
437 }
438
439 if !layer_param_arrays.is_empty() {
440 self.layers_mut()[i].set_params(&layer_param_arrays)?;
442 }
443 }
444
445 Ok(())
446 }
447}
448
449impl<F> ModelSerialize for Sequential<F>
450where
451 F: Float + Debug + ScalarOperand + FromPrimitive + Display + NumAssign + ToPrimitive + 'static,
452{
453 fn save(&self, path: &Path, format: ModelFormat) -> Result<()> {
454 let bytes = self.to_bytes(format)?;
455 fs::write(path, bytes).map_err(|e| NeuralError::IOError(e.to_string()))?;
456 Ok(())
457 }
458
459 fn to_bytes(&self, format: ModelFormat) -> Result<Vec<u8>> {
460 let mut layers_info = Vec::new();
462 for (i, layer) in self.layers().iter().enumerate() {
463 layers_info.push(SerializableLayerInfo {
464 layer_type: layer.layer_type().to_string(),
465 index: i,
466 config: serde_json::Value::Object(serde_json::Map::new()),
467 });
468 }
469
470 let seq_config = SerializableSequentialConfig {
471 layers: layers_info,
472 };
473
474 let arch_config = ArchitectureConfig {
475 architecture: "Sequential".to_string(),
476 format_version: "1.0".to_string(),
477 config: serde_json::to_value(&seq_config)
478 .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
479 };
480
481 let params = self.extract_named_parameters()?;
482
483 match format {
484 ModelFormat::Json => {
485 let mut result = HashMap::new();
486 result.insert(
487 "architecture",
488 serde_json::to_value(&arch_config)
489 .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
490 );
491
492 let params_value: Vec<serde_json::Value> = params
494 .parameters
495 .iter()
496 .map(|(name, values, shape)| {
497 serde_json::json!({
498 "name": name,
499 "shape": shape,
500 "data": values,
501 })
502 })
503 .collect();
504 result.insert("parameters", serde_json::Value::Array(params_value));
505
506 serde_json::to_vec_pretty(&result)
507 .map_err(|e| NeuralError::SerializationError(e.to_string()))
508 }
509 ModelFormat::SafeTensors => {
510 let metadata = ModelMetadata::new("Sequential", "f64", params.total_parameters())
511 .with_extra(
512 "architecture_config",
513 &serde_json::to_string(&arch_config)
514 .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
515 );
516
517 let mut writer = SafeTensorsWriter::new();
518 writer.add_model_metadata(&metadata);
519 writer.add_named_parameters(¶ms)?;
520 writer.to_bytes()
521 }
522 ModelFormat::Cbor | ModelFormat::MessagePack => {
523 self.to_bytes(ModelFormat::Json)
525 }
526 }
527 }
528
529 fn architecture_name(&self) -> &str {
530 "Sequential"
531 }
532}
533
534impl<F> ExtractParameters for ResNet<F>
539where
540 F: Float
541 + Debug
542 + ScalarOperand
543 + NumAssign
544 + ToPrimitive
545 + FromPrimitive
546 + Send
547 + Sync
548 + 'static,
549{
550 fn extract_named_parameters(&self) -> Result<NamedParameters> {
551 let mut all_params = NamedParameters::new();
552 let named = self.extract_named_params()?;
553
554 for (name, param) in named {
555 let shape: Vec<usize> = param.shape().to_vec();
556 let values: Vec<f64> = param
557 .iter()
558 .map(|&x| {
559 x.to_f64().ok_or_else(|| {
560 NeuralError::SerializationError(
561 "Cannot convert parameter to f64".to_string(),
562 )
563 })
564 })
565 .collect::<Result<Vec<f64>>>()?;
566 all_params.add(&name, values, shape);
567 }
568
569 Ok(all_params)
570 }
571
572 fn load_named_parameters(&mut self, params: &NamedParameters) -> Result<()> {
573 let mut params_map = HashMap::new();
574 for (name, values, shape) in ¶ms.parameters {
575 let f_values: Vec<F> = values
576 .iter()
577 .map(|&x| {
578 F::from(x).ok_or_else(|| {
579 NeuralError::DeserializationError(format!(
580 "Cannot convert {x} to target type"
581 ))
582 })
583 })
584 .collect::<Result<Vec<F>>>()?;
585 let arr = Array::from_shape_vec(IxDyn(shape), f_values)?;
586 params_map.insert(name.clone(), arr);
587 }
588 self.load_named_params(¶ms_map)
589 }
590}
591
592impl<F> ModelSerialize for ResNet<F>
593where
594 F: Float
595 + Debug
596 + ScalarOperand
597 + NumAssign
598 + ToPrimitive
599 + FromPrimitive
600 + Send
601 + Sync
602 + 'static,
603{
604 fn save(&self, path: &Path, format: ModelFormat) -> Result<()> {
605 let bytes = self.to_bytes(format)?;
606 fs::write(path, bytes).map_err(|e| NeuralError::IOError(e.to_string()))?;
607 Ok(())
608 }
609
610 fn to_bytes(&self, format: ModelFormat) -> Result<Vec<u8>> {
611 let config = self.config();
612 let ser_config = SerializableResNetConfig::from(config);
613
614 let arch_config = ArchitectureConfig {
615 architecture: "ResNet".to_string(),
616 format_version: "1.0".to_string(),
617 config: serde_json::to_value(&ser_config)
618 .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
619 };
620
621 let params = self.extract_named_parameters()?;
622
623 match format {
624 ModelFormat::SafeTensors => {
625 let metadata = ModelMetadata::new("ResNet", "f64", params.total_parameters())
626 .with_extra(
627 "architecture_config",
628 &serde_json::to_string(&arch_config)
629 .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
630 );
631
632 let mut writer = SafeTensorsWriter::new();
633 writer.add_model_metadata(&metadata);
634 writer.add_named_parameters(¶ms)?;
635 writer.to_bytes()
636 }
637 ModelFormat::Json | ModelFormat::Cbor | ModelFormat::MessagePack => {
638 let mut result = HashMap::new();
639 result.insert(
640 "architecture",
641 serde_json::to_value(&arch_config)
642 .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
643 );
644
645 let params_value: Vec<serde_json::Value> = params
646 .parameters
647 .iter()
648 .map(|(name, values, shape)| {
649 serde_json::json!({
650 "name": name,
651 "shape": shape,
652 "data": values,
653 })
654 })
655 .collect();
656 result.insert("parameters", serde_json::Value::Array(params_value));
657
658 serde_json::to_vec_pretty(&result)
659 .map_err(|e| NeuralError::SerializationError(e.to_string()))
660 }
661 }
662 }
663
664 fn architecture_name(&self) -> &str {
665 "ResNet"
666 }
667}
668
669impl<F> ExtractParameters for BertModel<F>
674where
675 F: Float
676 + Debug
677 + ScalarOperand
678 + NumAssign
679 + ToPrimitive
680 + FromPrimitive
681 + Send
682 + Sync
683 + SimdUnifiedOps
684 + 'static,
685{
686 fn extract_named_parameters(&self) -> Result<NamedParameters> {
687 let mut all_params = NamedParameters::new();
688 let named = self.extract_named_params()?;
689
690 for (name, param) in named {
691 let shape: Vec<usize> = param.shape().to_vec();
692 let values: Vec<f64> = param
693 .iter()
694 .map(|&x| {
695 x.to_f64().ok_or_else(|| {
696 NeuralError::SerializationError(
697 "Cannot convert parameter to f64".to_string(),
698 )
699 })
700 })
701 .collect::<Result<Vec<f64>>>()?;
702 all_params.add(&name, values, shape);
703 }
704
705 Ok(all_params)
706 }
707
708 fn load_named_parameters(&mut self, params: &NamedParameters) -> Result<()> {
709 let mut params_map = HashMap::new();
710 for (name, values, shape) in ¶ms.parameters {
711 let f_values: Vec<F> = values
712 .iter()
713 .map(|&x| {
714 F::from(x).ok_or_else(|| {
715 NeuralError::DeserializationError(format!(
716 "Cannot convert {x} to target type"
717 ))
718 })
719 })
720 .collect::<Result<Vec<F>>>()?;
721 let arr = Array::from_shape_vec(IxDyn(shape), f_values)?;
722 params_map.insert(name.clone(), arr);
723 }
724 self.load_named_params(¶ms_map)
725 }
726}
727
728impl<F> ModelSerialize for BertModel<F>
729where
730 F: Float
731 + Debug
732 + ScalarOperand
733 + NumAssign
734 + ToPrimitive
735 + FromPrimitive
736 + Send
737 + Sync
738 + SimdUnifiedOps
739 + 'static,
740{
741 fn save(&self, path: &Path, format: ModelFormat) -> Result<()> {
742 let bytes = self.to_bytes(format)?;
743 fs::write(path, bytes).map_err(|e| NeuralError::IOError(e.to_string()))?;
744 Ok(())
745 }
746
747 fn to_bytes(&self, format: ModelFormat) -> Result<Vec<u8>> {
748 let config = self.config();
749 let ser_config = SerializableBertConfig::from(config);
750
751 let arch_config = ArchitectureConfig {
752 architecture: "BERT".to_string(),
753 format_version: "1.0".to_string(),
754 config: serde_json::to_value(&ser_config)
755 .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
756 };
757
758 let params = self.extract_named_parameters()?;
759
760 match format {
761 ModelFormat::SafeTensors => {
762 let metadata = ModelMetadata::new("BERT", "f64", params.total_parameters())
763 .with_extra(
764 "architecture_config",
765 &serde_json::to_string(&arch_config)
766 .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
767 );
768
769 let mut writer = SafeTensorsWriter::new();
770 writer.add_model_metadata(&metadata);
771 writer.add_named_parameters(¶ms)?;
772 writer.to_bytes()
773 }
774 _ => {
775 let mut result = HashMap::new();
776 result.insert(
777 "architecture",
778 serde_json::to_value(&arch_config)
779 .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
780 );
781
782 let params_value: Vec<serde_json::Value> = params
783 .parameters
784 .iter()
785 .map(|(name, values, shape)| {
786 serde_json::json!({
787 "name": name,
788 "shape": shape,
789 "data": values,
790 })
791 })
792 .collect();
793 result.insert("parameters", serde_json::Value::Array(params_value));
794
795 serde_json::to_vec_pretty(&result)
796 .map_err(|e| NeuralError::SerializationError(e.to_string()))
797 }
798 }
799 }
800
801 fn architecture_name(&self) -> &str {
802 "BERT"
803 }
804}
805
806impl<F> ModelDeserialize for BertModel<F>
807where
808 F: Float
809 + Debug
810 + ScalarOperand
811 + NumAssign
812 + ToPrimitive
813 + FromPrimitive
814 + Send
815 + Sync
816 + SimdUnifiedOps
817 + 'static,
818{
819 fn load(path: &Path, format: ModelFormat) -> Result<Self> {
820 let bytes = fs::read(path).map_err(|e| NeuralError::IOError(e.to_string()))?;
821 Self::from_bytes(&bytes, format)
822 }
823
824 fn from_bytes(bytes: &[u8], format: ModelFormat) -> Result<Self> {
825 match format {
826 ModelFormat::SafeTensors => {
827 let reader = SafeTensorsReader::from_bytes(bytes)?;
828 let meta = reader.metadata();
829 let arch_config_str = meta.get("architecture_config").ok_or_else(|| {
830 NeuralError::DeserializationError(
831 "Missing architecture_config in SafeTensors metadata".to_string(),
832 )
833 })?;
834 let arch_config: ArchitectureConfig = serde_json::from_str(arch_config_str)
835 .map_err(|e| {
836 NeuralError::DeserializationError(format!(
837 "Invalid architecture config: {e}"
838 ))
839 })?;
840
841 let ser_config: SerializableBertConfig = serde_json::from_value(arch_config.config)
842 .map_err(|e| {
843 NeuralError::DeserializationError(format!("Invalid BERT config: {e}"))
844 })?;
845
846 let bert_config = ser_config.to_bert_config();
847 let mut model = BertModel::new(bert_config)?;
848
849 let params = reader.to_named_parameters()?;
850 model.load_named_parameters(¶ms)?;
851
852 Ok(model)
853 }
854 _ => {
855 let raw: HashMap<String, serde_json::Value> = serde_json::from_slice(bytes)
856 .map_err(|e| NeuralError::DeserializationError(format!("Invalid JSON: {e}")))?;
857
858 let arch_value = raw.get("architecture").ok_or_else(|| {
859 NeuralError::DeserializationError(
860 "Missing 'architecture' key in JSON".to_string(),
861 )
862 })?;
863
864 let arch_config: ArchitectureConfig = serde_json::from_value(arch_value.clone())
865 .map_err(|e| {
866 NeuralError::DeserializationError(format!(
867 "Invalid architecture config: {e}"
868 ))
869 })?;
870
871 let ser_config: SerializableBertConfig = serde_json::from_value(arch_config.config)
872 .map_err(|e| {
873 NeuralError::DeserializationError(format!("Invalid BERT config: {e}"))
874 })?;
875
876 let bert_config = ser_config.to_bert_config();
877 BertModel::new(bert_config)
878 }
879 }
880 }
881}
882
883impl<F> ExtractParameters for GPTModel<F>
888where
889 F: Float
890 + Debug
891 + ScalarOperand
892 + NumAssign
893 + ToPrimitive
894 + Send
895 + Sync
896 + SimdUnifiedOps
897 + 'static,
898{
899 fn extract_named_parameters(&self) -> Result<NamedParameters> {
900 let mut all_params = NamedParameters::new();
901
902 let layer_ref: &dyn Layer<F> = self;
903 let params = layer_ref.params();
904
905 for (i, param) in params.iter().enumerate() {
906 let name = format!("gpt.param_{i}");
907 let shape: Vec<usize> = param.shape().to_vec();
908 let values: Vec<f64> = param
909 .iter()
910 .map(|&x| {
911 x.to_f64().ok_or_else(|| {
912 NeuralError::SerializationError(
913 "Cannot convert parameter to f64".to_string(),
914 )
915 })
916 })
917 .collect::<Result<Vec<f64>>>()?;
918 all_params.add(&name, values, shape);
919 }
920
921 Ok(all_params)
922 }
923
924 fn load_named_parameters(&mut self, _params: &NamedParameters) -> Result<()> {
925 Ok(())
926 }
927}
928
929impl<F> ModelSerialize for GPTModel<F>
930where
931 F: Float
932 + Debug
933 + ScalarOperand
934 + NumAssign
935 + ToPrimitive
936 + Send
937 + Sync
938 + SimdUnifiedOps
939 + 'static,
940{
941 fn save(&self, path: &Path, format: ModelFormat) -> Result<()> {
942 let bytes = self.to_bytes(format)?;
943 fs::write(path, bytes).map_err(|e| NeuralError::IOError(e.to_string()))?;
944 Ok(())
945 }
946
947 fn to_bytes(&self, format: ModelFormat) -> Result<Vec<u8>> {
948 let config = self.config();
949 let ser_config = SerializableGPTConfig::from(config);
950
951 let arch_config = ArchitectureConfig {
952 architecture: "GPT".to_string(),
953 format_version: "1.0".to_string(),
954 config: serde_json::to_value(&ser_config)
955 .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
956 };
957
958 let params = self.extract_named_parameters()?;
959
960 match format {
961 ModelFormat::SafeTensors => {
962 let metadata = ModelMetadata::new("GPT", "f64", params.total_parameters())
963 .with_extra(
964 "architecture_config",
965 &serde_json::to_string(&arch_config)
966 .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
967 );
968
969 let mut writer = SafeTensorsWriter::new();
970 writer.add_model_metadata(&metadata);
971 writer.add_named_parameters(¶ms)?;
972 writer.to_bytes()
973 }
974 _ => {
975 let mut result = HashMap::new();
976 result.insert(
977 "architecture",
978 serde_json::to_value(&arch_config)
979 .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
980 );
981
982 let params_value: Vec<serde_json::Value> = params
983 .parameters
984 .iter()
985 .map(|(name, values, shape)| {
986 serde_json::json!({
987 "name": name,
988 "shape": shape,
989 "data": values,
990 })
991 })
992 .collect();
993 result.insert("parameters", serde_json::Value::Array(params_value));
994
995 serde_json::to_vec_pretty(&result)
996 .map_err(|e| NeuralError::SerializationError(e.to_string()))
997 }
998 }
999 }
1000
1001 fn architecture_name(&self) -> &str {
1002 "GPT"
1003 }
1004}
1005
1006impl<F> ModelDeserialize for GPTModel<F>
1007where
1008 F: Float
1009 + Debug
1010 + ScalarOperand
1011 + NumAssign
1012 + ToPrimitive
1013 + Send
1014 + Sync
1015 + SimdUnifiedOps
1016 + 'static,
1017{
1018 fn load(path: &Path, format: ModelFormat) -> Result<Self> {
1019 let bytes = fs::read(path).map_err(|e| NeuralError::IOError(e.to_string()))?;
1020 Self::from_bytes(&bytes, format)
1021 }
1022
1023 fn from_bytes(bytes: &[u8], format: ModelFormat) -> Result<Self> {
1024 match format {
1025 ModelFormat::SafeTensors => {
1026 let reader = SafeTensorsReader::from_bytes(bytes)?;
1027 let meta = reader.metadata();
1028 let arch_config_str = meta.get("architecture_config").ok_or_else(|| {
1029 NeuralError::DeserializationError(
1030 "Missing architecture_config in SafeTensors metadata".to_string(),
1031 )
1032 })?;
1033 let arch_config: ArchitectureConfig = serde_json::from_str(arch_config_str)
1034 .map_err(|e| {
1035 NeuralError::DeserializationError(format!(
1036 "Invalid architecture config: {e}"
1037 ))
1038 })?;
1039
1040 let ser_config: SerializableGPTConfig = serde_json::from_value(arch_config.config)
1041 .map_err(|e| {
1042 NeuralError::DeserializationError(format!("Invalid GPT config: {e}"))
1043 })?;
1044
1045 let gpt_config = ser_config.to_gpt_config();
1046 let mut model = GPTModel::new(gpt_config)?;
1047
1048 let params = reader.to_named_parameters()?;
1049 model.load_named_parameters(¶ms)?;
1050
1051 Ok(model)
1052 }
1053 _ => {
1054 let raw: HashMap<String, serde_json::Value> = serde_json::from_slice(bytes)
1055 .map_err(|e| NeuralError::DeserializationError(format!("Invalid JSON: {e}")))?;
1056
1057 let arch_value = raw.get("architecture").ok_or_else(|| {
1058 NeuralError::DeserializationError("Missing 'architecture' key".to_string())
1059 })?;
1060
1061 let arch_config: ArchitectureConfig = serde_json::from_value(arch_value.clone())
1062 .map_err(|e| {
1063 NeuralError::DeserializationError(format!(
1064 "Invalid architecture config: {e}"
1065 ))
1066 })?;
1067
1068 let ser_config: SerializableGPTConfig = serde_json::from_value(arch_config.config)
1069 .map_err(|e| {
1070 NeuralError::DeserializationError(format!("Invalid GPT config: {e}"))
1071 })?;
1072
1073 let gpt_config = ser_config.to_gpt_config();
1074 GPTModel::new(gpt_config)
1075 }
1076 }
1077 }
1078}
1079
1080impl<F> ModelDeserialize for ResNet<F>
1085where
1086 F: Float
1087 + Debug
1088 + ScalarOperand
1089 + NumAssign
1090 + ToPrimitive
1091 + FromPrimitive
1092 + Send
1093 + Sync
1094 + 'static,
1095{
1096 fn load(path: &Path, format: ModelFormat) -> Result<Self> {
1097 let bytes = fs::read(path).map_err(|e| NeuralError::IOError(e.to_string()))?;
1098 Self::from_bytes(&bytes, format)
1099 }
1100
1101 fn from_bytes(bytes: &[u8], format: ModelFormat) -> Result<Self> {
1102 match format {
1103 ModelFormat::SafeTensors => {
1104 let reader = SafeTensorsReader::from_bytes(bytes)?;
1105 let meta = reader.metadata();
1106 let arch_config_str = meta.get("architecture_config").ok_or_else(|| {
1107 NeuralError::DeserializationError("Missing architecture_config".to_string())
1108 })?;
1109 let arch_config: ArchitectureConfig = serde_json::from_str(arch_config_str)
1110 .map_err(|e| {
1111 NeuralError::DeserializationError(format!(
1112 "Invalid architecture config: {e}"
1113 ))
1114 })?;
1115
1116 let ser_config: SerializableResNetConfig =
1117 serde_json::from_value(arch_config.config).map_err(|e| {
1118 NeuralError::DeserializationError(format!("Invalid ResNet config: {e}"))
1119 })?;
1120
1121 let resnet_config = ser_config.to_resnet_config()?;
1122 let mut model = ResNet::new(resnet_config)?;
1123
1124 let params = reader.to_named_parameters()?;
1125 model.load_named_parameters(¶ms)?;
1126
1127 Ok(model)
1128 }
1129 _ => {
1130 let raw: HashMap<String, serde_json::Value> = serde_json::from_slice(bytes)
1131 .map_err(|e| NeuralError::DeserializationError(format!("Invalid JSON: {e}")))?;
1132
1133 let arch_value = raw.get("architecture").ok_or_else(|| {
1134 NeuralError::DeserializationError("Missing 'architecture' key".to_string())
1135 })?;
1136
1137 let arch_config: ArchitectureConfig = serde_json::from_value(arch_value.clone())
1138 .map_err(|e| {
1139 NeuralError::DeserializationError(format!(
1140 "Invalid architecture config: {e}"
1141 ))
1142 })?;
1143
1144 let ser_config: SerializableResNetConfig =
1145 serde_json::from_value(arch_config.config).map_err(|e| {
1146 NeuralError::DeserializationError(format!("Invalid ResNet config: {e}"))
1147 })?;
1148
1149 let resnet_config = ser_config.to_resnet_config()?;
1150 ResNet::new(resnet_config)
1151 }
1152 }
1153 }
1154}
1155
1156impl<F> ExtractParameters for Mamba<F>
1161where
1162 F: Float
1163 + Debug
1164 + ScalarOperand
1165 + NumAssign
1166 + ToPrimitive
1167 + Send
1168 + Sync
1169 + SimdUnifiedOps
1170 + 'static,
1171{
1172 fn extract_named_parameters(&self) -> Result<NamedParameters> {
1173 let mut all_params = NamedParameters::new();
1174
1175 let layer_ref: &dyn Layer<F> = self;
1176 let params = layer_ref.params();
1177
1178 for (i, param) in params.iter().enumerate() {
1179 let name = format!("mamba.param_{i}");
1180 let shape: Vec<usize> = param.shape().to_vec();
1181 let values: Vec<f64> = param
1182 .iter()
1183 .map(|&x| {
1184 x.to_f64().ok_or_else(|| {
1185 NeuralError::SerializationError(
1186 "Cannot convert parameter to f64".to_string(),
1187 )
1188 })
1189 })
1190 .collect::<Result<Vec<f64>>>()?;
1191 all_params.add(&name, values, shape);
1192 }
1193
1194 Ok(all_params)
1195 }
1196
1197 fn load_named_parameters(&mut self, _params: &NamedParameters) -> Result<()> {
1198 Ok(())
1199 }
1200}
1201
1202impl<F> ModelSerialize for Mamba<F>
1203where
1204 F: Float
1205 + Debug
1206 + ScalarOperand
1207 + NumAssign
1208 + ToPrimitive
1209 + Send
1210 + Sync
1211 + SimdUnifiedOps
1212 + 'static,
1213{
1214 fn save(&self, path: &Path, format: ModelFormat) -> Result<()> {
1215 let bytes = self.to_bytes(format)?;
1216 fs::write(path, bytes).map_err(|e| NeuralError::IOError(e.to_string()))?;
1217 Ok(())
1218 }
1219
1220 fn to_bytes(&self, format: ModelFormat) -> Result<Vec<u8>> {
1221 let config = self.config();
1222 let ser_config = SerializableMambaConfig::from(config);
1223
1224 let arch_config = ArchitectureConfig {
1225 architecture: "Mamba".to_string(),
1226 format_version: "1.0".to_string(),
1227 config: serde_json::to_value(&ser_config)
1228 .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
1229 };
1230
1231 let params = self.extract_named_parameters()?;
1232
1233 match format {
1234 ModelFormat::SafeTensors => {
1235 let metadata = ModelMetadata::new("Mamba", "f64", params.total_parameters())
1236 .with_extra(
1237 "architecture_config",
1238 &serde_json::to_string(&arch_config)
1239 .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
1240 );
1241
1242 let mut writer = SafeTensorsWriter::new();
1243 writer.add_model_metadata(&metadata);
1244 writer.add_named_parameters(¶ms)?;
1245 writer.to_bytes()
1246 }
1247 _ => {
1248 let mut result = HashMap::new();
1249 result.insert(
1250 "architecture",
1251 serde_json::to_value(&arch_config)
1252 .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
1253 );
1254
1255 let params_value: Vec<serde_json::Value> = params
1256 .parameters
1257 .iter()
1258 .map(|(name, values, shape)| {
1259 serde_json::json!({
1260 "name": name,
1261 "shape": shape,
1262 "data": values,
1263 })
1264 })
1265 .collect();
1266 result.insert("parameters", serde_json::Value::Array(params_value));
1267
1268 serde_json::to_vec_pretty(&result)
1269 .map_err(|e| NeuralError::SerializationError(e.to_string()))
1270 }
1271 }
1272 }
1273
1274 fn architecture_name(&self) -> &str {
1275 "Mamba"
1276 }
1277}
1278
1279impl<F> ModelDeserialize for Mamba<F>
1280where
1281 F: Float
1282 + Debug
1283 + ScalarOperand
1284 + NumAssign
1285 + ToPrimitive
1286 + Send
1287 + Sync
1288 + SimdUnifiedOps
1289 + 'static,
1290{
1291 fn load(path: &Path, format: ModelFormat) -> Result<Self> {
1292 let bytes = fs::read(path).map_err(|e| NeuralError::IOError(e.to_string()))?;
1293 Self::from_bytes(&bytes, format)
1294 }
1295
1296 fn from_bytes(bytes: &[u8], format: ModelFormat) -> Result<Self> {
1297 match format {
1298 ModelFormat::SafeTensors => {
1299 let reader = SafeTensorsReader::from_bytes(bytes)?;
1300 let meta = reader.metadata();
1301 let arch_config_str = meta.get("architecture_config").ok_or_else(|| {
1302 NeuralError::DeserializationError("Missing architecture_config".to_string())
1303 })?;
1304 let arch_config: ArchitectureConfig = serde_json::from_str(arch_config_str)
1305 .map_err(|e| {
1306 NeuralError::DeserializationError(format!(
1307 "Invalid architecture config: {e}"
1308 ))
1309 })?;
1310
1311 let ser_config: SerializableMambaConfig =
1312 serde_json::from_value(arch_config.config).map_err(|e| {
1313 NeuralError::DeserializationError(format!("Invalid Mamba config: {e}"))
1314 })?;
1315
1316 let mamba_config = ser_config.to_mamba_config();
1317 let mut rng = scirs2_core::ChaCha8Rng::seed_from_u64(42);
1318 let mut model = Mamba::new(mamba_config, &mut rng)?;
1319
1320 let params = reader.to_named_parameters()?;
1321 model.load_named_parameters(¶ms)?;
1322
1323 Ok(model)
1324 }
1325 _ => {
1326 let raw: HashMap<String, serde_json::Value> = serde_json::from_slice(bytes)
1327 .map_err(|e| NeuralError::DeserializationError(format!("Invalid JSON: {e}")))?;
1328
1329 let arch_value = raw.get("architecture").ok_or_else(|| {
1330 NeuralError::DeserializationError("Missing 'architecture' key".to_string())
1331 })?;
1332
1333 let arch_config: ArchitectureConfig = serde_json::from_value(arch_value.clone())
1334 .map_err(|e| {
1335 NeuralError::DeserializationError(format!(
1336 "Invalid architecture config: {e}"
1337 ))
1338 })?;
1339
1340 let ser_config: SerializableMambaConfig =
1341 serde_json::from_value(arch_config.config).map_err(|e| {
1342 NeuralError::DeserializationError(format!("Invalid Mamba config: {e}"))
1343 })?;
1344
1345 let mamba_config = ser_config.to_mamba_config();
1346 let mut rng = scirs2_core::ChaCha8Rng::seed_from_u64(42);
1347 Mamba::new(mamba_config, &mut rng)
1348 }
1349 }
1350 }
1351}
1352
1353impl<F> ExtractParameters for EfficientNet<F>
1358where
1359 F: Float + Debug + ScalarOperand + NumAssign + ToPrimitive + Send + Sync + 'static,
1360{
1361 fn extract_named_parameters(&self) -> Result<NamedParameters> {
1362 let mut all_params = NamedParameters::new();
1363
1364 let layer_ref: &dyn Layer<F> = self;
1365 let params = layer_ref.params();
1366
1367 for (i, param) in params.iter().enumerate() {
1368 let name = format!("efficientnet.param_{i}");
1369 let shape: Vec<usize> = param.shape().to_vec();
1370 let values: Vec<f64> = param
1371 .iter()
1372 .map(|&x| {
1373 x.to_f64().ok_or_else(|| {
1374 NeuralError::SerializationError(
1375 "Cannot convert parameter to f64".to_string(),
1376 )
1377 })
1378 })
1379 .collect::<Result<Vec<f64>>>()?;
1380 all_params.add(&name, values, shape);
1381 }
1382
1383 Ok(all_params)
1384 }
1385
1386 fn load_named_parameters(&mut self, _params: &NamedParameters) -> Result<()> {
1387 Ok(())
1388 }
1389}
1390
1391impl<F> ModelSerialize for EfficientNet<F>
1392where
1393 F: Float + Debug + ScalarOperand + NumAssign + ToPrimitive + Send + Sync + 'static,
1394{
1395 fn save(&self, path: &Path, format: ModelFormat) -> Result<()> {
1396 let bytes = self.to_bytes(format)?;
1397 fs::write(path, bytes).map_err(|e| NeuralError::IOError(e.to_string()))?;
1398 Ok(())
1399 }
1400
1401 fn to_bytes(&self, format: ModelFormat) -> Result<Vec<u8>> {
1402 let config = self.config();
1403 let ser_config = SerializableEfficientNetConfig {
1404 width_coefficient: config.width_coefficient,
1405 depth_coefficient: config.depth_coefficient,
1406 resolution: config.resolution,
1407 dropout_rate: config.dropout_rate,
1408 input_channels: config.input_channels,
1409 num_classes: config.num_classes,
1410 };
1411
1412 let arch_config = ArchitectureConfig {
1413 architecture: "EfficientNet".to_string(),
1414 format_version: "1.0".to_string(),
1415 config: serde_json::to_value(&ser_config)
1416 .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
1417 };
1418
1419 let params = self.extract_named_parameters()?;
1420
1421 match format {
1422 ModelFormat::SafeTensors => {
1423 let metadata = ModelMetadata::new("EfficientNet", "f64", params.total_parameters())
1424 .with_extra(
1425 "architecture_config",
1426 &serde_json::to_string(&arch_config)
1427 .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
1428 );
1429
1430 let mut writer = SafeTensorsWriter::new();
1431 writer.add_model_metadata(&metadata);
1432 writer.add_named_parameters(¶ms)?;
1433 writer.to_bytes()
1434 }
1435 _ => {
1436 let mut result = HashMap::new();
1437 result.insert(
1438 "architecture",
1439 serde_json::to_value(&arch_config)
1440 .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
1441 );
1442
1443 let params_value: Vec<serde_json::Value> = params
1444 .parameters
1445 .iter()
1446 .map(|(name, values, shape)| {
1447 serde_json::json!({
1448 "name": name,
1449 "shape": shape,
1450 "data": values,
1451 })
1452 })
1453 .collect();
1454 result.insert("parameters", serde_json::Value::Array(params_value));
1455
1456 serde_json::to_vec_pretty(&result)
1457 .map_err(|e| NeuralError::SerializationError(e.to_string()))
1458 }
1459 }
1460 }
1461
1462 fn architecture_name(&self) -> &str {
1463 "EfficientNet"
1464 }
1465}
1466
1467impl<F> ExtractParameters for MobileNet<F>
1472where
1473 F: Float + Debug + ScalarOperand + NumAssign + ToPrimitive + Send + Sync + 'static,
1474{
1475 fn extract_named_parameters(&self) -> Result<NamedParameters> {
1476 let mut all_params = NamedParameters::new();
1477
1478 let layer_ref: &dyn Layer<F> = self;
1479 let params = layer_ref.params();
1480
1481 for (i, param) in params.iter().enumerate() {
1482 let name = format!("mobilenet.param_{i}");
1483 let shape: Vec<usize> = param.shape().to_vec();
1484 let values: Vec<f64> = param
1485 .iter()
1486 .map(|&x| {
1487 x.to_f64().ok_or_else(|| {
1488 NeuralError::SerializationError(
1489 "Cannot convert parameter to f64".to_string(),
1490 )
1491 })
1492 })
1493 .collect::<Result<Vec<f64>>>()?;
1494 all_params.add(&name, values, shape);
1495 }
1496
1497 Ok(all_params)
1498 }
1499
1500 fn load_named_parameters(&mut self, _params: &NamedParameters) -> Result<()> {
1501 Ok(())
1502 }
1503}
1504
1505impl<F> ModelSerialize for MobileNet<F>
1506where
1507 F: Float + Debug + ScalarOperand + NumAssign + ToPrimitive + Send + Sync + 'static,
1508{
1509 fn save(&self, path: &Path, format: ModelFormat) -> Result<()> {
1510 let bytes = self.to_bytes(format)?;
1511 fs::write(path, bytes).map_err(|e| NeuralError::IOError(e.to_string()))?;
1512 Ok(())
1513 }
1514
1515 fn to_bytes(&self, format: ModelFormat) -> Result<Vec<u8>> {
1516 let config = self.config();
1517 let ser_config = SerializableMobileNetConfig {
1518 version: match config.version {
1519 MobileNetVersion::V1 => "V1".to_string(),
1520 MobileNetVersion::V2 => "V2".to_string(),
1521 MobileNetVersion::V3Small => "V3Small".to_string(),
1522 MobileNetVersion::V3Large => "V3Large".to_string(),
1523 },
1524 width_multiplier: config.width_multiplier,
1525 resolution_multiplier: config.resolution_multiplier,
1526 dropout_rate: config.dropout_rate,
1527 input_channels: config.input_channels,
1528 num_classes: config.num_classes,
1529 };
1530
1531 let arch_config = ArchitectureConfig {
1532 architecture: "MobileNet".to_string(),
1533 format_version: "1.0".to_string(),
1534 config: serde_json::to_value(&ser_config)
1535 .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
1536 };
1537
1538 let params = self.extract_named_parameters()?;
1539
1540 match format {
1541 ModelFormat::SafeTensors => {
1542 let metadata = ModelMetadata::new("MobileNet", "f64", params.total_parameters())
1543 .with_extra(
1544 "architecture_config",
1545 &serde_json::to_string(&arch_config)
1546 .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
1547 );
1548
1549 let mut writer = SafeTensorsWriter::new();
1550 writer.add_model_metadata(&metadata);
1551 writer.add_named_parameters(¶ms)?;
1552 writer.to_bytes()
1553 }
1554 _ => {
1555 let mut result = HashMap::new();
1556 result.insert(
1557 "architecture",
1558 serde_json::to_value(&arch_config)
1559 .map_err(|e| NeuralError::SerializationError(e.to_string()))?,
1560 );
1561
1562 let params_value: Vec<serde_json::Value> = params
1563 .parameters
1564 .iter()
1565 .map(|(name, values, shape)| {
1566 serde_json::json!({
1567 "name": name,
1568 "shape": shape,
1569 "data": values,
1570 })
1571 })
1572 .collect();
1573 result.insert("parameters", serde_json::Value::Array(params_value));
1574
1575 serde_json::to_vec_pretty(&result)
1576 .map_err(|e| NeuralError::SerializationError(e.to_string()))
1577 }
1578 }
1579 }
1580
1581 fn architecture_name(&self) -> &str {
1582 "MobileNet"
1583 }
1584}
1585
1586pub fn detect_architecture(path: &Path) -> Result<String> {
1592 let bytes = fs::read(path).map_err(|e| NeuralError::IOError(e.to_string()))?;
1593 detect_architecture_from_bytes(&bytes)
1594}
1595
1596pub fn detect_architecture_from_bytes(bytes: &[u8]) -> Result<String> {
1598 if bytes.len() >= 8 {
1600 if let Ok(reader) = SafeTensorsReader::from_bytes(bytes) {
1601 let meta = reader.metadata();
1602 if let Some(arch) = meta.get("architecture") {
1603 return Ok(arch.clone());
1604 }
1605 }
1606 }
1607
1608 if let Ok(raw) = serde_json::from_slice::<HashMap<String, serde_json::Value>>(bytes) {
1610 if let Some(arch_value) = raw.get("architecture") {
1611 if let Ok(arch_config) =
1612 serde_json::from_value::<ArchitectureConfig>(arch_value.clone())
1613 {
1614 return Ok(arch_config.architecture);
1615 }
1616 }
1617 }
1618
1619 Err(NeuralError::DeserializationError(
1620 "Cannot detect architecture from file: unrecognized format".to_string(),
1621 ))
1622}
1623
1624#[cfg(test)]
1625mod tests {
1626 use super::*;
1627
1628 #[test]
1629 fn test_serializable_resnet_config_roundtrip() -> Result<()> {
1630 let config = ResNetConfig::resnet18(3, 1000);
1631 let ser = SerializableResNetConfig::from(&config);
1632
1633 let json = serde_json::to_string(&ser)
1635 .map_err(|e| NeuralError::SerializationError(e.to_string()))?;
1636
1637 let deser: SerializableResNetConfig = serde_json::from_str(&json)
1639 .map_err(|e| NeuralError::DeserializationError(e.to_string()))?;
1640
1641 let restored = deser.to_resnet_config()?;
1642 assert_eq!(restored.input_channels, 3);
1643 assert_eq!(restored.num_classes, 1000);
1644 assert_eq!(restored.layers.len(), 4);
1645
1646 Ok(())
1647 }
1648
1649 #[test]
1650 fn test_serializable_bert_config_roundtrip() -> Result<()> {
1651 let config = BertConfig::bert_base_uncased();
1652 let ser = SerializableBertConfig::from(&config);
1653
1654 let json = serde_json::to_string(&ser)
1655 .map_err(|e| NeuralError::SerializationError(e.to_string()))?;
1656
1657 let deser: SerializableBertConfig = serde_json::from_str(&json)
1658 .map_err(|e| NeuralError::DeserializationError(e.to_string()))?;
1659
1660 let restored = deser.to_bert_config();
1661 assert_eq!(restored.vocab_size, 30522);
1662 assert_eq!(restored.hidden_size, 768);
1663 assert_eq!(restored.num_hidden_layers, 12);
1664
1665 Ok(())
1666 }
1667
1668 #[test]
1669 fn test_serializable_gpt_config_roundtrip() -> Result<()> {
1670 let config = GPTConfig::gpt2_small();
1671 let ser = SerializableGPTConfig::from(&config);
1672
1673 let json = serde_json::to_string(&ser)
1674 .map_err(|e| NeuralError::SerializationError(e.to_string()))?;
1675
1676 let deser: SerializableGPTConfig = serde_json::from_str(&json)
1677 .map_err(|e| NeuralError::DeserializationError(e.to_string()))?;
1678
1679 let restored = deser.to_gpt_config();
1680 assert_eq!(restored.vocab_size, 50257);
1681 assert_eq!(restored.hidden_size, 768);
1682
1683 Ok(())
1684 }
1685
1686 #[test]
1687 fn test_serializable_mamba_config_roundtrip() -> Result<()> {
1688 let config = MambaConfig::new(256);
1689 let ser = SerializableMambaConfig::from(&config);
1690
1691 let json = serde_json::to_string(&ser)
1692 .map_err(|e| NeuralError::SerializationError(e.to_string()))?;
1693
1694 let deser: SerializableMambaConfig = serde_json::from_str(&json)
1695 .map_err(|e| NeuralError::DeserializationError(e.to_string()))?;
1696
1697 let restored = deser.to_mamba_config();
1698 assert_eq!(restored.d_model, 256);
1699 assert_eq!(restored.d_state, 16);
1700
1701 Ok(())
1702 }
1703
1704 #[test]
1705 fn test_architecture_config_envelope() -> Result<()> {
1706 let config = ArchitectureConfig {
1707 architecture: "ResNet".to_string(),
1708 format_version: "1.0".to_string(),
1709 config: serde_json::json!({
1710 "block": "Basic",
1711 "input_channels": 3,
1712 "num_classes": 10,
1713 }),
1714 };
1715
1716 let json = serde_json::to_string(&config)
1717 .map_err(|e| NeuralError::SerializationError(e.to_string()))?;
1718
1719 let restored: ArchitectureConfig = serde_json::from_str(&json)
1720 .map_err(|e| NeuralError::DeserializationError(e.to_string()))?;
1721
1722 assert_eq!(restored.architecture, "ResNet");
1723 assert_eq!(restored.format_version, "1.0");
1724
1725 Ok(())
1726 }
1727
1728 #[test]
1729 fn test_resnet_model_serialize() -> Result<()> {
1730 let config = ResNetConfig::resnet18(3, 10);
1731 let model = ResNet::<f64>::new(config)?;
1732
1733 let bytes = model.to_bytes(ModelFormat::SafeTensors)?;
1735 assert!(!bytes.is_empty());
1736
1737 let reader = SafeTensorsReader::from_bytes(&bytes)?;
1739 let meta = reader.metadata();
1740 assert_eq!(meta.get("architecture"), Some(&"ResNet".to_string()));
1741
1742 let json_bytes = model.to_bytes(ModelFormat::Json)?;
1744 assert!(!json_bytes.is_empty());
1745
1746 Ok(())
1747 }
1748
1749 #[test]
1750 fn test_resnet_save_load_roundtrip() -> Result<()> {
1751 let test_dir = std::env::temp_dir().join("scirs2_arch_resnet");
1752 fs::create_dir_all(&test_dir).map_err(|e| NeuralError::IOError(e.to_string()))?;
1753 let path = test_dir.join("resnet18.safetensors");
1754
1755 let config = ResNetConfig::resnet18(3, 10);
1756 let model = ResNet::<f64>::new(config)?;
1757 model.save(&path, ModelFormat::SafeTensors)?;
1758
1759 let loaded = ResNet::<f64>::load(&path, ModelFormat::SafeTensors)?;
1760 assert_eq!(loaded.config().input_channels, 3);
1761 assert_eq!(loaded.config().num_classes, 10);
1762
1763 let _ = fs::remove_dir_all(&test_dir);
1764 Ok(())
1765 }
1766
1767 #[test]
1768 fn test_detect_architecture_safetensors() -> Result<()> {
1769 let config = ResNetConfig::resnet18(3, 10);
1770 let model = ResNet::<f64>::new(config)?;
1771 let bytes = model.to_bytes(ModelFormat::SafeTensors)?;
1772
1773 let arch = detect_architecture_from_bytes(&bytes)?;
1774 assert_eq!(arch, "ResNet");
1775 Ok(())
1776 }
1777
1778 #[test]
1779 fn test_detect_architecture_json() -> Result<()> {
1780 let config = ResNetConfig::resnet18(3, 10);
1781 let model = ResNet::<f64>::new(config)?;
1782 let bytes = model.to_bytes(ModelFormat::Json)?;
1783
1784 let arch = detect_architecture_from_bytes(&bytes)?;
1785 assert_eq!(arch, "ResNet");
1786 Ok(())
1787 }
1788}