scirs2_io/ml_framework/converters/
jax.rs1#![allow(dead_code)]
3
4use crate::error::{IoError, Result};
5use crate::ml_framework::converters::MLFrameworkConverter;
6use crate::ml_framework::types::{MLFramework, MLModel, MLTensor};
7use scirs2_core::ndarray::{ArrayD, IxDyn};
8use std::fs::File;
9use std::path::Path;
10
11pub struct JAXConverter;
13
14impl MLFrameworkConverter for JAXConverter {
15 fn save_model(&self, model: &MLModel, path: &Path) -> Result<()> {
16 let jax_model = serde_json::json!({
18 "format": "jax",
19 "version": "0.4.0",
20 "pytree": {
21 "params": model.weights.iter().map(|(name, tensor)| {
22 (name.clone(), serde_json::json!({
23 "shape": tensor.metadata.shape,
24 "dtype": format!("{:?}", tensor.metadata.dtype),
25 "data": tensor.data.as_slice().unwrap().to_vec()
26 }))
27 }).collect::<serde_json::Map<String, serde_json::Value>>(),
28 "config": model.config
29 },
30 "metadata": model.metadata
31 });
32
33 let file = File::create(path).map_err(IoError::Io)?;
34 serde_json::to_writer_pretty(file, &jax_model)
35 .map_err(|e| IoError::SerializationError(e.to_string()))
36 }
37
38 fn load_model(&self, path: &Path) -> Result<MLModel> {
39 let file = File::open(path).map_err(IoError::Io)?;
40 let jax_model: serde_json::Value = serde_json::from_reader(file)
41 .map_err(|e| IoError::SerializationError(e.to_string()))?;
42
43 let mut model = MLModel::new(MLFramework::JAX);
44
45 if let Some(metadata) = jax_model.get("metadata") {
46 model.metadata = serde_json::from_value(metadata.clone())
47 .map_err(|e| IoError::SerializationError(e.to_string()))?;
48 }
49
50 if let Some(pytree) = jax_model.get("pytree") {
51 if let Some(params) = pytree.get("params").and_then(|v| v.as_object()) {
52 for (name, param_data) in params {
53 let shape: Vec<usize> = serde_json::from_value(param_data["shape"].clone())
54 .map_err(|e| IoError::SerializationError(e.to_string()))?;
55
56 let data: Vec<f32> = serde_json::from_value(param_data["data"].clone())
57 .map_err(|e| IoError::SerializationError(e.to_string()))?;
58
59 let array = ArrayD::from_shape_vec(IxDyn(&shape), data)
60 .map_err(|e| IoError::Other(e.to_string()))?;
61
62 model
63 .weights
64 .insert(name.clone(), MLTensor::new(array, Some(name.clone())));
65 }
66 }
67
68 if let Some(config) = pytree.get("config") {
69 model.config = serde_json::from_value(config.clone())
70 .map_err(|e| IoError::SerializationError(e.to_string()))?;
71 }
72 }
73
74 Ok(model)
75 }
76
77 fn save_tensor(&self, tensor: &MLTensor, path: &Path) -> Result<()> {
78 let tensor_data = serde_json::json!({
79 "jax_array": {
80 "shape": tensor.metadata.shape,
81 "dtype": format!("{:?}", tensor.metadata.dtype),
82 "data": tensor.data.as_slice().unwrap().to_vec(),
83 "weak_type": false
84 }
85 });
86
87 let file = File::create(path).map_err(IoError::Io)?;
88 serde_json::to_writer_pretty(file, &tensor_data)
89 .map_err(|e| IoError::SerializationError(e.to_string()))
90 }
91
92 fn load_tensor(&self, path: &Path) -> Result<MLTensor> {
93 let file = File::open(path).map_err(IoError::Io)?;
94 let tensor_data: serde_json::Value = serde_json::from_reader(file)
95 .map_err(|e| IoError::SerializationError(e.to_string()))?;
96
97 if let Some(jax_array) = tensor_data.get("jax_array") {
98 let shape: Vec<usize> = serde_json::from_value(jax_array["shape"].clone())
99 .map_err(|e| IoError::SerializationError(e.to_string()))?;
100
101 let data: Vec<f32> = serde_json::from_value(jax_array["data"].clone())
102 .map_err(|e| IoError::SerializationError(e.to_string()))?;
103
104 let array = ArrayD::from_shape_vec(IxDyn(&shape), data)
105 .map_err(|e| IoError::Other(e.to_string()))?;
106
107 return Ok(MLTensor::new(array, None));
108 }
109
110 Err(IoError::Other("Invalid JAX tensor format".to_string()))
111 }
112}