scirs2_cluster/serialization/
compatibility.rs1use crate::error::{ClusteringError, Result};
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::collections::HashMap;
10
11use super::models::*;
12
13pub fn create_sklearn_param_grid(
15 algorithm: &str,
16 param_ranges: HashMap<String, Vec<Value>>,
17) -> Result<HashMap<String, Vec<Value>>> {
18 match algorithm {
19 "kmeans" => {
20 let mut grid = HashMap::new();
21 if let Some(n_clusters) = param_ranges.get("n_clusters") {
22 grid.insert("n_clusters".to_string(), n_clusters.clone());
23 }
24 if let Some(init) = param_ranges.get("init") {
25 grid.insert("init".to_string(), init.clone());
26 }
27 Ok(grid)
28 }
29 "dbscan" => {
30 let mut grid = HashMap::new();
31 if let Some(eps) = param_ranges.get("eps") {
32 grid.insert("eps".to_string(), eps.clone());
33 }
34 if let Some(min_samples) = param_ranges.get("min_samples") {
35 grid.insert("min_samples".to_string(), min_samples.clone());
36 }
37 Ok(grid)
38 }
39 _ => Err(ClusteringError::InvalidInput(format!(
40 "Unsupported algorithm for sklearn parameter grid: {}",
41 algorithm
42 ))),
43 }
44}
45
46pub fn from_joblib_format(data: Vec<u8>) -> Result<Value> {
48 serde_json::from_slice(&data)
51 .map_err(|e| ClusteringError::InvalidInput(format!("Failed to parse joblib format: {}", e)))
52}
53
54pub fn from_numpy_format(data: Vec<u8>) -> Result<scirs2_core::ndarray::Array2<f64>> {
56 let json_data: Value = serde_json::from_slice(&data).map_err(|e| {
59 ClusteringError::InvalidInput(format!("Failed to parse numpy format: {}", e))
60 })?;
61
62 if let Value::Array(array) = json_data {
63 let mut flat_data = Vec::new();
64 let mut ncols = 0;
65
66 if let Some(Value::Array(first_row)) = array.first() {
67 ncols = first_row.len();
68 }
69 let nrows = array.len();
70
71 for row in array {
72 if let Value::Array(row_values) = row {
73 for val in row_values {
74 if let Value::Number(num) = val {
75 flat_data.push(num.as_f64().unwrap_or(0.0));
76 }
77 }
78 }
79 }
80
81 scirs2_core::ndarray::Array2::from_shape_vec((nrows, ncols), flat_data).map_err(|e| {
82 ClusteringError::InvalidInput(format!("Failed to create array from numpy data: {}", e))
83 })
84 } else {
85 Err(ClusteringError::InvalidInput(
86 "Invalid numpy format".to_string(),
87 ))
88 }
89}
90
91pub fn from_sklearn_format(data: Value) -> Result<Value> {
93 Ok(data)
95}
96
97pub fn generate_sklearn_model_summary(model_type: &str, model_data: &Value) -> Result<String> {
99 match model_type {
100 "KMeans" => {
101 let summary = serde_json::json!({
102 "model_type": "KMeans",
103 "n_clusters": model_data.get("n_clusters").unwrap_or(&Value::Null),
104 "inertia": model_data.get("inertia_").unwrap_or(&Value::Null),
105 "n_iter": model_data.get("n_iter_").unwrap_or(&Value::Null)
106 });
107 Ok(serde_json::to_string_pretty(&summary)?)
108 }
109 "DBSCAN" => {
110 let summary = serde_json::json!({
111 "model_type": "DBSCAN",
112 "eps": model_data.get("eps").unwrap_or(&Value::Null),
113 "min_samples": model_data.get("min_samples").unwrap_or(&Value::Null)
114 });
115 Ok(serde_json::to_string_pretty(&summary)?)
116 }
117 _ => Err(ClusteringError::InvalidInput(format!(
118 "Unsupported sklearn model type: {}",
119 model_type
120 ))),
121 }
122}
123
124pub fn to_arrow_schema<T: ClusteringModel>(model: &T) -> Result<Value> {
126 let schema = serde_json::json!({
127 "type": "struct",
128 "fields": [
129 {
130 "name": "cluster_id",
131 "type": {
132 "name": "int",
133 "bitWidth": 32
134 },
135 "nullable": false
136 },
137 {
138 "name": "features",
139 "type": {
140 "name": "list",
141 "valueType": {
142 "name": "floatingpoint",
143 "precision": "DOUBLE"
144 }
145 },
146 "nullable": false
147 }
148 ]
149 });
150 Ok(schema)
151}
152
153pub fn to_huggingface_card<T: ClusteringModel>(model: &T) -> Result<String> {
155 let summary = model.summary()?;
156 let card = format!(
157 r#"
158---
159tags:
160- clustering
161- unsupervised-learning
162- scirs2-cluster
163library_name: scirs2-cluster
164model_summary: {}
165---
166
167# Clustering Model
168
169This is a clustering model trained using scirs2-cluster.
170
171## Model Details
172
173{}
174
175## Usage
176
177```rust
178use scirs2_cluster::serialization::SerializableModel;
179
180// Load the model
181let model = Model::load_from_file("model.json")?;
182
183// Use for prediction
184let predictions = model.predict(data.view())?;
185```
186"#,
187 serde_json::to_string_pretty(&summary)?,
188 serde_json::to_string_pretty(&summary)?
189 );
190
191 Ok(card)
192}
193
194pub fn to_joblib_format<T: ClusteringModel>(model: &T) -> Result<Vec<u8>> {
196 let summary = model.summary()?;
198 Ok(serde_json::to_vec(&summary)?)
199}
200
201pub fn to_mlflow_format<T: ClusteringModel>(model: &T) -> Result<Value> {
203 let summary = model.summary()?;
204 Ok(serde_json::json!({
205 "artifact_path": "model",
206 "flavors": {
207 "scirs2_cluster": {
208 "model_type": "clustering",
209 "scirs2_version": env!("CARGO_PKG_VERSION"),
210 "data": summary
211 }
212 },
213 "model_uuid": uuid::Uuid::new_v4().to_string(),
214 "run_id": "unknown",
215 "utc_time_created": chrono::Utc::now().to_rfc3339()
216 }))
217}
218
219pub fn to_numpy_format(data: &scirs2_core::ndarray::Array2<f64>) -> Result<Vec<u8>> {
221 let shape = data.shape();
224 let numpy_data = serde_json::json!({
225 "shape": shape,
226 "data": data.as_slice().unwrap_or(&[])
227 });
228 Ok(serde_json::to_vec(&numpy_data)?)
229}
230
231pub fn to_onnx_metadata<T: ClusteringModel>(model: &T) -> Result<Value> {
233 let summary = model.summary()?;
234 Ok(serde_json::json!({
235 "ir_version": 7,
236 "producer_name": "scirs2-cluster",
237 "producer_version": env!("CARGO_PKG_VERSION"),
238 "model_version": 1,
239 "doc_string": "Clustering model exported from scirs2-cluster",
240 "metadata_props": {
241 "model_summary": summary
242 }
243 }))
244}
245
246pub fn to_pandas_clustering_report<T: ClusteringModel>(model: &T) -> Result<Value> {
248 let summary = model.summary()?;
249 Ok(serde_json::json!({
250 "model_type": "clustering",
251 "n_clusters": model.n_clusters(),
252 "summary": summary,
253 "pandas_version": "1.0.0",
254 "created_at": chrono::Utc::now().to_rfc3339()
255 }))
256}
257
258pub fn to_pandas_format<T: ClusteringModel>(model: &T) -> Result<Value> {
260 to_pandas_clustering_report(model)
261}
262
263pub fn to_pickle_like_format<T: ClusteringModel>(model: &T) -> Result<Vec<u8>> {
265 let summary = model.summary()?;
267 Ok(serde_json::to_vec(&summary)?)
268}
269
270pub fn to_pytorch_checkpoint<T: ClusteringModel>(model: &T) -> Result<Value> {
272 let summary = model.summary()?;
273 Ok(serde_json::json!({
274 "model_state_dict": summary,
275 "optimizer_state_dict": {},
276 "epoch": 1,
277 "loss": 0.0,
278 "pytorch_version": "1.10.0",
279 "scirs2_cluster_version": env!("CARGO_PKG_VERSION")
280 }))
281}
282
283pub fn to_r_format<T: ClusteringModel>(model: &T) -> Result<Value> {
285 let summary = model.summary()?;
286 Ok(serde_json::json!({
287 "class": "clustering_model",
288 "data": summary,
289 "r_version": "4.0.0",
290 "created_by": "scirs2-cluster"
291 }))
292}
293
294pub fn to_scipy_dendrogram_format(
296 linkage_matrix: &scirs2_core::ndarray::Array2<f64>,
297) -> Result<Value> {
298 Ok(serde_json::json!({
299 "linkage": linkage_matrix.as_slice().unwrap_or(&[]),
300 "format": "scipy_dendrogram",
301 "shape": linkage_matrix.shape()
302 }))
303}
304
305pub fn to_scipy_linkage_format(
307 linkage_matrix: &scirs2_core::ndarray::Array2<f64>,
308) -> Result<Value> {
309 Ok(serde_json::json!({
310 "linkage_matrix": linkage_matrix.as_slice().unwrap_or(&[]),
311 "shape": linkage_matrix.shape(),
312 "method": "ward",
313 "metric": "euclidean"
314 }))
315}
316
317pub fn to_sklearn_clustering_result<T: ClusteringModel>(model: &T) -> Result<Value> {
319 let summary = model.summary()?;
320 Ok(serde_json::json!({
321 "labels_": [],
322 "n_clusters_": model.n_clusters(),
323 "model_summary": summary,
324 "_sklearn_version": "1.0.0"
325 }))
326}
327
328pub fn to_sklearn_format<T: ClusteringModel>(model: &T) -> Result<Value> {
330 to_sklearn_clustering_result(model)
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336 use scirs2_core::ndarray::Array2;
337
338 #[test]
339 fn test_create_sklearn_param_grid() {
340 let mut params = HashMap::new();
341 params.insert(
342 "n_clusters".to_string(),
343 vec![serde_json::json!(2), serde_json::json!(3)],
344 );
345
346 let grid = create_sklearn_param_grid("kmeans", params).unwrap();
347 assert!(grid.contains_key("n_clusters"));
348 }
349
350 #[test]
351 fn test_to_numpy_format() {
352 let data = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
353 let result = to_numpy_format(&data);
354 assert!(result.is_ok());
355 }
356
357 #[test]
358 fn test_to_scipy_linkage_format() {
359 let linkage = Array2::from_shape_vec((1, 3), vec![0.0, 1.0, 0.5]).unwrap();
360 let result = to_scipy_linkage_format(&linkage);
361 assert!(result.is_ok());
362 }
363}