sklears_inspection/
serialization.rs

1//! Serialization support for explanation results
2//!
3//! This module provides serialization and deserialization capabilities
4//! for explanation results, allowing them to be saved to disk and loaded later.
5
6use crate::types::*;
7// ✅ SciRS2 Policy Compliant Import
8use scirs2_core::ndarray::{Array1, Array2};
9use serde::{Deserialize, Serialize};
10use sklears_core::prelude::SklearsError;
11use std::collections::HashMap;
12use std::fs;
13use std::path::Path;
14
15/// Serializable wrapper for explanation results
16#[derive(Serialize, Deserialize, Clone, Debug)]
17pub struct SerializableExplanationResult {
18    /// Unique identifier for this explanation
19    pub id: String,
20    /// Explanation method used
21    pub method: String,
22    /// Timestamp when explanation was generated
23    pub timestamp: chrono::DateTime<chrono::Utc>,
24    /// Feature importance values
25    pub feature_importance: Vec<Float>,
26    /// Feature names (if available)
27    pub feature_names: Option<Vec<String>>,
28    /// SHAP values (if computed)
29    pub shap_values: Option<Vec<Vec<Float>>>,
30    /// Confidence intervals
31    pub confidence_intervals: Option<Vec<(Float, Float)>>,
32    /// Model information
33    pub model_info: ModelMetadata,
34    /// Dataset information
35    pub dataset_info: DatasetMetadata,
36    /// Explanation configuration
37    pub config: ExplanationConfiguration,
38    /// Additional metadata
39    pub metadata: HashMap<String, String>,
40}
41
42/// Model metadata for serialization
43#[derive(Serialize, Deserialize, Clone, Debug)]
44pub struct ModelMetadata {
45    /// Model type
46    pub model_type: String,
47    /// Model version
48    pub version: String,
49    /// Training accuracy (if available)
50    pub training_accuracy: Option<Float>,
51    /// Validation accuracy (if available)
52    pub validation_accuracy: Option<Float>,
53    /// Number of parameters
54    pub num_parameters: Option<usize>,
55    /// Additional model-specific metadata
56    pub additional_info: HashMap<String, String>,
57}
58
59/// Dataset metadata for serialization
60#[derive(Serialize, Deserialize, Clone, Debug, Default)]
61pub struct DatasetMetadata {
62    /// Number of samples
63    pub num_samples: usize,
64    /// Number of features
65    pub num_features: usize,
66    /// Dataset name
67    pub name: Option<String>,
68    /// Feature types
69    pub feature_types: Option<Vec<String>>,
70    /// Target variable info
71    pub target_info: Option<String>,
72    /// Data statistics
73    pub statistics: Option<DataStatistics>,
74}
75
76/// Data statistics
77#[derive(Serialize, Deserialize, Clone, Debug)]
78pub struct DataStatistics {
79    /// Feature means
80    pub feature_means: Vec<Float>,
81    /// Feature standard deviations
82    pub feature_stds: Vec<Float>,
83    /// Feature min values
84    pub feature_mins: Vec<Float>,
85    /// Feature max values
86    pub feature_maxs: Vec<Float>,
87    /// Missing value counts
88    pub missing_counts: Vec<usize>,
89}
90
91/// Explanation configuration
92#[derive(Serialize, Deserialize, Clone, Debug)]
93pub struct ExplanationConfiguration {
94    /// Method name
95    pub method_name: String,
96    /// Configuration parameters
97    pub parameters: HashMap<String, String>,
98    /// Random seed (if used)
99    pub random_seed: Option<u64>,
100    /// Computation time
101    pub computation_time_ms: Option<u64>,
102    /// Number of samples used
103    pub num_samples_used: Option<usize>,
104}
105
106/// Serialization format options
107#[derive(Clone, Debug, PartialEq)]
108pub enum SerializationFormat {
109    /// JSON format
110    Json,
111    /// Binary format (MessagePack)
112    Binary,
113    /// CSV format (limited functionality)
114    Csv,
115    /// Parquet format (for large datasets)
116    Parquet,
117}
118
119/// Compression options
120#[derive(Clone, Debug, PartialEq)]
121pub enum CompressionType {
122    /// No compression
123    None,
124    /// Gzip compression
125    Gzip,
126    /// LZ4 compression
127    Lz4,
128    /// Zstd compression
129    Zstd,
130}
131
132/// Serialization configuration
133#[derive(Clone, Debug)]
134pub struct SerializationConfig {
135    /// Output format
136    pub format: SerializationFormat,
137    /// Compression type
138    pub compression: CompressionType,
139    /// Include raw data
140    pub include_raw_data: bool,
141    /// Include intermediate results
142    pub include_intermediate: bool,
143    /// Precision for floating point numbers
144    pub float_precision: usize,
145}
146
147impl Default for SerializationConfig {
148    fn default() -> Self {
149        Self {
150            format: SerializationFormat::Json,
151            compression: CompressionType::None,
152            include_raw_data: false,
153            include_intermediate: false,
154            float_precision: 6,
155        }
156    }
157}
158
159impl SerializableExplanationResult {
160    /// Create a new serializable explanation result
161    pub fn new(
162        id: String,
163        method: String,
164        feature_importance: Array1<Float>,
165        feature_names: Option<Vec<String>>,
166    ) -> Self {
167        Self {
168            id,
169            method,
170            timestamp: chrono::Utc::now(),
171            feature_importance: feature_importance.to_vec(),
172            feature_names,
173            shap_values: None,
174            confidence_intervals: None,
175            model_info: ModelMetadata::default(),
176            dataset_info: DatasetMetadata::default(),
177            config: ExplanationConfiguration::default(),
178            metadata: HashMap::new(),
179        }
180    }
181
182    /// Add SHAP values to the result
183    pub fn with_shap_values(mut self, shap_values: Array2<Float>) -> Self {
184        self.shap_values = Some(
185            shap_values
186                .rows()
187                .into_iter()
188                .map(|row| row.to_vec())
189                .collect(),
190        );
191        self
192    }
193
194    /// Add confidence intervals
195    pub fn with_confidence_intervals(mut self, intervals: Vec<(Float, Float)>) -> Self {
196        self.confidence_intervals = Some(intervals);
197        self
198    }
199
200    /// Add model metadata
201    pub fn with_model_info(mut self, model_info: ModelMetadata) -> Self {
202        self.model_info = model_info;
203        self
204    }
205
206    /// Add dataset metadata
207    pub fn with_dataset_info(mut self, dataset_info: DatasetMetadata) -> Self {
208        self.dataset_info = dataset_info;
209        self
210    }
211
212    /// Add configuration information
213    pub fn with_config(mut self, config: ExplanationConfiguration) -> Self {
214        self.config = config;
215        self
216    }
217
218    /// Add custom metadata
219    pub fn with_metadata(mut self, key: String, value: String) -> Self {
220        self.metadata.insert(key, value);
221        self
222    }
223
224    /// Get feature importance as Array1
225    pub fn get_feature_importance(&self) -> Array1<Float> {
226        Array1::from_vec(self.feature_importance.clone())
227    }
228
229    /// Get SHAP values as Array2 (if available)
230    pub fn get_shap_values(&self) -> Option<Array2<Float>> {
231        self.shap_values.as_ref().map(|values| {
232            let rows = values.len();
233            let cols = values.first().map(|row| row.len()).unwrap_or(0);
234            let flat: Vec<Float> = values.iter().flatten().copied().collect();
235            Array2::from_shape_vec((rows, cols), flat).unwrap()
236        })
237    }
238
239    /// Serialize to JSON string
240    pub fn to_json(&self) -> crate::SklResult<String> {
241        serde_json::to_string_pretty(self)
242            .map_err(|e| SklearsError::InvalidInput(format!("Failed to serialize to JSON: {}", e)))
243    }
244
245    /// Deserialize from JSON string
246    pub fn from_json(json: &str) -> crate::SklResult<Self> {
247        serde_json::from_str(json).map_err(|e| {
248            SklearsError::InvalidInput(format!("Failed to deserialize from JSON: {}", e))
249        })
250    }
251
252    /// Save to file
253    pub fn save_to_file<P: AsRef<Path>>(
254        &self,
255        path: P,
256        config: &SerializationConfig,
257    ) -> crate::SklResult<()> {
258        let content = match config.format {
259            SerializationFormat::Json => self.to_json()?,
260            SerializationFormat::Binary => {
261                return Err(SklearsError::InvalidInput(
262                    "Binary format not yet implemented".to_string(),
263                ));
264            }
265            SerializationFormat::Csv => self.to_csv()?,
266            SerializationFormat::Parquet => {
267                return Err(SklearsError::InvalidInput(
268                    "Parquet format not yet implemented".to_string(),
269                ));
270            }
271        };
272
273        // Apply compression if needed
274        let final_content = match config.compression {
275            CompressionType::None => content.into_bytes(),
276            CompressionType::Gzip => {
277                return Err(SklearsError::InvalidInput(
278                    "Gzip compression not yet implemented".to_string(),
279                ));
280            }
281            CompressionType::Lz4 => {
282                return Err(SklearsError::InvalidInput(
283                    "LZ4 compression not yet implemented".to_string(),
284                ));
285            }
286            CompressionType::Zstd => {
287                return Err(SklearsError::InvalidInput(
288                    "Zstd compression not yet implemented".to_string(),
289                ));
290            }
291        };
292
293        fs::write(path, final_content)
294            .map_err(|e| SklearsError::InvalidInput(format!("Failed to write file: {}", e)))?;
295
296        Ok(())
297    }
298
299    /// Load from file
300    pub fn load_from_file<P: AsRef<Path>>(
301        path: P,
302        config: &SerializationConfig,
303    ) -> crate::SklResult<Self> {
304        let content = fs::read(path)
305            .map_err(|e| SklearsError::InvalidInput(format!("Failed to read file: {}", e)))?;
306
307        // Decompress if needed
308        let decompressed_content = match config.compression {
309            CompressionType::None => String::from_utf8(content).map_err(|e| {
310                SklearsError::InvalidInput(format!("Failed to decode UTF-8: {}", e))
311            })?,
312            CompressionType::Gzip => {
313                return Err(SklearsError::InvalidInput(
314                    "Gzip decompression not yet implemented".to_string(),
315                ));
316            }
317            CompressionType::Lz4 => {
318                return Err(SklearsError::InvalidInput(
319                    "LZ4 decompression not yet implemented".to_string(),
320                ));
321            }
322            CompressionType::Zstd => {
323                return Err(SklearsError::InvalidInput(
324                    "Zstd decompression not yet implemented".to_string(),
325                ));
326            }
327        };
328
329        match config.format {
330            SerializationFormat::Json => Self::from_json(&decompressed_content),
331            SerializationFormat::Binary => Err(SklearsError::InvalidInput(
332                "Binary format not yet implemented".to_string(),
333            )),
334            SerializationFormat::Csv => Self::from_csv(&decompressed_content),
335            SerializationFormat::Parquet => Err(SklearsError::InvalidInput(
336                "Parquet format not yet implemented".to_string(),
337            )),
338        }
339    }
340
341    /// Convert to CSV format (simplified)
342    pub fn to_csv(&self) -> crate::SklResult<String> {
343        let mut csv = String::new();
344
345        // Header
346        csv.push_str("feature_index,feature_name,importance\n");
347
348        // Data rows
349        for (idx, importance) in self.feature_importance.iter().enumerate() {
350            let feature_name = self
351                .feature_names
352                .as_ref()
353                .and_then(|names| names.get(idx))
354                .map(|s| s.as_str())
355                .unwrap_or("unknown");
356
357            csv.push_str(&format!("{},{},{}\n", idx, feature_name, importance));
358        }
359
360        Ok(csv)
361    }
362
363    /// Load from CSV format (simplified)
364    pub fn from_csv(csv: &str) -> crate::SklResult<Self> {
365        let lines: Vec<&str> = csv.lines().collect();
366        if lines.is_empty() {
367            return Err(SklearsError::InvalidInput("Empty CSV content".to_string()));
368        }
369
370        let mut feature_importance = Vec::new();
371        let mut feature_names = Vec::new();
372
373        // Skip header and parse data
374        for line in lines.iter().skip(1) {
375            let parts: Vec<&str> = line.split(',').collect();
376            if parts.len() >= 3 {
377                let importance: Float = parts[2].parse().map_err(|_| {
378                    SklearsError::InvalidInput("Invalid importance value in CSV".to_string())
379                })?;
380                feature_importance.push(importance);
381                feature_names.push(parts[1].to_string());
382            }
383        }
384
385        Ok(Self::new(
386            "csv_import".to_string(),
387            "unknown".to_string(),
388            Array1::from_vec(feature_importance),
389            Some(feature_names),
390        ))
391    }
392
393    /// Get summary statistics
394    pub fn get_summary(&self) -> SerializationSummary {
395        let importance_array = self.get_feature_importance();
396
397        SerializationSummary {
398            method: self.method.clone(),
399            num_features: self.feature_importance.len(),
400            timestamp: self.timestamp,
401            max_importance: importance_array
402                .iter()
403                .cloned()
404                .fold(Float::NEG_INFINITY, Float::max),
405            min_importance: importance_array
406                .iter()
407                .cloned()
408                .fold(Float::INFINITY, Float::min),
409            mean_importance: importance_array.mean().unwrap_or(0.0),
410            std_importance: importance_array.std(0.0),
411            has_shap_values: self.shap_values.is_some(),
412            has_confidence_intervals: self.confidence_intervals.is_some(),
413        }
414    }
415}
416
417/// Summary information about an explanation
418#[derive(Serialize, Deserialize, Clone, Debug)]
419pub struct SerializationSummary {
420    /// Method used
421    pub method: String,
422    /// Number of features
423    pub num_features: usize,
424    /// Timestamp
425    pub timestamp: chrono::DateTime<chrono::Utc>,
426    /// Maximum importance value
427    pub max_importance: Float,
428    /// Minimum importance value
429    pub min_importance: Float,
430    /// Mean importance value
431    pub mean_importance: Float,
432    /// Standard deviation of importance
433    pub std_importance: Float,
434    /// Whether SHAP values are available
435    pub has_shap_values: bool,
436    /// Whether confidence intervals are available
437    pub has_confidence_intervals: bool,
438}
439
440impl Default for ModelMetadata {
441    fn default() -> Self {
442        Self {
443            model_type: "unknown".to_string(),
444            version: "1.0.0".to_string(),
445            training_accuracy: None,
446            validation_accuracy: None,
447            num_parameters: None,
448            additional_info: HashMap::new(),
449        }
450    }
451}
452
453impl Default for ExplanationConfiguration {
454    fn default() -> Self {
455        Self {
456            method_name: "unknown".to_string(),
457            parameters: HashMap::new(),
458            random_seed: None,
459            computation_time_ms: None,
460            num_samples_used: None,
461        }
462    }
463}
464
465/// Batch serialization for multiple explanation results
466pub struct ExplanationBatch {
467    /// List of explanation results
468    pub results: Vec<SerializableExplanationResult>,
469    /// Batch metadata
470    pub metadata: HashMap<String, String>,
471    /// Creation timestamp
472    pub created_at: chrono::DateTime<chrono::Utc>,
473}
474
475impl Default for ExplanationBatch {
476    fn default() -> Self {
477        Self::new()
478    }
479}
480
481impl ExplanationBatch {
482    /// Create a new explanation batch
483    pub fn new() -> Self {
484        Self {
485            results: Vec::new(),
486            metadata: HashMap::new(),
487            created_at: chrono::Utc::now(),
488        }
489    }
490
491    /// Add an explanation result to the batch
492    pub fn add_result(&mut self, result: SerializableExplanationResult) {
493        self.results.push(result);
494    }
495
496    /// Add metadata to the batch
497    pub fn add_metadata(&mut self, key: String, value: String) {
498        self.metadata.insert(key, value);
499    }
500
501    /// Save batch to directory
502    pub fn save_to_directory<P: AsRef<Path>>(
503        &self,
504        directory: P,
505        config: &SerializationConfig,
506    ) -> crate::SklResult<()> {
507        let dir_path = directory.as_ref();
508        fs::create_dir_all(dir_path).map_err(|e| {
509            SklearsError::InvalidInput(format!("Failed to create directory: {}", e))
510        })?;
511
512        // Save each result as a separate file
513        for (idx, result) in self.results.iter().enumerate() {
514            let filename = format!("explanation_{:03}.json", idx);
515            let file_path = dir_path.join(filename);
516            result.save_to_file(file_path, config)?;
517        }
518
519        // Save batch metadata
520        let metadata_path = dir_path.join("batch_metadata.json");
521        let metadata_json = serde_json::to_string_pretty(&self.metadata).map_err(|e| {
522            SklearsError::InvalidInput(format!("Failed to serialize metadata: {}", e))
523        })?;
524        fs::write(metadata_path, metadata_json)
525            .map_err(|e| SklearsError::InvalidInput(format!("Failed to write metadata: {}", e)))?;
526
527        Ok(())
528    }
529
530    /// Load batch from directory
531    pub fn load_from_directory<P: AsRef<Path>>(
532        directory: P,
533        config: &SerializationConfig,
534    ) -> crate::SklResult<Self> {
535        let dir_path = directory.as_ref();
536
537        let mut batch = ExplanationBatch::new();
538
539        // Load all explanation files
540        let entries = fs::read_dir(dir_path)
541            .map_err(|e| SklearsError::InvalidInput(format!("Failed to read directory: {}", e)))?;
542
543        for entry in entries {
544            let entry = entry.map_err(|e| {
545                SklearsError::InvalidInput(format!("Failed to read directory entry: {}", e))
546            })?;
547
548            let path = entry.path();
549            if let Some(filename) = path.file_name().and_then(|n| n.to_str()) {
550                if filename.starts_with("explanation_") && filename.ends_with(".json") {
551                    let result = SerializableExplanationResult::load_from_file(&path, config)?;
552                    batch.add_result(result);
553                }
554            }
555        }
556
557        // Load metadata if available
558        let metadata_path = dir_path.join("batch_metadata.json");
559        if metadata_path.exists() {
560            let metadata_content = fs::read_to_string(metadata_path).map_err(|e| {
561                SklearsError::InvalidInput(format!("Failed to read metadata: {}", e))
562            })?;
563
564            let metadata: HashMap<String, String> = serde_json::from_str(&metadata_content)
565                .map_err(|e| {
566                    SklearsError::InvalidInput(format!("Failed to parse metadata: {}", e))
567                })?;
568
569            batch.metadata = metadata;
570        }
571
572        Ok(batch)
573    }
574
575    /// Get summary of all results in the batch
576    pub fn get_batch_summary(&self) -> BatchSummary {
577        let summaries: Vec<SerializationSummary> = self
578            .results
579            .iter()
580            .map(|result| result.get_summary())
581            .collect();
582
583        let methods: Vec<String> = summaries.iter().map(|s| s.method.clone()).collect();
584        let unique_methods: std::collections::HashSet<String> = methods.iter().cloned().collect();
585
586        BatchSummary {
587            num_results: self.results.len(),
588            methods_used: unique_methods.into_iter().collect(),
589            created_at: self.created_at,
590            total_features: summaries.iter().map(|s| s.num_features).sum(),
591            has_metadata: !self.metadata.is_empty(),
592        }
593    }
594}
595
596/// Batch summary information
597#[derive(Serialize, Deserialize, Clone, Debug)]
598pub struct BatchSummary {
599    /// Number of results in the batch
600    pub num_results: usize,
601    /// Methods used in the batch
602    pub methods_used: Vec<String>,
603    /// Batch creation timestamp
604    pub created_at: chrono::DateTime<chrono::Utc>,
605    /// Total number of features across all results
606    pub total_features: usize,
607    /// Whether the batch has metadata
608    pub has_metadata: bool,
609}
610
611#[cfg(test)]
612mod tests {
613    use super::*;
614    use approx::assert_abs_diff_eq;
615    // ✅ SciRS2 Policy Compliant Import
616    use scirs2_core::ndarray::array;
617    use tempfile::tempdir;
618
619    #[test]
620    fn test_serializable_explanation_result_creation() {
621        let feature_importance = array![0.5, 0.3, 0.2];
622        let feature_names = vec![
623            "feature1".to_string(),
624            "feature2".to_string(),
625            "feature3".to_string(),
626        ];
627
628        let result = SerializableExplanationResult::new(
629            "test_id".to_string(),
630            "permutation".to_string(),
631            feature_importance.clone(),
632            Some(feature_names.clone()),
633        );
634
635        assert_eq!(result.id, "test_id");
636        assert_eq!(result.method, "permutation");
637        assert_eq!(result.feature_importance, vec![0.5, 0.3, 0.2]);
638        assert_eq!(result.feature_names, Some(feature_names));
639    }
640
641    #[test]
642    fn test_shap_values_conversion() {
643        let feature_importance = array![0.5, 0.3];
644        let shap_values = array![[0.1, 0.2], [0.3, 0.4]];
645
646        let result = SerializableExplanationResult::new(
647            "test_id".to_string(),
648            "shap".to_string(),
649            feature_importance,
650            None,
651        )
652        .with_shap_values(shap_values.clone());
653
654        let recovered_shap = result.get_shap_values().unwrap();
655        assert_eq!(recovered_shap, shap_values);
656    }
657
658    #[test]
659    fn test_json_serialization() {
660        let feature_importance = array![0.5, 0.3, 0.2];
661        let result = SerializableExplanationResult::new(
662            "test_id".to_string(),
663            "permutation".to_string(),
664            feature_importance,
665            None,
666        );
667
668        let json = result.to_json().unwrap();
669        let recovered = SerializableExplanationResult::from_json(&json).unwrap();
670
671        assert_eq!(result.id, recovered.id);
672        assert_eq!(result.method, recovered.method);
673        assert_eq!(result.feature_importance, recovered.feature_importance);
674    }
675
676    #[test]
677    fn test_csv_serialization() {
678        let feature_importance = array![0.5, 0.3, 0.2];
679        let feature_names = vec!["f1".to_string(), "f2".to_string(), "f3".to_string()];
680
681        let result = SerializableExplanationResult::new(
682            "test_id".to_string(),
683            "permutation".to_string(),
684            feature_importance,
685            Some(feature_names),
686        );
687
688        let csv = result.to_csv().unwrap();
689        let recovered = SerializableExplanationResult::from_csv(&csv).unwrap();
690
691        assert_eq!(
692            result.feature_importance.len(),
693            recovered.feature_importance.len()
694        );
695        for (a, b) in result
696            .feature_importance
697            .iter()
698            .zip(recovered.feature_importance.iter())
699        {
700            assert_abs_diff_eq!(*a, *b, epsilon = 1e-6);
701        }
702    }
703
704    #[test]
705    fn test_file_save_and_load() {
706        let temp_dir = tempdir().unwrap();
707        let file_path = temp_dir.path().join("test_explanation.json");
708
709        let feature_importance = array![0.5, 0.3, 0.2];
710        let result = SerializableExplanationResult::new(
711            "test_id".to_string(),
712            "permutation".to_string(),
713            feature_importance,
714            None,
715        );
716
717        let config = SerializationConfig::default();
718        result.save_to_file(&file_path, &config).unwrap();
719
720        let loaded = SerializableExplanationResult::load_from_file(&file_path, &config).unwrap();
721        assert_eq!(result.id, loaded.id);
722        assert_eq!(result.method, loaded.method);
723        assert_eq!(result.feature_importance, loaded.feature_importance);
724    }
725
726    #[test]
727    fn test_explanation_summary() {
728        let feature_importance = array![0.5, 0.3, 0.8, 0.1];
729        let result = SerializableExplanationResult::new(
730            "test_id".to_string(),
731            "permutation".to_string(),
732            feature_importance,
733            None,
734        );
735
736        let summary = result.get_summary();
737        assert_eq!(summary.method, "permutation");
738        assert_eq!(summary.num_features, 4);
739        assert_eq!(summary.max_importance, 0.8);
740        assert_eq!(summary.min_importance, 0.1);
741        assert!(!summary.has_shap_values);
742        assert!(!summary.has_confidence_intervals);
743    }
744
745    #[test]
746    fn test_explanation_batch() {
747        let mut batch = ExplanationBatch::new();
748
749        let result1 = SerializableExplanationResult::new(
750            "test_1".to_string(),
751            "permutation".to_string(),
752            array![0.5, 0.3],
753            None,
754        );
755
756        let result2 = SerializableExplanationResult::new(
757            "test_2".to_string(),
758            "shap".to_string(),
759            array![0.2, 0.8],
760            None,
761        );
762
763        batch.add_result(result1);
764        batch.add_result(result2);
765        batch.add_metadata("experiment".to_string(), "test_run".to_string());
766
767        let summary = batch.get_batch_summary();
768        assert_eq!(summary.num_results, 2);
769        assert!(summary.methods_used.contains(&"permutation".to_string()));
770        assert!(summary.methods_used.contains(&"shap".to_string()));
771        assert!(summary.has_metadata);
772    }
773
774    #[test]
775    fn test_batch_save_and_load() {
776        let temp_dir = tempdir().unwrap();
777
778        let mut batch = ExplanationBatch::new();
779        batch.add_result(SerializableExplanationResult::new(
780            "test_1".to_string(),
781            "permutation".to_string(),
782            array![0.5, 0.3],
783            None,
784        ));
785        batch.add_metadata("experiment".to_string(), "test_batch".to_string());
786
787        let config = SerializationConfig::default();
788        batch.save_to_directory(temp_dir.path(), &config).unwrap();
789
790        let loaded_batch = ExplanationBatch::load_from_directory(temp_dir.path(), &config).unwrap();
791        assert_eq!(loaded_batch.results.len(), 1);
792        assert_eq!(
793            loaded_batch.metadata.get("experiment"),
794            Some(&"test_batch".to_string())
795        );
796    }
797
798    #[test]
799    fn test_serialization_config_default() {
800        let config = SerializationConfig::default();
801        assert_eq!(config.format, SerializationFormat::Json);
802        assert_eq!(config.compression, CompressionType::None);
803        assert!(!config.include_raw_data);
804        assert_eq!(config.float_precision, 6);
805    }
806
807    #[test]
808    fn test_model_metadata_default() {
809        let metadata = ModelMetadata::default();
810        assert_eq!(metadata.model_type, "unknown");
811        assert_eq!(metadata.version, "1.0.0");
812        assert!(metadata.training_accuracy.is_none());
813    }
814
815    #[test]
816    fn test_dataset_metadata_default() {
817        let metadata = DatasetMetadata::default();
818        assert_eq!(metadata.num_samples, 0);
819        assert_eq!(metadata.num_features, 0);
820        assert!(metadata.name.is_none());
821    }
822
823    #[test]
824    fn test_with_methods() {
825        let feature_importance = array![0.5, 0.3];
826        let model_info = ModelMetadata {
827            model_type: "linear_regression".to_string(),
828            ..Default::default()
829        };
830
831        let result = SerializableExplanationResult::new(
832            "test_id".to_string(),
833            "permutation".to_string(),
834            feature_importance,
835            None,
836        )
837        .with_model_info(model_info.clone())
838        .with_metadata("key1".to_string(), "value1".to_string());
839
840        assert_eq!(result.model_info.model_type, "linear_regression");
841        assert_eq!(result.metadata.get("key1"), Some(&"value1".to_string()));
842    }
843}