scirs2_cluster/serialization/
export.rs

1//! Advanced export functionality for clustering models
2//!
3//! This module provides sophisticated export capabilities including
4//! multiple formats, metadata enrichment, and cross-platform compatibility.
5
6use crate::error::{ClusteringError, Result};
7use scirs2_core::ndarray::{Array1, Array2};
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use std::collections::HashMap;
11use std::path::Path;
12
13#[cfg(feature = "yaml")]
14use serde_yaml;
15
16use super::core::{EnhancedModelMetadata, PlatformInfo, SerializableModel};
17use super::models::*;
18
19/// Export formats supported by the serialization system
20#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
21pub enum ExportFormat {
22    /// JSON format with full metadata
23    Json,
24    /// Compressed JSON format
25    JsonGz,
26    /// Binary format (MessagePack)
27    Binary,
28    /// CSV format (for simple models)
29    Csv,
30    /// YAML format
31    Yaml,
32    /// XML format
33    Xml,
34    /// HDF5 format (for large datasets)
35    Hdf5,
36    /// Custom format with user-defined structure
37    Custom(String),
38}
39
40/// Trait for advanced export capabilities
41pub trait AdvancedExport {
42    /// Export model to specified format with metadata
43    fn export_with_metadata(
44        &self,
45        format: ExportFormat,
46        metadata: Option<ModelMetadata>,
47    ) -> Result<Vec<u8>>;
48
49    /// Export to file with automatic format detection
50    fn export_to_file<P: AsRef<Path>>(&self, path: P) -> Result<()>;
51
52    /// Export model summary for quick inspection
53    fn export_summary(&self) -> Result<String>;
54
55    /// Export model in a format compatible with other libraries
56    fn export_compatible(&self, target_library: &str) -> Result<Value>;
57
58    /// Validate model before export
59    fn validate_for_export(&self) -> Result<()>;
60}
61
62/// Comprehensive model metadata for exports
63#[derive(Serialize, Deserialize, Debug, Clone)]
64pub struct ModelMetadata {
65    /// Model name and version
66    pub model_info: ModelInfo,
67    /// Algorithm configuration
68    pub algorithm_config: AlgorithmConfig,
69    /// Performance metrics
70    pub performance_metrics: PerformanceMetrics,
71    /// Data characteristics
72    pub data_characteristics: ModelDataCharacteristics,
73    /// Export settings
74    pub export_settings: ExportSettings,
75}
76
77/// Model information
78#[derive(Serialize, Deserialize, Debug, Clone)]
79pub struct ModelInfo {
80    /// Model name
81    pub name: String,
82    /// Model version
83    pub version: String,
84    /// Creation timestamp
85    pub created_at: String,
86    /// Author/creator
87    pub author: Option<String>,
88    /// Description
89    pub description: Option<String>,
90}
91
92/// Algorithm configuration details
93#[derive(Serialize, Deserialize, Debug, Clone)]
94pub struct AlgorithmConfig {
95    /// Algorithm name
96    pub algorithm: String,
97    /// Hyperparameters used
98    pub hyperparameters: HashMap<String, Value>,
99    /// Preprocessing steps applied
100    pub preprocessing: Vec<String>,
101    /// Random seed used
102    pub random_seed: Option<u64>,
103    /// Convergence criteria
104    pub convergence_criteria: Option<HashMap<String, f64>>,
105}
106
107/// Performance metrics collected during training
108#[derive(Serialize, Deserialize, Debug, Clone)]
109pub struct PerformanceMetrics {
110    /// Training time in seconds
111    pub training_time_seconds: f64,
112    /// Memory usage in MB
113    pub peak_memory_mb: f64,
114    /// CPU utilization percentage
115    pub cpu_utilization: f64,
116    /// Model quality metrics
117    pub quality_metrics: HashMap<String, f64>,
118    /// Convergence information
119    pub convergence_info: Option<ConvergenceInfo>,
120}
121
122/// Convergence information
123#[derive(Serialize, Deserialize, Debug, Clone)]
124pub struct ConvergenceInfo {
125    /// Whether algorithm converged
126    pub converged: bool,
127    /// Number of iterations to convergence
128    pub iterations: usize,
129    /// Final objective value
130    pub final_objective: f64,
131    /// Convergence tolerance achieved
132    pub tolerance_achieved: f64,
133}
134
135/// Data characteristics for model validation
136#[derive(Serialize, Deserialize, Debug, Clone)]
137pub struct ModelDataCharacteristics {
138    /// Number of samples
139    pub n_samples: usize,
140    /// Number of features
141    pub n_features: usize,
142    /// Feature names (if available)
143    pub feature_names: Option<Vec<String>>,
144    /// Data types for each feature
145    pub feature_types: Option<Vec<String>>,
146    /// Statistical summaries
147    pub feature_statistics: Option<HashMap<String, FeatureStats>>,
148}
149
150/// Statistical summary for a feature
151#[derive(Serialize, Deserialize, Debug, Clone)]
152pub struct FeatureStats {
153    /// Mean value
154    pub mean: f64,
155    /// Standard deviation
156    pub std: f64,
157    /// Minimum value
158    pub min: f64,
159    /// Maximum value
160    pub max: f64,
161    /// Missing value count
162    pub missing_count: usize,
163}
164
165/// Export settings and options
166#[derive(Serialize, Deserialize, Debug, Clone)]
167pub struct ExportSettings {
168    /// Include raw model data
169    pub include_raw_data: bool,
170    /// Include training data
171    pub include_training_data: bool,
172    /// Compression level (0-9)
173    pub compression_level: Option<u8>,
174    /// Precision for floating point values
175    pub float_precision: Option<usize>,
176    /// Custom export options
177    pub custom_options: HashMap<String, Value>,
178}
179
180impl Default for ExportSettings {
181    fn default() -> Self {
182        Self {
183            include_raw_data: true,
184            include_training_data: false,
185            compression_level: None,
186            float_precision: Some(6),
187            custom_options: HashMap::new(),
188        }
189    }
190}
191
192/// Implementation of AdvancedExport for KMeansModel
193impl AdvancedExport for KMeansModel {
194    fn export_with_metadata(
195        &self,
196        format: ExportFormat,
197        metadata: Option<ModelMetadata>,
198    ) -> Result<Vec<u8>> {
199        let export_data = KMeansExportData {
200            model: self.clone(),
201            metadata,
202            format_version: "1.0".to_string(),
203            export_timestamp: chrono::Utc::now().to_rfc3339(),
204        };
205
206        match format {
207            ExportFormat::Json => {
208                let json = serde_json::to_string_pretty(&export_data)
209                    .map_err(|e| ClusteringError::InvalidInput(e.to_string()))?;
210                Ok(json.into_bytes())
211            }
212            ExportFormat::JsonGz => {
213                let json = serde_json::to_string(&export_data)
214                    .map_err(|e| ClusteringError::InvalidInput(e.to_string()))?;
215
216                use flate2::write::GzEncoder;
217                use flate2::Compression;
218                use std::io::Write;
219
220                let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
221                encoder
222                    .write_all(json.as_bytes())
223                    .map_err(|e| ClusteringError::InvalidInput(e.to_string()))?;
224                encoder
225                    .finish()
226                    .map_err(|e| ClusteringError::InvalidInput(e.to_string()))
227            }
228            #[cfg(feature = "yaml")]
229            ExportFormat::Yaml => {
230                let yaml = serde_yaml::to_string(&export_data)
231                    .map_err(|e| ClusteringError::InvalidInput(e.to_string()))?;
232                Ok(yaml.into_bytes())
233            }
234            #[cfg(not(feature = "yaml"))]
235            ExportFormat::Yaml => Err(ClusteringError::InvalidInput(
236                "YAML support not enabled. Enable the 'yaml' feature".to_string(),
237            )),
238            ExportFormat::Csv => self.export_csv(),
239            _ => Err(ClusteringError::InvalidInput(format!(
240                "Unsupported export format: {:?}",
241                format
242            ))),
243        }
244    }
245
246    fn export_to_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
247        let path = path.as_ref();
248        let format = detect_format_from_extension(path)?;
249        let data = self.export_with_metadata(format, None)?;
250
251        std::fs::write(path, data)
252            .map_err(|e| ClusteringError::InvalidInput(format!("Failed to write file: {}", e)))
253    }
254
255    fn export_summary(&self) -> Result<String> {
256        let summary = KMeansSummary {
257            algorithm: "K-Means".to_string(),
258            n_clusters: self.n_clusters,
259            n_features: self.centroids.ncols(),
260            n_iterations: self.n_iter,
261            inertia: self.inertia,
262            has_labels: self.labels.is_some(),
263        };
264
265        serde_json::to_string_pretty(&summary)
266            .map_err(|e| ClusteringError::InvalidInput(e.to_string()))
267    }
268
269    fn export_compatible(&self, target_library: &str) -> Result<Value> {
270        match target_library.to_lowercase().as_str() {
271            "sklearn" | "scikit-learn" => self.to_sklearn_format(),
272            "tensorflow" | "tf" => self.to_tensorflow_format(),
273            "pytorch" => self.to_pytorch_format(),
274            _ => Err(ClusteringError::InvalidInput(format!(
275                "Unsupported target library: {}",
276                target_library
277            ))),
278        }
279    }
280
281    fn validate_for_export(&self) -> Result<()> {
282        if self.centroids.is_empty() {
283            return Err(ClusteringError::InvalidInput(
284                "Cannot export model with empty centroids".to_string(),
285            ));
286        }
287
288        if self.n_clusters == 0 {
289            return Err(ClusteringError::InvalidInput(
290                "Cannot export model with zero clusters".to_string(),
291            ));
292        }
293
294        if self.centroids.nrows() != self.n_clusters {
295            return Err(ClusteringError::InvalidInput(
296                "Centroids shape inconsistent with n_clusters".to_string(),
297            ));
298        }
299
300        Ok(())
301    }
302}
303
304/// Export data structure for K-Means
305#[derive(Serialize, Deserialize, Debug, Clone)]
306struct KMeansExportData {
307    model: KMeansModel,
308    metadata: Option<ModelMetadata>,
309    format_version: String,
310    export_timestamp: String,
311}
312
313/// Summary structure for K-Means
314#[derive(Serialize, Deserialize, Debug, Clone)]
315struct KMeansSummary {
316    algorithm: String,
317    n_clusters: usize,
318    n_features: usize,
319    n_iterations: usize,
320    inertia: f64,
321    has_labels: bool,
322}
323
324impl KMeansModel {
325    /// Export as CSV format
326    fn export_csv(&self) -> Result<Vec<u8>> {
327        let mut csv_content = String::new();
328
329        // Header
330        csv_content.push_str("cluster_id");
331        for i in 0..self.centroids.ncols() {
332            csv_content.push_str(&format!(",feature_{}", i));
333        }
334        csv_content.push('\n');
335
336        // Centroids data
337        for (cluster_id, centroid) in self.centroids.rows().into_iter().enumerate() {
338            csv_content.push_str(&cluster_id.to_string());
339            for value in centroid {
340                csv_content.push_str(&format!(",{:.6}", value));
341            }
342            csv_content.push('\n');
343        }
344
345        Ok(csv_content.into_bytes())
346    }
347
348    /// Convert to scikit-learn compatible format
349    fn to_sklearn_format(&self) -> Result<Value> {
350        use serde_json::json;
351
352        Ok(json!({
353            "cluster_centers_": self.centroids.as_slice().unwrap(),
354            "labels_": self.labels.as_ref().map(|l| l.as_slice().unwrap()),
355            "inertia_": self.inertia,
356            "n_iter_": self.n_iter,
357            "n_clusters": self.n_clusters,
358            "_sklearn_version": "1.0.0"
359        }))
360    }
361
362    /// Convert to TensorFlow compatible format
363    fn to_tensorflow_format(&self) -> Result<Value> {
364        use serde_json::json;
365
366        Ok(json!({
367            "centroids": {
368                "data": self.centroids.as_slice().unwrap(),
369                "shape": [self.centroids.nrows(), self.centroids.ncols()],
370                "dtype": "float64"
371            },
372            "metadata": {
373                "n_clusters": self.n_clusters,
374                "inertia": self.inertia,
375                "iterations": self.n_iter
376            }
377        }))
378    }
379
380    /// Convert to PyTorch compatible format
381    fn to_pytorch_format(&self) -> Result<Value> {
382        use serde_json::json;
383
384        Ok(json!({
385            "state_dict": {
386                "centroids": self.centroids.as_slice().unwrap()
387            },
388            "hyperparameters": {
389                "n_clusters": self.n_clusters
390            },
391            "metrics": {
392                "inertia": self.inertia,
393                "n_iter": self.n_iter
394            }
395        }))
396    }
397}
398
399/// Detect export format from file extension
400fn detect_format_from_extension<P: AsRef<Path>>(path: P) -> Result<ExportFormat> {
401    let path = path.as_ref();
402    let extension = path
403        .extension()
404        .and_then(|ext| ext.to_str())
405        .unwrap_or("")
406        .to_lowercase();
407
408    match extension.as_str() {
409        "json" => Ok(ExportFormat::Json),
410        "gz" | "json.gz" => Ok(ExportFormat::JsonGz),
411        "yaml" | "yml" => Ok(ExportFormat::Yaml),
412        "csv" => Ok(ExportFormat::Csv),
413        "xml" => Ok(ExportFormat::Xml),
414        "h5" | "hdf5" => Ok(ExportFormat::Hdf5),
415        _ => Err(ClusteringError::InvalidInput(format!(
416            "Unknown file extension: {}",
417            extension
418        ))),
419    }
420}
421
422/// Export utility functions
423pub mod utils {
424    use super::*;
425
426    /// Create default metadata for a model
427    pub fn create_default_metadata(algorithm_name: &str) -> ModelMetadata {
428        ModelMetadata {
429            model_info: ModelInfo {
430                name: format!("{}_model", algorithm_name),
431                version: "1.0.0".to_string(),
432                created_at: chrono::Utc::now().to_rfc3339(),
433                author: None,
434                description: None,
435            },
436            algorithm_config: AlgorithmConfig {
437                algorithm: algorithm_name.to_string(),
438                hyperparameters: HashMap::new(),
439                preprocessing: Vec::new(),
440                random_seed: None,
441                convergence_criteria: None,
442            },
443            performance_metrics: PerformanceMetrics {
444                training_time_seconds: 0.0,
445                peak_memory_mb: 0.0,
446                cpu_utilization: 0.0,
447                quality_metrics: HashMap::new(),
448                convergence_info: None,
449            },
450            data_characteristics: ModelDataCharacteristics {
451                n_samples: 0,
452                n_features: 0,
453                feature_names: None,
454                feature_types: None,
455                feature_statistics: None,
456            },
457            export_settings: ExportSettings::default(),
458        }
459    }
460
461    /// Compress data using gzip
462    pub fn compress_data(data: &[u8]) -> Result<Vec<u8>> {
463        use flate2::write::GzEncoder;
464        use flate2::Compression;
465        use std::io::Write;
466
467        let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
468        encoder
469            .write_all(data)
470            .map_err(|e| ClusteringError::InvalidInput(e.to_string()))?;
471        encoder
472            .finish()
473            .map_err(|e| ClusteringError::InvalidInput(e.to_string()))
474    }
475
476    /// Decompress gzip data
477    pub fn decompress_data(compressed: &[u8]) -> Result<Vec<u8>> {
478        use flate2::read::GzDecoder;
479        use std::io::Read;
480
481        let mut decoder = GzDecoder::new(compressed);
482        let mut decompressed = Vec::new();
483        decoder
484            .read_to_end(&mut decompressed)
485            .map_err(|e| ClusteringError::InvalidInput(e.to_string()))?;
486        Ok(decompressed)
487    }
488}
489
490#[cfg(test)]
491mod tests {
492    use super::*;
493    use scirs2_core::ndarray::Array2;
494
495    #[test]
496    fn test_kmeans_export_summary() {
497        let centroids = Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 1.0, 1.0]).unwrap();
498        let model = KMeansModel::new(centroids, 2, 10, 0.5, None);
499
500        let summary = model.export_summary().unwrap();
501        assert!(summary.contains("K-Means"));
502        assert!(summary.contains("\"n_clusters\": 2"));
503    }
504
505    #[test]
506    fn test_format_detection() {
507        assert_eq!(
508            detect_format_from_extension("model.json").unwrap(),
509            ExportFormat::Json
510        );
511        assert_eq!(
512            detect_format_from_extension("model.yaml").unwrap(),
513            ExportFormat::Yaml
514        );
515        assert_eq!(
516            detect_format_from_extension("model.csv").unwrap(),
517            ExportFormat::Csv
518        );
519    }
520
521    #[test]
522    fn test_sklearn_compatibility() {
523        let centroids = Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 1.0, 1.0]).unwrap();
524        let model = KMeansModel::new(centroids, 2, 10, 0.5, None);
525
526        let sklearn_format = model.export_compatible("sklearn").unwrap();
527        assert!(sklearn_format.get("cluster_centers_").is_some());
528        assert!(sklearn_format.get("_sklearn_version").is_some());
529    }
530}