1use crate::error::{ClusteringError, Result};
7use scirs2_core::ndarray::{Array1, Array2};
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use std::collections::HashMap;
11use std::path::Path;
12
13#[cfg(feature = "yaml")]
14use serde_yaml;
15
16use super::core::{EnhancedModelMetadata, PlatformInfo, SerializableModel};
17use super::models::*;
18
19#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
21pub enum ExportFormat {
22 Json,
24 JsonGz,
26 Binary,
28 Csv,
30 Yaml,
32 Xml,
34 Hdf5,
36 Custom(String),
38}
39
40pub trait AdvancedExport {
42 fn export_with_metadata(
44 &self,
45 format: ExportFormat,
46 metadata: Option<ModelMetadata>,
47 ) -> Result<Vec<u8>>;
48
49 fn export_to_file<P: AsRef<Path>>(&self, path: P) -> Result<()>;
51
52 fn export_summary(&self) -> Result<String>;
54
55 fn export_compatible(&self, target_library: &str) -> Result<Value>;
57
58 fn validate_for_export(&self) -> Result<()>;
60}
61
62#[derive(Serialize, Deserialize, Debug, Clone)]
64pub struct ModelMetadata {
65 pub model_info: ModelInfo,
67 pub algorithm_config: AlgorithmConfig,
69 pub performance_metrics: PerformanceMetrics,
71 pub data_characteristics: ModelDataCharacteristics,
73 pub export_settings: ExportSettings,
75}
76
77#[derive(Serialize, Deserialize, Debug, Clone)]
79pub struct ModelInfo {
80 pub name: String,
82 pub version: String,
84 pub created_at: String,
86 pub author: Option<String>,
88 pub description: Option<String>,
90}
91
92#[derive(Serialize, Deserialize, Debug, Clone)]
94pub struct AlgorithmConfig {
95 pub algorithm: String,
97 pub hyperparameters: HashMap<String, Value>,
99 pub preprocessing: Vec<String>,
101 pub random_seed: Option<u64>,
103 pub convergence_criteria: Option<HashMap<String, f64>>,
105}
106
107#[derive(Serialize, Deserialize, Debug, Clone)]
109pub struct PerformanceMetrics {
110 pub training_time_seconds: f64,
112 pub peak_memory_mb: f64,
114 pub cpu_utilization: f64,
116 pub quality_metrics: HashMap<String, f64>,
118 pub convergence_info: Option<ConvergenceInfo>,
120}
121
122#[derive(Serialize, Deserialize, Debug, Clone)]
124pub struct ConvergenceInfo {
125 pub converged: bool,
127 pub iterations: usize,
129 pub final_objective: f64,
131 pub tolerance_achieved: f64,
133}
134
135#[derive(Serialize, Deserialize, Debug, Clone)]
137pub struct ModelDataCharacteristics {
138 pub n_samples: usize,
140 pub n_features: usize,
142 pub feature_names: Option<Vec<String>>,
144 pub feature_types: Option<Vec<String>>,
146 pub feature_statistics: Option<HashMap<String, FeatureStats>>,
148}
149
150#[derive(Serialize, Deserialize, Debug, Clone)]
152pub struct FeatureStats {
153 pub mean: f64,
155 pub std: f64,
157 pub min: f64,
159 pub max: f64,
161 pub missing_count: usize,
163}
164
165#[derive(Serialize, Deserialize, Debug, Clone)]
167pub struct ExportSettings {
168 pub include_raw_data: bool,
170 pub include_training_data: bool,
172 pub compression_level: Option<u8>,
174 pub float_precision: Option<usize>,
176 pub custom_options: HashMap<String, Value>,
178}
179
180impl Default for ExportSettings {
181 fn default() -> Self {
182 Self {
183 include_raw_data: true,
184 include_training_data: false,
185 compression_level: None,
186 float_precision: Some(6),
187 custom_options: HashMap::new(),
188 }
189 }
190}
191
192impl AdvancedExport for KMeansModel {
194 fn export_with_metadata(
195 &self,
196 format: ExportFormat,
197 metadata: Option<ModelMetadata>,
198 ) -> Result<Vec<u8>> {
199 let export_data = KMeansExportData {
200 model: self.clone(),
201 metadata,
202 format_version: "1.0".to_string(),
203 export_timestamp: chrono::Utc::now().to_rfc3339(),
204 };
205
206 match format {
207 ExportFormat::Json => {
208 let json = serde_json::to_string_pretty(&export_data)
209 .map_err(|e| ClusteringError::InvalidInput(e.to_string()))?;
210 Ok(json.into_bytes())
211 }
212 ExportFormat::JsonGz => {
213 let json = serde_json::to_string(&export_data)
214 .map_err(|e| ClusteringError::InvalidInput(e.to_string()))?;
215
216 use flate2::write::GzEncoder;
217 use flate2::Compression;
218 use std::io::Write;
219
220 let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
221 encoder
222 .write_all(json.as_bytes())
223 .map_err(|e| ClusteringError::InvalidInput(e.to_string()))?;
224 encoder
225 .finish()
226 .map_err(|e| ClusteringError::InvalidInput(e.to_string()))
227 }
228 #[cfg(feature = "yaml")]
229 ExportFormat::Yaml => {
230 let yaml = serde_yaml::to_string(&export_data)
231 .map_err(|e| ClusteringError::InvalidInput(e.to_string()))?;
232 Ok(yaml.into_bytes())
233 }
234 #[cfg(not(feature = "yaml"))]
235 ExportFormat::Yaml => Err(ClusteringError::InvalidInput(
236 "YAML support not enabled. Enable the 'yaml' feature".to_string(),
237 )),
238 ExportFormat::Csv => self.export_csv(),
239 _ => Err(ClusteringError::InvalidInput(format!(
240 "Unsupported export format: {:?}",
241 format
242 ))),
243 }
244 }
245
246 fn export_to_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
247 let path = path.as_ref();
248 let format = detect_format_from_extension(path)?;
249 let data = self.export_with_metadata(format, None)?;
250
251 std::fs::write(path, data)
252 .map_err(|e| ClusteringError::InvalidInput(format!("Failed to write file: {}", e)))
253 }
254
255 fn export_summary(&self) -> Result<String> {
256 let summary = KMeansSummary {
257 algorithm: "K-Means".to_string(),
258 n_clusters: self.n_clusters,
259 n_features: self.centroids.ncols(),
260 n_iterations: self.n_iter,
261 inertia: self.inertia,
262 has_labels: self.labels.is_some(),
263 };
264
265 serde_json::to_string_pretty(&summary)
266 .map_err(|e| ClusteringError::InvalidInput(e.to_string()))
267 }
268
269 fn export_compatible(&self, target_library: &str) -> Result<Value> {
270 match target_library.to_lowercase().as_str() {
271 "sklearn" | "scikit-learn" => self.to_sklearn_format(),
272 "tensorflow" | "tf" => self.to_tensorflow_format(),
273 "pytorch" => self.to_pytorch_format(),
274 _ => Err(ClusteringError::InvalidInput(format!(
275 "Unsupported target library: {}",
276 target_library
277 ))),
278 }
279 }
280
281 fn validate_for_export(&self) -> Result<()> {
282 if self.centroids.is_empty() {
283 return Err(ClusteringError::InvalidInput(
284 "Cannot export model with empty centroids".to_string(),
285 ));
286 }
287
288 if self.n_clusters == 0 {
289 return Err(ClusteringError::InvalidInput(
290 "Cannot export model with zero clusters".to_string(),
291 ));
292 }
293
294 if self.centroids.nrows() != self.n_clusters {
295 return Err(ClusteringError::InvalidInput(
296 "Centroids shape inconsistent with n_clusters".to_string(),
297 ));
298 }
299
300 Ok(())
301 }
302}
303
304#[derive(Serialize, Deserialize, Debug, Clone)]
306struct KMeansExportData {
307 model: KMeansModel,
308 metadata: Option<ModelMetadata>,
309 format_version: String,
310 export_timestamp: String,
311}
312
313#[derive(Serialize, Deserialize, Debug, Clone)]
315struct KMeansSummary {
316 algorithm: String,
317 n_clusters: usize,
318 n_features: usize,
319 n_iterations: usize,
320 inertia: f64,
321 has_labels: bool,
322}
323
324impl KMeansModel {
325 fn export_csv(&self) -> Result<Vec<u8>> {
327 let mut csv_content = String::new();
328
329 csv_content.push_str("cluster_id");
331 for i in 0..self.centroids.ncols() {
332 csv_content.push_str(&format!(",feature_{}", i));
333 }
334 csv_content.push('\n');
335
336 for (cluster_id, centroid) in self.centroids.rows().into_iter().enumerate() {
338 csv_content.push_str(&cluster_id.to_string());
339 for value in centroid {
340 csv_content.push_str(&format!(",{:.6}", value));
341 }
342 csv_content.push('\n');
343 }
344
345 Ok(csv_content.into_bytes())
346 }
347
348 fn to_sklearn_format(&self) -> Result<Value> {
350 use serde_json::json;
351
352 Ok(json!({
353 "cluster_centers_": self.centroids.as_slice().unwrap(),
354 "labels_": self.labels.as_ref().map(|l| l.as_slice().unwrap()),
355 "inertia_": self.inertia,
356 "n_iter_": self.n_iter,
357 "n_clusters": self.n_clusters,
358 "_sklearn_version": "1.0.0"
359 }))
360 }
361
362 fn to_tensorflow_format(&self) -> Result<Value> {
364 use serde_json::json;
365
366 Ok(json!({
367 "centroids": {
368 "data": self.centroids.as_slice().unwrap(),
369 "shape": [self.centroids.nrows(), self.centroids.ncols()],
370 "dtype": "float64"
371 },
372 "metadata": {
373 "n_clusters": self.n_clusters,
374 "inertia": self.inertia,
375 "iterations": self.n_iter
376 }
377 }))
378 }
379
380 fn to_pytorch_format(&self) -> Result<Value> {
382 use serde_json::json;
383
384 Ok(json!({
385 "state_dict": {
386 "centroids": self.centroids.as_slice().unwrap()
387 },
388 "hyperparameters": {
389 "n_clusters": self.n_clusters
390 },
391 "metrics": {
392 "inertia": self.inertia,
393 "n_iter": self.n_iter
394 }
395 }))
396 }
397}
398
399fn detect_format_from_extension<P: AsRef<Path>>(path: P) -> Result<ExportFormat> {
401 let path = path.as_ref();
402 let extension = path
403 .extension()
404 .and_then(|ext| ext.to_str())
405 .unwrap_or("")
406 .to_lowercase();
407
408 match extension.as_str() {
409 "json" => Ok(ExportFormat::Json),
410 "gz" | "json.gz" => Ok(ExportFormat::JsonGz),
411 "yaml" | "yml" => Ok(ExportFormat::Yaml),
412 "csv" => Ok(ExportFormat::Csv),
413 "xml" => Ok(ExportFormat::Xml),
414 "h5" | "hdf5" => Ok(ExportFormat::Hdf5),
415 _ => Err(ClusteringError::InvalidInput(format!(
416 "Unknown file extension: {}",
417 extension
418 ))),
419 }
420}
421
422pub mod utils {
424 use super::*;
425
426 pub fn create_default_metadata(algorithm_name: &str) -> ModelMetadata {
428 ModelMetadata {
429 model_info: ModelInfo {
430 name: format!("{}_model", algorithm_name),
431 version: "1.0.0".to_string(),
432 created_at: chrono::Utc::now().to_rfc3339(),
433 author: None,
434 description: None,
435 },
436 algorithm_config: AlgorithmConfig {
437 algorithm: algorithm_name.to_string(),
438 hyperparameters: HashMap::new(),
439 preprocessing: Vec::new(),
440 random_seed: None,
441 convergence_criteria: None,
442 },
443 performance_metrics: PerformanceMetrics {
444 training_time_seconds: 0.0,
445 peak_memory_mb: 0.0,
446 cpu_utilization: 0.0,
447 quality_metrics: HashMap::new(),
448 convergence_info: None,
449 },
450 data_characteristics: ModelDataCharacteristics {
451 n_samples: 0,
452 n_features: 0,
453 feature_names: None,
454 feature_types: None,
455 feature_statistics: None,
456 },
457 export_settings: ExportSettings::default(),
458 }
459 }
460
461 pub fn compress_data(data: &[u8]) -> Result<Vec<u8>> {
463 use flate2::write::GzEncoder;
464 use flate2::Compression;
465 use std::io::Write;
466
467 let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
468 encoder
469 .write_all(data)
470 .map_err(|e| ClusteringError::InvalidInput(e.to_string()))?;
471 encoder
472 .finish()
473 .map_err(|e| ClusteringError::InvalidInput(e.to_string()))
474 }
475
476 pub fn decompress_data(compressed: &[u8]) -> Result<Vec<u8>> {
478 use flate2::read::GzDecoder;
479 use std::io::Read;
480
481 let mut decoder = GzDecoder::new(compressed);
482 let mut decompressed = Vec::new();
483 decoder
484 .read_to_end(&mut decompressed)
485 .map_err(|e| ClusteringError::InvalidInput(e.to_string()))?;
486 Ok(decompressed)
487 }
488}
489
490#[cfg(test)]
491mod tests {
492 use super::*;
493 use scirs2_core::ndarray::Array2;
494
495 #[test]
496 fn test_kmeans_export_summary() {
497 let centroids = Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 1.0, 1.0]).unwrap();
498 let model = KMeansModel::new(centroids, 2, 10, 0.5, None);
499
500 let summary = model.export_summary().unwrap();
501 assert!(summary.contains("K-Means"));
502 assert!(summary.contains("\"n_clusters\": 2"));
503 }
504
505 #[test]
506 fn test_format_detection() {
507 assert_eq!(
508 detect_format_from_extension("model.json").unwrap(),
509 ExportFormat::Json
510 );
511 assert_eq!(
512 detect_format_from_extension("model.yaml").unwrap(),
513 ExportFormat::Yaml
514 );
515 assert_eq!(
516 detect_format_from_extension("model.csv").unwrap(),
517 ExportFormat::Csv
518 );
519 }
520
521 #[test]
522 fn test_sklearn_compatibility() {
523 let centroids = Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 1.0, 1.0]).unwrap();
524 let model = KMeansModel::new(centroids, 2, 10, 0.5, None);
525
526 let sklearn_format = model.export_compatible("sklearn").unwrap();
527 assert!(sklearn_format.get("cluster_centers_").is_some());
528 assert!(sklearn_format.get("_sklearn_version").is_some());
529 }
530}