1use std::collections::HashMap;
13use std::fs::{self, File};
14use std::io::{Read, Write};
15use std::path::{Path, PathBuf};
16
17use ::ndarray::IxDyn;
18
19#[cfg(feature = "serialization")]
20use serde::{Deserialize, Serialize};
21#[cfg(feature = "serialization")]
22use serde_json;
23
24use chrono;
25
26use crate::array_protocol::grad::{Optimizer, SGD};
27use crate::array_protocol::ml_ops::ActivationFunc;
28use crate::array_protocol::neural::{
29 BatchNorm, Conv2D, Dropout, Layer, Linear, MaxPool2D, Sequential,
30};
31use crate::array_protocol::{ArrayProtocol, NdarrayWrapper};
32use crate::error::{CoreError, CoreResult, ErrorContext};
33
34pub trait Serializable {
36 fn serialize(&self) -> CoreResult<Vec<u8>>;
38
39 fn deserialize(bytes: &[u8]) -> CoreResult<Self>
41 where
42 Self: Sized;
43
44 fn type_name(&self) -> &str;
46}
47
48#[derive(Serialize, Deserialize)]
50pub struct ModelFile {
51 pub metadata: ModelMetadata,
53
54 pub architecture: ModelArchitecture,
56
57 pub parameter_files: HashMap<String, String>,
59
60 pub optimizer_state: Option<String>,
62}
63
64#[derive(Serialize, Deserialize)]
66pub struct ModelMetadata {
67 pub name: String,
69
70 pub version: String,
72
73 pub framework_version: String,
75
76 pub created_at: String,
78
79 pub inputshape: Vec<usize>,
81
82 pub outputshape: Vec<usize>,
84
85 pub additional_info: HashMap<String, String>,
87}
88
89#[cfg_attr(feature = "serialization", derive(Serialize, Deserialize))]
91pub struct ModelArchitecture {
92 pub model_type: String,
94
95 pub layers: Vec<LayerConfig>,
97}
98
99#[cfg_attr(feature = "serialization", derive(Serialize, Deserialize))]
101pub struct LayerConfig {
102 pub layer_type: String,
104
105 pub name: String,
107
108 #[cfg(feature = "serialization")]
110 pub config: serde_json::Value,
111 #[cfg(not(feature = "serialization"))]
112 pub config: HashMap<String, String>, }
114
115pub struct ModelSerializer {
117 basedir: PathBuf,
119}
120
121impl ModelSerializer {
122 pub fn new(basedir: impl AsRef<Path>) -> Self {
124 Self {
125 basedir: basedir.as_ref().to_path_buf(),
126 }
127 }
128
129 pub fn save_model(
131 &self,
132 model: &Sequential,
133 name: &str,
134 version: &str,
135 optimizer: Option<&dyn Optimizer>,
136 ) -> CoreResult<PathBuf> {
137 let modeldir = self.basedir.join(name).join(version);
139 fs::create_dir_all(&modeldir)?;
140
141 let metadata = ModelMetadata {
143 name: name.to_string(),
144 version: version.to_string(),
145 framework_version: "0.1.0".to_string(),
146 created_at: chrono::Utc::now().to_rfc3339(),
147 inputshape: vec![], outputshape: vec![], additional_info: HashMap::new(),
150 };
151
152 let architecture = self.create_architecture(model)?;
154
155 let mut parameter_files = HashMap::new();
157 self.save_parameters(model, &modeldir, &mut parameter_files)?;
158
159 let optimizer_state = if let Some(optimizer) = optimizer {
161 let optimizerpath = self.save_optimizer(optimizer, &modeldir)?;
162 Some(
163 optimizerpath
164 .file_name()
165 .expect("Operation failed")
166 .to_string_lossy()
167 .to_string(),
168 )
169 } else {
170 None
171 };
172
173 let model_file = ModelFile {
175 metadata,
176 architecture,
177 parameter_files,
178 optimizer_state,
179 };
180
181 let model_file_path = modeldir.join("model.json");
183 let model_file_json = serde_json::to_string_pretty(&model_file)?;
184 let mut file = File::create(&model_file_path)?;
185 file.write_all(model_file_json.as_bytes())?;
186
187 Ok(model_file_path)
188 }
189
190 pub fn loadmodel(
192 &self,
193 name: &str,
194 version: &str,
195 ) -> CoreResult<(Sequential, Option<Box<dyn Optimizer>>)> {
196 let modeldir = self.basedir.join(name).join(version);
198
199 let model_file_path = modeldir.join("model.json");
201 let mut file = File::open(&model_file_path)?;
202 let mut model_file_json = String::new();
203 file.read_to_string(&mut model_file_json)?;
204
205 let model_file: ModelFile = serde_json::from_str(&model_file_json)?;
206
207 let model = self.create_model_from_architecture(&model_file.architecture)?;
209
210 self.load_parameters(&model, &modeldir, &model_file.parameter_files)?;
212
213 let optimizer = if let Some(optimizer_state) = &model_file.optimizer_state {
215 let optimizerpath = modeldir.join(optimizer_state);
216 Some(self.load_optimizer(&optimizerpath)?)
217 } else {
218 None
219 };
220
221 Ok((model, optimizer))
222 }
223
224 fn create_architecture(&self, model: &Sequential) -> CoreResult<ModelArchitecture> {
226 let mut layers = Vec::new();
227
228 for layer in model.layers() {
229 let layer_config = self.create_layer_config(layer.as_ref())?;
230 layers.push(layer_config);
231 }
232
233 Ok(ModelArchitecture {
234 model_type: "Sequential".to_string(),
235 layers,
236 })
237 }
238
239 fn create_layer_config(&self, layer: &dyn Layer) -> CoreResult<LayerConfig> {
241 let layer_type = layer.layer_type();
242 if !["Linear", "Conv2D", "MaxPool2D", "BatchNorm", "Dropout"].contains(&layer_type) {
243 return Err(CoreError::NotImplementedError(ErrorContext::new(format!(
244 "Serialization not implemented for layer type: {}",
245 layer.name()
246 ))));
247 };
248
249 let config = match layer_type {
251 "Linear" => {
252 serde_json::json!({
255 "in_features": 0,
256 "out_features": 0,
257 "bias": true,
258 "activation": "relu",
259 })
260 }
261 "Conv2D" => {
262 serde_json::json!({
263 "filter_height": 3,
264 "filter_width": 3,
265 "in_channels": 0,
266 "out_channels": 0,
267 "stride": [1, 1],
268 "padding": [0, 0],
269 "bias": true,
270 "activation": "relu",
271 })
272 }
273 "MaxPool2D" => {
274 serde_json::json!({
275 "kernel_size": [2, 2],
276 "stride": [2, 2],
277 "padding": [0, 0],
278 })
279 }
280 "BatchNorm" => {
281 serde_json::json!({
282 "num_features": 0,
283 "epsilon": 1e-5,
284 "momentum": 0.1,
285 })
286 }
287 "Dropout" => {
288 serde_json::json!({
289 "rate": 0.5,
290 "seed": null,
291 })
292 }
293 _ => serde_json::json!({}),
294 };
295
296 Ok(LayerConfig {
297 layer_type: layer_type.to_string(),
298 name: layer.name().to_string(),
299 config,
300 })
301 }
302
303 fn save_parameters(
305 &self,
306 model: &Sequential,
307 modeldir: &Path,
308 parameter_files: &mut HashMap<String, String>,
309 ) -> CoreResult<()> {
310 let params_dir = modeldir.join("parameters");
312 fs::create_dir_all(¶ms_dir)?;
313
314 for (i, layer) in model.layers().iter().enumerate() {
316 for (j, param) in layer.parameters().iter().enumerate() {
317 let param_name = format!("layer_{i}_param_{j}");
319 let param_file = format!("{param_name}.npz");
320 let param_path = params_dir.join(¶m_file);
321
322 self.save_parameter(param.as_ref(), ¶m_path)?;
324
325 parameter_files.insert(param_name, format!("parameters/{param_file}"));
327 }
328 }
329
330 Ok(())
331 }
332
333 fn save_parameter(&self, param: &dyn ArrayProtocol, path: &Path) -> CoreResult<()> {
335 if let Some(array) = param.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
337 let ndarray = array.as_array();
338
339 let shape: Vec<usize> = ndarray.shape().to_vec();
341 let data: Vec<f64> = ndarray.iter().cloned().collect();
342
343 let save_data = serde_json::json!({
344 "shape": shape,
345 "data": data,
346 });
347
348 let mut file = File::create(path)?;
349 let json_str = serde_json::to_string(&save_data)?;
350 file.write_all(json_str.as_bytes())?;
351
352 Ok(())
353 } else {
354 Err(CoreError::NotImplementedError(ErrorContext::new(
355 "Parameter serialization not implemented for this array type".to_string(),
356 )))
357 }
358 }
359
360 fn save_optimizer(&self, _optimizer: &dyn Optimizer, modeldir: &Path) -> CoreResult<PathBuf> {
362 let optimizerpath = modeldir.join("optimizer.json");
364
365 let optimizer_data = serde_json::json!({
369 "type": "SGD", "config": {
371 "learningrate": 0.01,
372 "momentum": null
373 },
374 "state": {} });
376
377 let mut file = File::create(&optimizerpath)?;
378 let json_str = serde_json::to_string_pretty(&optimizer_data)?;
379 file.write_all(json_str.as_bytes())?;
380
381 Ok(optimizerpath)
382 }
383
384 fn create_model_from_architecture(
386 &self,
387 architecture: &ModelArchitecture,
388 ) -> CoreResult<Sequential> {
389 let mut model = Sequential::new(&architecture.model_type, Vec::new());
390
391 for layer_config in &architecture.layers {
393 let layer = self.create_layer_from_config(layer_config)?;
394 model.add_layer(layer);
395 }
396
397 Ok(model)
398 }
399
400 fn create_layer_from_config(&self, config: &LayerConfig) -> CoreResult<Box<dyn Layer>> {
402 match config.layer_type.as_str() {
403 "Linear" => {
404 let in_features = config.config["in_features"].as_u64().unwrap_or(0) as usize;
406 let out_features = config.config["out_features"].as_u64().unwrap_or(0) as usize;
407 let bias = config.config["bias"].as_bool().unwrap_or(true);
408 let activation = match config.config["activation"].as_str() {
409 Some("relu") => Some(ActivationFunc::ReLU),
410 Some("sigmoid") => Some(ActivationFunc::Sigmoid),
411 Some("tanh") => Some(ActivationFunc::Tanh),
412 _ => None,
413 };
414
415 Ok(Box::new(Linear::new_random(
417 &config.name,
418 in_features,
419 out_features,
420 bias,
421 activation,
422 )))
423 }
424 "Conv2D" => {
425 let filter_height = config.config["filter_height"].as_u64().unwrap_or(3) as usize;
427 let filter_width = config.config["filter_width"].as_u64().unwrap_or(3) as usize;
428 let in_channels = config.config["in_channels"].as_u64().unwrap_or(0) as usize;
429 let out_channels = config.config["out_channels"].as_u64().unwrap_or(0) as usize;
430 let stride = (
431 config.config["stride"][0].as_u64().unwrap_or(1) as usize,
432 config.config["stride"][1].as_u64().unwrap_or(1) as usize,
433 );
434 let padding = (
435 config.config["padding"][0].as_u64().unwrap_or(0) as usize,
436 config.config["padding"][1].as_u64().unwrap_or(0) as usize,
437 );
438 let bias = config.config["bias"].as_bool().unwrap_or(true);
439 let activation = match config.config["activation"].as_str() {
440 Some("relu") => Some(ActivationFunc::ReLU),
441 Some("sigmoid") => Some(ActivationFunc::Sigmoid),
442 Some("tanh") => Some(ActivationFunc::Tanh),
443 _ => None,
444 };
445
446 Ok(Box::new(Conv2D::withshape(
448 &config.name,
449 filter_height,
450 filter_width,
451 in_channels,
452 out_channels,
453 stride,
454 padding,
455 bias,
456 activation,
457 )))
458 }
459 "MaxPool2D" => {
460 let kernel_size = (
462 config.config["kernel_size"][0].as_u64().unwrap_or(2) as usize,
463 config.config["kernel_size"][1].as_u64().unwrap_or(2) as usize,
464 );
465 let stride = if config.config["stride"].is_array() {
466 Some((
467 config.config["stride"][0].as_u64().unwrap_or(2) as usize,
468 config.config["stride"][1].as_u64().unwrap_or(2) as usize,
469 ))
470 } else {
471 None
472 };
473 let padding = (
474 config.config["padding"][0].as_u64().unwrap_or(0) as usize,
475 config.config["padding"][1].as_u64().unwrap_or(0) as usize,
476 );
477
478 Ok(Box::new(MaxPool2D::new(
480 &config.name,
481 kernel_size,
482 stride,
483 padding,
484 )))
485 }
486 "BatchNorm" => {
487 let num_features = config.config["num_features"].as_u64().unwrap_or(0) as usize;
489 let epsilon = config.config["epsilon"].as_f64().unwrap_or(1e-5);
490 let momentum = config.config["momentum"].as_f64().unwrap_or(0.1);
491
492 Ok(Box::new(BatchNorm::withshape(
494 &config.name,
495 num_features,
496 Some(epsilon),
497 Some(momentum),
498 )))
499 }
500 "Dropout" => {
501 let rate = config.config["rate"].as_f64().unwrap_or(0.5);
503 let seed = config.config["seed"].as_u64();
504
505 Ok(Box::new(Dropout::new(&config.name, rate, seed)))
507 }
508 _ => Err(CoreError::NotImplementedError(ErrorContext::new(format!(
509 "Deserialization not implemented for layer type: {layer_type}",
510 layer_type = config.layer_type
511 )))),
512 }
513 }
514
515 fn load_parameters(
517 &self,
518 model: &Sequential,
519 modeldir: &Path,
520 parameter_files: &HashMap<String, String>,
521 ) -> CoreResult<()> {
522 for (i, layer) in model.layers().iter().enumerate() {
524 let params = layer.parameters();
525 for (j, param) in params.iter().enumerate() {
526 let param_name = format!("layer_{i}_param_{j}");
528 if let Some(param_file) = parameter_files.get(¶m_name) {
529 let param_path = modeldir.join(param_file);
530
531 if param_path.exists() {
533 let mut file = File::open(¶m_path)?;
534 let mut json_str = String::new();
535 file.read_to_string(&mut json_str)?;
536
537 let load_data: serde_json::Value = serde_json::from_str(&json_str)?;
538 let shape: Vec<usize> = serde_json::from_value(load_data["shape"].clone())?;
539 let _data: Vec<f64> = serde_json::from_value(load_data["data"].clone())?;
540
541 if let Some(_array) =
547 param.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>()
548 {
549 }
552 } else {
553 return Err(CoreError::InvalidArgument(ErrorContext::new(format!(
554 "Parameter file not found: {path}",
555 path = param_path.display()
556 ))));
557 }
558 }
559 }
560 }
561
562 Ok(())
563 }
564
565 fn load_optimizer(&self, optimizerpath: &Path) -> CoreResult<Box<dyn Optimizer>> {
567 if !optimizerpath.exists() {
569 return Err(CoreError::InvalidArgument(ErrorContext::new(format!(
570 "Optimizer file not found: {path}",
571 path = optimizerpath.display()
572 ))));
573 }
574
575 let mut file = File::open(optimizerpath)?;
577 let mut json_str = String::new();
578 file.read_to_string(&mut json_str)?;
579
580 let optimizer_data: serde_json::Value = serde_json::from_str(&json_str)?;
581
582 match optimizer_data["type"].as_str() {
584 Some("SGD") => {
585 let config = &optimizer_data["config"];
586 let learningrate = config["learningrate"].as_f64().unwrap_or(0.01);
587 let momentum = config["momentum"].as_f64();
588 Ok(Box::new(SGD::new(learningrate, momentum)))
589 }
590 _ => {
591 Ok(Box::new(SGD::new(0.01, None)))
593 }
594 }
595 }
596}
597
598pub struct OnnxExporter;
600
601impl OnnxExporter {
602 pub fn export(
604 &self,
605 _model: &Sequential,
606 path: impl AsRef<Path>,
607 _inputshape: &[usize],
608 ) -> CoreResult<()> {
609 File::create(path.as_ref())?;
614
615 Ok(())
616 }
617}
618
619#[allow(dead_code)]
621pub fn save_checkpoint(
622 model: &Sequential,
623 optimizer: &dyn Optimizer,
624 path: impl AsRef<Path>,
625 epoch: usize,
626 metrics: HashMap<String, f64>,
627) -> CoreResult<()> {
628 let checkpoint_dir = path.as_ref().parent().unwrap_or(Path::new("."));
630 fs::create_dir_all(checkpoint_dir)?;
631
632 let metadata = serde_json::json!({
634 "epoch": epoch,
635 "metrics": metrics,
636 "timestamp": chrono::Utc::now().to_rfc3339(),
637 });
638
639 let metadata_path = path.as_ref().with_extension("json");
641 let metadata_json = serde_json::to_string_pretty(&metadata)?;
642 let mut file = File::create(&metadata_path)?;
643 file.write_all(metadata_json.as_bytes())?;
644
645 let serializer = ModelSerializer::new(checkpoint_dir);
647
648 let model_name = "checkpoint";
650 let model_version = format!("epoch_{epoch}");
651 serializer.save_model(model, model_name, &model_version, Some(optimizer))?;
652
653 Ok(())
654}
655
656pub type ModelCheckpoint = (Sequential, Box<dyn Optimizer>, usize, HashMap<String, f64>);
658
659#[cfg(feature = "serialization")]
661#[allow(dead_code)]
662pub fn load_checkpoint(path: impl AsRef<Path>) -> CoreResult<ModelCheckpoint> {
663 let metadata_path = path.as_ref().with_extension("json");
665 let mut file = File::open(&metadata_path)?;
666 let mut metadata_json = String::new();
667 file.read_to_string(&mut metadata_json)?;
668
669 let metadata: serde_json::Value = serde_json::from_str(&metadata_json)?;
670
671 let epoch = metadata["epoch"].as_u64().unwrap_or(0) as usize;
673 let metrics: HashMap<String, f64> =
674 serde_json::from_value(metadata["metrics"].clone()).unwrap_or_else(|_| HashMap::new());
675
676 let checkpoint_dir = path.as_ref().parent().unwrap_or(Path::new("."));
678 let serializer = ModelSerializer::new(checkpoint_dir);
679
680 let model_name = "checkpoint";
682 let model_version = format!("epoch_{epoch}");
683 let (model, optimizer) = serializer.loadmodel(model_name, &model_version)?;
684
685 Ok((model, optimizer.expect("Operation failed"), epoch, metrics))
686}
687
688#[cfg(test)]
689mod tests {
690 use super::*;
691 use crate::array_protocol;
692 use crate::array_protocol::grad::SGD;
693 use crate::array_protocol::ml_ops::ActivationFunc;
694 use crate::array_protocol::neural::{Linear, Sequential};
695 use tempfile::tempdir;
696
697 #[test]
698 fn test_model_serializer() {
699 array_protocol::init();
701
702 let temp_dir = match tempdir() {
704 Ok(dir) => dir,
705 Err(e) => {
706 println!("Skipping test_model_serializer (temp dir creation failed): {e}");
707 return;
708 }
709 };
710
711 let mut model = Sequential::new("test_model", Vec::new());
713
714 model.add_layer(Box::new(Linear::new_random(
716 "fc1",
717 10,
718 5,
719 true,
720 Some(ActivationFunc::ReLU),
721 )));
722
723 model.add_layer(Box::new(Linear::new_random("fc2", 5, 2, true, None)));
724
725 let optimizer = SGD::new(0.01, Some(0.9));
727
728 let serializer = ModelSerializer::new(temp_dir.path());
730
731 let model_path = serializer.save_model(&model, "test_model", "v1", Some(&optimizer));
733 if model_path.is_err() {
734 println!("Save model failed: {:?}", model_path.err());
735 return;
736 }
737
738 let (loadedmodel, loaded_optimizer) = serializer
740 .loadmodel("test_model", "v1")
741 .expect("Operation failed");
742
743 assert_eq!(loadedmodel.layers().len(), 2);
745 assert!(loaded_optimizer.is_some());
746 }
747
748 #[test]
749 fn test_save_load_checkpoint() {
750 array_protocol::init();
752
753 let temp_dir = match tempdir() {
755 Ok(dir) => dir,
756 Err(e) => {
757 println!("Skipping test_save_load_checkpoint (temp dir creation failed): {e}");
758 return;
759 }
760 };
761
762 let mut model = Sequential::new("test_model", Vec::new());
764
765 model.add_layer(Box::new(Linear::new_random(
767 "fc1",
768 10,
769 5,
770 true,
771 Some(ActivationFunc::ReLU),
772 )));
773
774 let optimizer = SGD::new(0.01, Some(0.9));
776
777 let mut metrics = HashMap::new();
779 metrics.insert("loss".to_string(), 0.1);
780 metrics.insert("accuracy".to_string(), 0.9);
781
782 let checkpoint_path = temp_dir.path().join("checkpoint");
784 let result = save_checkpoint(&model, &optimizer, &checkpoint_path, 10, metrics.clone());
785 if let Err(e) = result {
786 println!("Skipping test_save_load_checkpoint (save failed): {e}");
787 return;
788 }
789
790 let result = load_checkpoint(&checkpoint_path);
792 if let Err(e) = result {
793 println!("Skipping test_save_load_checkpoint (load failed): {e}");
794 return;
795 }
796
797 let (loadedmodel, loaded_optimizer, loaded_epoch, loaded_metrics) =
798 result.expect("Operation failed");
799
800 assert_eq!(loadedmodel.layers().len(), 1);
802 assert_eq!(loaded_epoch, 10);
803 assert_eq!(loaded_metrics.get("loss"), metrics.get("loss"));
804 assert_eq!(loaded_metrics.get("accuracy"), metrics.get("accuracy"));
805 }
806}