1use std::collections::HashMap;
19use std::fs::{self, File};
20use std::io::{Read, Write};
21use std::path::{Path, PathBuf};
22
23use ndarray::IxDyn;
24
25#[cfg(feature = "serialization")]
26use serde::{Deserialize, Serialize};
27#[cfg(feature = "serialization")]
28use serde_json;
29
30use chrono;
31
32use crate::array_protocol::grad::{Optimizer, SGD};
33use crate::array_protocol::ml_ops::ActivationFunc;
34use crate::array_protocol::neural::{
35 BatchNorm, Conv2D, Dropout, Layer, Linear, MaxPool2D, Sequential,
36};
37use crate::array_protocol::{ArrayProtocol, NdarrayWrapper};
38use crate::error::{CoreError, CoreResult, ErrorContext};
39
40pub trait Serializable {
42 fn serialize(&self) -> CoreResult<Vec<u8>>;
44
45 fn deserialize(bytes: &[u8]) -> CoreResult<Self>
47 where
48 Self: Sized;
49
50 fn type_name(&self) -> &str;
52}
53
54#[derive(Serialize, Deserialize)]
56pub struct ModelFile {
57 pub metadata: ModelMetadata,
59
60 pub architecture: ModelArchitecture,
62
63 pub parameter_files: HashMap<String, String>,
65
66 pub optimizer_state: Option<String>,
68}
69
70#[derive(Serialize, Deserialize)]
72pub struct ModelMetadata {
73 pub name: String,
75
76 pub version: String,
78
79 pub framework_version: String,
81
82 pub created_at: String,
84
85 pub inputshape: Vec<usize>,
87
88 pub outputshape: Vec<usize>,
90
91 pub additional_info: HashMap<String, String>,
93}
94
95#[cfg_attr(feature = "serialization", derive(Serialize, Deserialize))]
97pub struct ModelArchitecture {
98 pub model_type: String,
100
101 pub layers: Vec<LayerConfig>,
103}
104
105#[cfg_attr(feature = "serialization", derive(Serialize, Deserialize))]
107pub struct LayerConfig {
108 pub layer_type: String,
110
111 pub name: String,
113
114 #[cfg(feature = "serialization")]
116 pub config: serde_json::Value,
117 #[cfg(not(feature = "serialization"))]
118 pub config: HashMap<String, String>, }
120
121pub struct ModelSerializer {
123 basedir: PathBuf,
125}
126
127impl ModelSerializer {
128 pub fn new(basedir: impl AsRef<Path>) -> Self {
130 Self {
131 basedir: basedir.as_ref().to_path_buf(),
132 }
133 }
134
135 pub fn save_model(
137 &self,
138 model: &Sequential,
139 name: &str,
140 version: &str,
141 optimizer: Option<&dyn Optimizer>,
142 ) -> CoreResult<PathBuf> {
143 let modeldir = self.basedir.join(name).join(version);
145 fs::create_dir_all(&modeldir)?;
146
147 let metadata = ModelMetadata {
149 name: name.to_string(),
150 version: version.to_string(),
151 framework_version: "0.1.0".to_string(),
152 created_at: chrono::Utc::now().to_rfc3339(),
153 inputshape: vec![], outputshape: vec![], additional_info: HashMap::new(),
156 };
157
158 let architecture = self.create_architecture(model)?;
160
161 let mut parameter_files = HashMap::new();
163 self.save_parameters(model, &modeldir, &mut parameter_files)?;
164
165 let optimizer_state = if let Some(optimizer) = optimizer {
167 let optimizerpath = self.save_optimizer(optimizer, &modeldir)?;
168 Some(
169 optimizerpath
170 .file_name()
171 .unwrap()
172 .to_string_lossy()
173 .to_string(),
174 )
175 } else {
176 None
177 };
178
179 let model_file = ModelFile {
181 metadata,
182 architecture,
183 parameter_files,
184 optimizer_state,
185 };
186
187 let model_file_path = modeldir.join("model.json");
189 let model_file_json = serde_json::to_string_pretty(&model_file)?;
190 let mut file = File::create(&model_file_path)?;
191 file.write_all(model_file_json.as_bytes())?;
192
193 Ok(model_file_path)
194 }
195
196 pub fn loadmodel(
198 &self,
199 name: &str,
200 version: &str,
201 ) -> CoreResult<(Sequential, Option<Box<dyn Optimizer>>)> {
202 let modeldir = self.basedir.join(name).join(version);
204
205 let model_file_path = modeldir.join("model.json");
207 let mut file = File::open(&model_file_path)?;
208 let mut model_file_json = String::new();
209 file.read_to_string(&mut model_file_json)?;
210
211 let model_file: ModelFile = serde_json::from_str(&model_file_json)?;
212
213 let model = self.create_model_from_architecture(&model_file.architecture)?;
215
216 self.load_parameters(&model, &modeldir, &model_file.parameter_files)?;
218
219 let optimizer = if let Some(optimizer_state) = &model_file.optimizer_state {
221 let optimizerpath = modeldir.join(optimizer_state);
222 Some(self.load_optimizer(&optimizerpath)?)
223 } else {
224 None
225 };
226
227 Ok((model, optimizer))
228 }
229
230 fn create_architecture(&self, model: &Sequential) -> CoreResult<ModelArchitecture> {
232 let mut layers = Vec::new();
233
234 for layer in model.layers() {
235 let layer_config = self.create_layer_config(layer.as_ref())?;
236 layers.push(layer_config);
237 }
238
239 Ok(ModelArchitecture {
240 model_type: "Sequential".to_string(),
241 layers,
242 })
243 }
244
245 fn create_layer_config(&self, layer: &dyn Layer) -> CoreResult<LayerConfig> {
247 let layer_type = layer.layer_type();
248 if !["Linear", "Conv2D", "MaxPool2D", "BatchNorm", "Dropout"].contains(&layer_type) {
249 return Err(CoreError::NotImplementedError(ErrorContext::new(format!(
250 "Serialization not implemented for layer type: {}",
251 layer.name()
252 ))));
253 };
254
255 let config = match layer_type {
257 "Linear" => {
258 serde_json::json!({
261 "in_features": 0,
262 "out_features": 0,
263 "bias": true,
264 "activation": "relu",
265 })
266 }
267 "Conv2D" => {
268 serde_json::json!({
269 "filter_height": 3,
270 "filter_width": 3,
271 "in_channels": 0,
272 "out_channels": 0,
273 "stride": [1, 1],
274 "padding": [0, 0],
275 "bias": true,
276 "activation": "relu",
277 })
278 }
279 "MaxPool2D" => {
280 serde_json::json!({
281 "kernel_size": [2, 2],
282 "stride": [2, 2],
283 "padding": [0, 0],
284 })
285 }
286 "BatchNorm" => {
287 serde_json::json!({
288 "num_features": 0,
289 "epsilon": 1e-5,
290 "momentum": 0.1,
291 })
292 }
293 "Dropout" => {
294 serde_json::json!({
295 "rate": 0.5,
296 "seed": null,
297 })
298 }
299 _ => serde_json::json!({}),
300 };
301
302 Ok(LayerConfig {
303 layer_type: layer_type.to_string(),
304 name: layer.name().to_string(),
305 config,
306 })
307 }
308
309 fn save_parameters(
311 &self,
312 model: &Sequential,
313 modeldir: &Path,
314 parameter_files: &mut HashMap<String, String>,
315 ) -> CoreResult<()> {
316 let params_dir = modeldir.join("parameters");
318 fs::create_dir_all(¶ms_dir)?;
319
320 for (i, layer) in model.layers().iter().enumerate() {
322 for (j, param) in layer.parameters().iter().enumerate() {
323 let param_name = format!("layer_{i}_param_{j}");
325 let param_file = format!("{param_name}.npz");
326 let param_path = params_dir.join(¶m_file);
327
328 self.save_parameter(param.as_ref(), ¶m_path)?;
330
331 parameter_files.insert(param_name, format!("parameters/{param_file}"));
333 }
334 }
335
336 Ok(())
337 }
338
339 fn save_parameter(&self, param: &dyn ArrayProtocol, path: &Path) -> CoreResult<()> {
341 if let Some(array) = param.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
343 let ndarray = array.as_array();
344
345 let shape: Vec<usize> = ndarray.shape().to_vec();
347 let data: Vec<f64> = ndarray.iter().cloned().collect();
348
349 let save_data = serde_json::json!({
350 "shape": shape,
351 "data": data,
352 });
353
354 let mut file = File::create(path)?;
355 let json_str = serde_json::to_string(&save_data)?;
356 file.write_all(json_str.as_bytes())?;
357
358 Ok(())
359 } else {
360 Err(CoreError::NotImplementedError(ErrorContext::new(
361 "Parameter serialization not implemented for this array type".to_string(),
362 )))
363 }
364 }
365
366 fn save_optimizer(&self, _optimizer: &dyn Optimizer, modeldir: &Path) -> CoreResult<PathBuf> {
368 let optimizerpath = modeldir.join("optimizer.json");
370
371 let optimizer_data = serde_json::json!({
375 "type": "SGD", "config": {
377 "learningrate": 0.01,
378 "momentum": null
379 },
380 "state": {} });
382
383 let mut file = File::create(&optimizerpath)?;
384 let json_str = serde_json::to_string_pretty(&optimizer_data)?;
385 file.write_all(json_str.as_bytes())?;
386
387 Ok(optimizerpath)
388 }
389
390 fn create_model_from_architecture(
392 &self,
393 architecture: &ModelArchitecture,
394 ) -> CoreResult<Sequential> {
395 let mut model = Sequential::new(&architecture.model_type, Vec::new());
396
397 for layer_config in &architecture.layers {
399 let layer = self.create_layer_from_config(layer_config)?;
400 model.add_layer(layer);
401 }
402
403 Ok(model)
404 }
405
406 fn create_layer_from_config(&self, config: &LayerConfig) -> CoreResult<Box<dyn Layer>> {
408 match config.layer_type.as_str() {
409 "Linear" => {
410 let in_features = config.config["in_features"].as_u64().unwrap_or(0) as usize;
412 let out_features = config.config["out_features"].as_u64().unwrap_or(0) as usize;
413 let bias = config.config["bias"].as_bool().unwrap_or(true);
414 let activation = match config.config["activation"].as_str() {
415 Some("relu") => Some(ActivationFunc::ReLU),
416 Some("sigmoid") => Some(ActivationFunc::Sigmoid),
417 Some("tanh") => Some(ActivationFunc::Tanh),
418 _ => None,
419 };
420
421 Ok(Box::new(Linear::new_random(
423 &config.name,
424 in_features,
425 out_features,
426 bias,
427 activation,
428 )))
429 }
430 "Conv2D" => {
431 let filter_height = config.config["filter_height"].as_u64().unwrap_or(3) as usize;
433 let filter_width = config.config["filter_width"].as_u64().unwrap_or(3) as usize;
434 let in_channels = config.config["in_channels"].as_u64().unwrap_or(0) as usize;
435 let out_channels = config.config["out_channels"].as_u64().unwrap_or(0) as usize;
436 let stride = (
437 config.config["stride"][0].as_u64().unwrap_or(1) as usize,
438 config.config["stride"][1].as_u64().unwrap_or(1) as usize,
439 );
440 let padding = (
441 config.config["padding"][0].as_u64().unwrap_or(0) as usize,
442 config.config["padding"][1].as_u64().unwrap_or(0) as usize,
443 );
444 let bias = config.config["bias"].as_bool().unwrap_or(true);
445 let activation = match config.config["activation"].as_str() {
446 Some("relu") => Some(ActivationFunc::ReLU),
447 Some("sigmoid") => Some(ActivationFunc::Sigmoid),
448 Some("tanh") => Some(ActivationFunc::Tanh),
449 _ => None,
450 };
451
452 Ok(Box::new(Conv2D::withshape(
454 &config.name,
455 filter_height,
456 filter_width,
457 in_channels,
458 out_channels,
459 stride,
460 padding,
461 bias,
462 activation,
463 )))
464 }
465 "MaxPool2D" => {
466 let kernel_size = (
468 config.config["kernel_size"][0].as_u64().unwrap_or(2) as usize,
469 config.config["kernel_size"][1].as_u64().unwrap_or(2) as usize,
470 );
471 let stride = if config.config["stride"].is_array() {
472 Some((
473 config.config["stride"][0].as_u64().unwrap_or(2) as usize,
474 config.config["stride"][1].as_u64().unwrap_or(2) as usize,
475 ))
476 } else {
477 None
478 };
479 let padding = (
480 config.config["padding"][0].as_u64().unwrap_or(0) as usize,
481 config.config["padding"][1].as_u64().unwrap_or(0) as usize,
482 );
483
484 Ok(Box::new(MaxPool2D::new(
486 &config.name,
487 kernel_size,
488 stride,
489 padding,
490 )))
491 }
492 "BatchNorm" => {
493 let num_features = config.config["num_features"].as_u64().unwrap_or(0) as usize;
495 let epsilon = config.config["epsilon"].as_f64().unwrap_or(1e-5);
496 let momentum = config.config["momentum"].as_f64().unwrap_or(0.1);
497
498 Ok(Box::new(BatchNorm::withshape(
500 &config.name,
501 num_features,
502 Some(epsilon),
503 Some(momentum),
504 )))
505 }
506 "Dropout" => {
507 let rate = config.config["rate"].as_f64().unwrap_or(0.5);
509 let seed = config.config["seed"].as_u64();
510
511 Ok(Box::new(Dropout::new(&config.name, rate, seed)))
513 }
514 _ => Err(CoreError::NotImplementedError(ErrorContext::new(format!(
515 "Deserialization not implemented for layer type: {layer_type}",
516 layer_type = config.layer_type
517 )))),
518 }
519 }
520
521 fn load_parameters(
523 &self,
524 model: &Sequential,
525 modeldir: &Path,
526 parameter_files: &HashMap<String, String>,
527 ) -> CoreResult<()> {
528 for (i, layer) in model.layers().iter().enumerate() {
530 let params = layer.parameters();
531 for (j, param) in params.iter().enumerate() {
532 let param_name = format!("layer_{i}_param_{j}");
534 if let Some(param_file) = parameter_files.get(¶m_name) {
535 let param_path = modeldir.join(param_file);
536
537 if param_path.exists() {
539 let mut file = File::open(¶m_path)?;
540 let mut json_str = String::new();
541 file.read_to_string(&mut json_str)?;
542
543 let load_data: serde_json::Value = serde_json::from_str(&json_str)?;
544 let shape: Vec<usize> = serde_json::from_value(load_data["shape"].clone())?;
545 let _data: Vec<f64> = serde_json::from_value(load_data["data"].clone())?;
546
547 if let Some(_array) =
553 param.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>()
554 {
555 }
558 } else {
559 return Err(CoreError::InvalidArgument(ErrorContext::new(format!(
560 "Parameter file not found: {path}",
561 path = param_path.display()
562 ))));
563 }
564 }
565 }
566 }
567
568 Ok(())
569 }
570
571 fn load_optimizer(&self, optimizerpath: &Path) -> CoreResult<Box<dyn Optimizer>> {
573 if !optimizerpath.exists() {
575 return Err(CoreError::InvalidArgument(ErrorContext::new(format!(
576 "Optimizer file not found: {path}",
577 path = optimizerpath.display()
578 ))));
579 }
580
581 let mut file = File::open(optimizerpath)?;
583 let mut json_str = String::new();
584 file.read_to_string(&mut json_str)?;
585
586 let optimizer_data: serde_json::Value = serde_json::from_str(&json_str)?;
587
588 match optimizer_data["type"].as_str() {
590 Some("SGD") => {
591 let config = &optimizer_data["config"];
592 let learningrate = config["learningrate"].as_f64().unwrap_or(0.01);
593 let momentum = config["momentum"].as_f64();
594 Ok(Box::new(SGD::new(learningrate, momentum)))
595 }
596 _ => {
597 Ok(Box::new(SGD::new(0.01, None)))
599 }
600 }
601 }
602}
603
604pub struct OnnxExporter;
606
607impl OnnxExporter {
608 pub fn export(
610 &self,
611 _model: &Sequential,
612 path: impl AsRef<Path>,
613 _inputshape: &[usize],
614 ) -> CoreResult<()> {
615 File::create(path.as_ref())?;
620
621 Ok(())
622 }
623}
624
625#[allow(dead_code)]
627pub fn save_checkpoint(
628 model: &Sequential,
629 optimizer: &dyn Optimizer,
630 path: impl AsRef<Path>,
631 epoch: usize,
632 metrics: HashMap<String, f64>,
633) -> CoreResult<()> {
634 let checkpoint_dir = path.as_ref().parent().unwrap_or(Path::new("."));
636 fs::create_dir_all(checkpoint_dir)?;
637
638 let metadata = serde_json::json!({
640 "epoch": epoch,
641 "metrics": metrics,
642 "timestamp": chrono::Utc::now().to_rfc3339(),
643 });
644
645 let metadata_path = path.as_ref().with_extension("json");
647 let metadata_json = serde_json::to_string_pretty(&metadata)?;
648 let mut file = File::create(&metadata_path)?;
649 file.write_all(metadata_json.as_bytes())?;
650
651 let serializer = ModelSerializer::new(checkpoint_dir);
653
654 let model_name = "checkpoint";
656 let model_version = format!("epoch_{epoch}");
657 serializer.save_model(model, model_name, &model_version, Some(optimizer))?;
658
659 Ok(())
660}
661
662pub type ModelCheckpoint = (Sequential, Box<dyn Optimizer>, usize, HashMap<String, f64>);
664
665#[cfg(feature = "serialization")]
667#[allow(dead_code)]
668pub fn load_checkpoint(path: impl AsRef<Path>) -> CoreResult<ModelCheckpoint> {
669 let metadata_path = path.as_ref().with_extension("json");
671 let mut file = File::open(&metadata_path)?;
672 let mut metadata_json = String::new();
673 file.read_to_string(&mut metadata_json)?;
674
675 let metadata: serde_json::Value = serde_json::from_str(&metadata_json)?;
676
677 let epoch = metadata["epoch"].as_u64().unwrap_or(0) as usize;
679 let metrics: HashMap<String, f64> =
680 serde_json::from_value(metadata["metrics"].clone()).unwrap_or_else(|_| HashMap::new());
681
682 let checkpoint_dir = path.as_ref().parent().unwrap_or(Path::new("."));
684 let serializer = ModelSerializer::new(checkpoint_dir);
685
686 let model_name = "checkpoint";
688 let model_version = format!("epoch_{epoch}");
689 let (model, optimizer) = serializer.loadmodel(model_name, &model_version)?;
690
691 Ok((model, optimizer.unwrap(), epoch, metrics))
692}
693
694#[cfg(test)]
695mod tests {
696 use super::*;
697 use crate::array_protocol;
698 use crate::array_protocol::grad::SGD;
699 use crate::array_protocol::ml_ops::ActivationFunc;
700 use crate::array_protocol::neural::{Linear, Sequential};
701 use tempfile::tempdir;
702
703 #[test]
704 fn test_model_serializer() {
705 array_protocol::init();
707
708 let temp_dir = match tempdir() {
710 Ok(dir) => dir,
711 Err(e) => {
712 println!("Skipping test_model_serializer (temp dir creation failed): {e}");
713 return;
714 }
715 };
716
717 let mut model = Sequential::new("test_model", Vec::new());
719
720 model.add_layer(Box::new(Linear::new_random(
722 "fc1",
723 10,
724 5,
725 true,
726 Some(ActivationFunc::ReLU),
727 )));
728
729 model.add_layer(Box::new(Linear::new_random("fc2", 5, 2, true, None)));
730
731 let optimizer = SGD::new(0.01, Some(0.9));
733
734 let serializer = ModelSerializer::new(temp_dir.path());
736
737 let model_path = serializer.save_model(&model, "test_model", "v1", Some(&optimizer));
739 if model_path.is_err() {
740 println!("Save model failed: {:?}", model_path.err());
741 return;
742 }
743
744 let (loadedmodel, loaded_optimizer) = serializer.loadmodel("test_model", "v1").unwrap();
746
747 assert_eq!(loadedmodel.layers().len(), 2);
749 assert!(loaded_optimizer.is_some());
750 }
751
752 #[test]
753 fn test_save_load_checkpoint() {
754 array_protocol::init();
756
757 let temp_dir = match tempdir() {
759 Ok(dir) => dir,
760 Err(e) => {
761 println!("Skipping test_save_load_checkpoint (temp dir creation failed): {e}");
762 return;
763 }
764 };
765
766 let mut model = Sequential::new("test_model", Vec::new());
768
769 model.add_layer(Box::new(Linear::new_random(
771 "fc1",
772 10,
773 5,
774 true,
775 Some(ActivationFunc::ReLU),
776 )));
777
778 let optimizer = SGD::new(0.01, Some(0.9));
780
781 let mut metrics = HashMap::new();
783 metrics.insert("loss".to_string(), 0.1);
784 metrics.insert("accuracy".to_string(), 0.9);
785
786 let checkpoint_path = temp_dir.path().join("checkpoint");
788 let result = save_checkpoint(&model, &optimizer, &checkpoint_path, 10, metrics.clone());
789 if let Err(e) = result {
790 println!("Skipping test_save_load_checkpoint (save failed): {e}");
791 return;
792 }
793
794 let result = load_checkpoint(&checkpoint_path);
796 if let Err(e) = result {
797 println!("Skipping test_save_load_checkpoint (load failed): {e}");
798 return;
799 }
800
801 let (loadedmodel, loaded_optimizer, loaded_epoch, loaded_metrics) = result.unwrap();
802
803 assert_eq!(loadedmodel.layers().len(), 1);
805 assert_eq!(loaded_epoch, 10);
806 assert_eq!(loaded_metrics.get("loss"), metrics.get("loss"));
807 assert_eq!(loaded_metrics.get("accuracy"), metrics.get("accuracy"));
808 }
809}