scirs2_io/ml_framework/converters/
jax.rs

1//! JAX format converter
2#![allow(dead_code)]
3
4use crate::error::{IoError, Result};
5use crate::ml_framework::converters::MLFrameworkConverter;
6use crate::ml_framework::types::{MLFramework, MLModel, MLTensor};
7use scirs2_core::ndarray::{ArrayD, IxDyn};
8use std::fs::File;
9use std::path::Path;
10
11/// JAX format converter
12pub struct JAXConverter;
13
14impl MLFrameworkConverter for JAXConverter {
15    fn save_model(&self, model: &MLModel, path: &Path) -> Result<()> {
16        // JAX uses a simpler pickle-like format
17        let jax_model = serde_json::json!({
18            "format": "jax",
19            "version": "0.4.0",
20            "pytree": {
21                "params": model.weights.iter().map(|(name, tensor)| {
22                    (name.clone(), serde_json::json!({
23                        "shape": tensor.metadata.shape,
24                        "dtype": format!("{:?}", tensor.metadata.dtype),
25                        "data": tensor.data.as_slice().unwrap().to_vec()
26                    }))
27                }).collect::<serde_json::Map<String, serde_json::Value>>(),
28                "config": model.config
29            },
30            "metadata": model.metadata
31        });
32
33        let file = File::create(path).map_err(IoError::Io)?;
34        serde_json::to_writer_pretty(file, &jax_model)
35            .map_err(|e| IoError::SerializationError(e.to_string()))
36    }
37
38    fn load_model(&self, path: &Path) -> Result<MLModel> {
39        let file = File::open(path).map_err(IoError::Io)?;
40        let jax_model: serde_json::Value = serde_json::from_reader(file)
41            .map_err(|e| IoError::SerializationError(e.to_string()))?;
42
43        let mut model = MLModel::new(MLFramework::JAX);
44
45        if let Some(metadata) = jax_model.get("metadata") {
46            model.metadata = serde_json::from_value(metadata.clone())
47                .map_err(|e| IoError::SerializationError(e.to_string()))?;
48        }
49
50        if let Some(pytree) = jax_model.get("pytree") {
51            if let Some(params) = pytree.get("params").and_then(|v| v.as_object()) {
52                for (name, param_data) in params {
53                    let shape: Vec<usize> = serde_json::from_value(param_data["shape"].clone())
54                        .map_err(|e| IoError::SerializationError(e.to_string()))?;
55
56                    let data: Vec<f32> = serde_json::from_value(param_data["data"].clone())
57                        .map_err(|e| IoError::SerializationError(e.to_string()))?;
58
59                    let array = ArrayD::from_shape_vec(IxDyn(&shape), data)
60                        .map_err(|e| IoError::Other(e.to_string()))?;
61
62                    model
63                        .weights
64                        .insert(name.clone(), MLTensor::new(array, Some(name.clone())));
65                }
66            }
67
68            if let Some(config) = pytree.get("config") {
69                model.config = serde_json::from_value(config.clone())
70                    .map_err(|e| IoError::SerializationError(e.to_string()))?;
71            }
72        }
73
74        Ok(model)
75    }
76
77    fn save_tensor(&self, tensor: &MLTensor, path: &Path) -> Result<()> {
78        let tensor_data = serde_json::json!({
79            "jax_array": {
80                "shape": tensor.metadata.shape,
81                "dtype": format!("{:?}", tensor.metadata.dtype),
82                "data": tensor.data.as_slice().unwrap().to_vec(),
83                "weak_type": false
84            }
85        });
86
87        let file = File::create(path).map_err(IoError::Io)?;
88        serde_json::to_writer_pretty(file, &tensor_data)
89            .map_err(|e| IoError::SerializationError(e.to_string()))
90    }
91
92    fn load_tensor(&self, path: &Path) -> Result<MLTensor> {
93        let file = File::open(path).map_err(IoError::Io)?;
94        let tensor_data: serde_json::Value = serde_json::from_reader(file)
95            .map_err(|e| IoError::SerializationError(e.to_string()))?;
96
97        if let Some(jax_array) = tensor_data.get("jax_array") {
98            let shape: Vec<usize> = serde_json::from_value(jax_array["shape"].clone())
99                .map_err(|e| IoError::SerializationError(e.to_string()))?;
100
101            let data: Vec<f32> = serde_json::from_value(jax_array["data"].clone())
102                .map_err(|e| IoError::SerializationError(e.to_string()))?;
103
104            let array = ArrayD::from_shape_vec(IxDyn(&shape), data)
105                .map_err(|e| IoError::Other(e.to_string()))?;
106
107            return Ok(MLTensor::new(array, None));
108        }
109
110        Err(IoError::Other("Invalid JAX tensor format".to_string()))
111    }
112}