1use crate::types::*;
7use scirs2_core::ndarray::{Array1, Array2};
9use serde::{Deserialize, Serialize};
10use sklears_core::prelude::SklearsError;
11use std::collections::HashMap;
12use std::fs;
13use std::path::Path;
14
15#[derive(Serialize, Deserialize, Clone, Debug)]
17pub struct SerializableExplanationResult {
18 pub id: String,
20 pub method: String,
22 pub timestamp: chrono::DateTime<chrono::Utc>,
24 pub feature_importance: Vec<Float>,
26 pub feature_names: Option<Vec<String>>,
28 pub shap_values: Option<Vec<Vec<Float>>>,
30 pub confidence_intervals: Option<Vec<(Float, Float)>>,
32 pub model_info: ModelMetadata,
34 pub dataset_info: DatasetMetadata,
36 pub config: ExplanationConfiguration,
38 pub metadata: HashMap<String, String>,
40}
41
42#[derive(Serialize, Deserialize, Clone, Debug)]
44pub struct ModelMetadata {
45 pub model_type: String,
47 pub version: String,
49 pub training_accuracy: Option<Float>,
51 pub validation_accuracy: Option<Float>,
53 pub num_parameters: Option<usize>,
55 pub additional_info: HashMap<String, String>,
57}
58
59#[derive(Serialize, Deserialize, Clone, Debug, Default)]
61pub struct DatasetMetadata {
62 pub num_samples: usize,
64 pub num_features: usize,
66 pub name: Option<String>,
68 pub feature_types: Option<Vec<String>>,
70 pub target_info: Option<String>,
72 pub statistics: Option<DataStatistics>,
74}
75
76#[derive(Serialize, Deserialize, Clone, Debug)]
78pub struct DataStatistics {
79 pub feature_means: Vec<Float>,
81 pub feature_stds: Vec<Float>,
83 pub feature_mins: Vec<Float>,
85 pub feature_maxs: Vec<Float>,
87 pub missing_counts: Vec<usize>,
89}
90
91#[derive(Serialize, Deserialize, Clone, Debug)]
93pub struct ExplanationConfiguration {
94 pub method_name: String,
96 pub parameters: HashMap<String, String>,
98 pub random_seed: Option<u64>,
100 pub computation_time_ms: Option<u64>,
102 pub num_samples_used: Option<usize>,
104}
105
106#[derive(Clone, Debug, PartialEq)]
108pub enum SerializationFormat {
109 Json,
111 Binary,
113 Csv,
115 Parquet,
117}
118
119#[derive(Clone, Debug, PartialEq)]
121pub enum CompressionType {
122 None,
124 Gzip,
126 Lz4,
128 Zstd,
130}
131
132#[derive(Clone, Debug)]
134pub struct SerializationConfig {
135 pub format: SerializationFormat,
137 pub compression: CompressionType,
139 pub include_raw_data: bool,
141 pub include_intermediate: bool,
143 pub float_precision: usize,
145}
146
147impl Default for SerializationConfig {
148 fn default() -> Self {
149 Self {
150 format: SerializationFormat::Json,
151 compression: CompressionType::None,
152 include_raw_data: false,
153 include_intermediate: false,
154 float_precision: 6,
155 }
156 }
157}
158
159impl SerializableExplanationResult {
160 pub fn new(
162 id: String,
163 method: String,
164 feature_importance: Array1<Float>,
165 feature_names: Option<Vec<String>>,
166 ) -> Self {
167 Self {
168 id,
169 method,
170 timestamp: chrono::Utc::now(),
171 feature_importance: feature_importance.to_vec(),
172 feature_names,
173 shap_values: None,
174 confidence_intervals: None,
175 model_info: ModelMetadata::default(),
176 dataset_info: DatasetMetadata::default(),
177 config: ExplanationConfiguration::default(),
178 metadata: HashMap::new(),
179 }
180 }
181
182 pub fn with_shap_values(mut self, shap_values: Array2<Float>) -> Self {
184 self.shap_values = Some(
185 shap_values
186 .rows()
187 .into_iter()
188 .map(|row| row.to_vec())
189 .collect(),
190 );
191 self
192 }
193
194 pub fn with_confidence_intervals(mut self, intervals: Vec<(Float, Float)>) -> Self {
196 self.confidence_intervals = Some(intervals);
197 self
198 }
199
200 pub fn with_model_info(mut self, model_info: ModelMetadata) -> Self {
202 self.model_info = model_info;
203 self
204 }
205
206 pub fn with_dataset_info(mut self, dataset_info: DatasetMetadata) -> Self {
208 self.dataset_info = dataset_info;
209 self
210 }
211
212 pub fn with_config(mut self, config: ExplanationConfiguration) -> Self {
214 self.config = config;
215 self
216 }
217
218 pub fn with_metadata(mut self, key: String, value: String) -> Self {
220 self.metadata.insert(key, value);
221 self
222 }
223
224 pub fn get_feature_importance(&self) -> Array1<Float> {
226 Array1::from_vec(self.feature_importance.clone())
227 }
228
229 pub fn get_shap_values(&self) -> Option<Array2<Float>> {
231 self.shap_values.as_ref().map(|values| {
232 let rows = values.len();
233 let cols = values.first().map(|row| row.len()).unwrap_or(0);
234 let flat: Vec<Float> = values.iter().flatten().copied().collect();
235 Array2::from_shape_vec((rows, cols), flat).unwrap()
236 })
237 }
238
239 pub fn to_json(&self) -> crate::SklResult<String> {
241 serde_json::to_string_pretty(self)
242 .map_err(|e| SklearsError::InvalidInput(format!("Failed to serialize to JSON: {}", e)))
243 }
244
245 pub fn from_json(json: &str) -> crate::SklResult<Self> {
247 serde_json::from_str(json).map_err(|e| {
248 SklearsError::InvalidInput(format!("Failed to deserialize from JSON: {}", e))
249 })
250 }
251
252 pub fn save_to_file<P: AsRef<Path>>(
254 &self,
255 path: P,
256 config: &SerializationConfig,
257 ) -> crate::SklResult<()> {
258 let content = match config.format {
259 SerializationFormat::Json => self.to_json()?,
260 SerializationFormat::Binary => {
261 return Err(SklearsError::InvalidInput(
262 "Binary format not yet implemented".to_string(),
263 ));
264 }
265 SerializationFormat::Csv => self.to_csv()?,
266 SerializationFormat::Parquet => {
267 return Err(SklearsError::InvalidInput(
268 "Parquet format not yet implemented".to_string(),
269 ));
270 }
271 };
272
273 let final_content = match config.compression {
275 CompressionType::None => content.into_bytes(),
276 CompressionType::Gzip => {
277 return Err(SklearsError::InvalidInput(
278 "Gzip compression not yet implemented".to_string(),
279 ));
280 }
281 CompressionType::Lz4 => {
282 return Err(SklearsError::InvalidInput(
283 "LZ4 compression not yet implemented".to_string(),
284 ));
285 }
286 CompressionType::Zstd => {
287 return Err(SklearsError::InvalidInput(
288 "Zstd compression not yet implemented".to_string(),
289 ));
290 }
291 };
292
293 fs::write(path, final_content)
294 .map_err(|e| SklearsError::InvalidInput(format!("Failed to write file: {}", e)))?;
295
296 Ok(())
297 }
298
299 pub fn load_from_file<P: AsRef<Path>>(
301 path: P,
302 config: &SerializationConfig,
303 ) -> crate::SklResult<Self> {
304 let content = fs::read(path)
305 .map_err(|e| SklearsError::InvalidInput(format!("Failed to read file: {}", e)))?;
306
307 let decompressed_content = match config.compression {
309 CompressionType::None => String::from_utf8(content).map_err(|e| {
310 SklearsError::InvalidInput(format!("Failed to decode UTF-8: {}", e))
311 })?,
312 CompressionType::Gzip => {
313 return Err(SklearsError::InvalidInput(
314 "Gzip decompression not yet implemented".to_string(),
315 ));
316 }
317 CompressionType::Lz4 => {
318 return Err(SklearsError::InvalidInput(
319 "LZ4 decompression not yet implemented".to_string(),
320 ));
321 }
322 CompressionType::Zstd => {
323 return Err(SklearsError::InvalidInput(
324 "Zstd decompression not yet implemented".to_string(),
325 ));
326 }
327 };
328
329 match config.format {
330 SerializationFormat::Json => Self::from_json(&decompressed_content),
331 SerializationFormat::Binary => Err(SklearsError::InvalidInput(
332 "Binary format not yet implemented".to_string(),
333 )),
334 SerializationFormat::Csv => Self::from_csv(&decompressed_content),
335 SerializationFormat::Parquet => Err(SklearsError::InvalidInput(
336 "Parquet format not yet implemented".to_string(),
337 )),
338 }
339 }
340
341 pub fn to_csv(&self) -> crate::SklResult<String> {
343 let mut csv = String::new();
344
345 csv.push_str("feature_index,feature_name,importance\n");
347
348 for (idx, importance) in self.feature_importance.iter().enumerate() {
350 let feature_name = self
351 .feature_names
352 .as_ref()
353 .and_then(|names| names.get(idx))
354 .map(|s| s.as_str())
355 .unwrap_or("unknown");
356
357 csv.push_str(&format!("{},{},{}\n", idx, feature_name, importance));
358 }
359
360 Ok(csv)
361 }
362
363 pub fn from_csv(csv: &str) -> crate::SklResult<Self> {
365 let lines: Vec<&str> = csv.lines().collect();
366 if lines.is_empty() {
367 return Err(SklearsError::InvalidInput("Empty CSV content".to_string()));
368 }
369
370 let mut feature_importance = Vec::new();
371 let mut feature_names = Vec::new();
372
373 for line in lines.iter().skip(1) {
375 let parts: Vec<&str> = line.split(',').collect();
376 if parts.len() >= 3 {
377 let importance: Float = parts[2].parse().map_err(|_| {
378 SklearsError::InvalidInput("Invalid importance value in CSV".to_string())
379 })?;
380 feature_importance.push(importance);
381 feature_names.push(parts[1].to_string());
382 }
383 }
384
385 Ok(Self::new(
386 "csv_import".to_string(),
387 "unknown".to_string(),
388 Array1::from_vec(feature_importance),
389 Some(feature_names),
390 ))
391 }
392
393 pub fn get_summary(&self) -> SerializationSummary {
395 let importance_array = self.get_feature_importance();
396
397 SerializationSummary {
398 method: self.method.clone(),
399 num_features: self.feature_importance.len(),
400 timestamp: self.timestamp,
401 max_importance: importance_array
402 .iter()
403 .cloned()
404 .fold(Float::NEG_INFINITY, Float::max),
405 min_importance: importance_array
406 .iter()
407 .cloned()
408 .fold(Float::INFINITY, Float::min),
409 mean_importance: importance_array.mean().unwrap_or(0.0),
410 std_importance: importance_array.std(0.0),
411 has_shap_values: self.shap_values.is_some(),
412 has_confidence_intervals: self.confidence_intervals.is_some(),
413 }
414 }
415}
416
417#[derive(Serialize, Deserialize, Clone, Debug)]
419pub struct SerializationSummary {
420 pub method: String,
422 pub num_features: usize,
424 pub timestamp: chrono::DateTime<chrono::Utc>,
426 pub max_importance: Float,
428 pub min_importance: Float,
430 pub mean_importance: Float,
432 pub std_importance: Float,
434 pub has_shap_values: bool,
436 pub has_confidence_intervals: bool,
438}
439
440impl Default for ModelMetadata {
441 fn default() -> Self {
442 Self {
443 model_type: "unknown".to_string(),
444 version: "1.0.0".to_string(),
445 training_accuracy: None,
446 validation_accuracy: None,
447 num_parameters: None,
448 additional_info: HashMap::new(),
449 }
450 }
451}
452
453impl Default for ExplanationConfiguration {
454 fn default() -> Self {
455 Self {
456 method_name: "unknown".to_string(),
457 parameters: HashMap::new(),
458 random_seed: None,
459 computation_time_ms: None,
460 num_samples_used: None,
461 }
462 }
463}
464
465pub struct ExplanationBatch {
467 pub results: Vec<SerializableExplanationResult>,
469 pub metadata: HashMap<String, String>,
471 pub created_at: chrono::DateTime<chrono::Utc>,
473}
474
475impl Default for ExplanationBatch {
476 fn default() -> Self {
477 Self::new()
478 }
479}
480
481impl ExplanationBatch {
482 pub fn new() -> Self {
484 Self {
485 results: Vec::new(),
486 metadata: HashMap::new(),
487 created_at: chrono::Utc::now(),
488 }
489 }
490
491 pub fn add_result(&mut self, result: SerializableExplanationResult) {
493 self.results.push(result);
494 }
495
496 pub fn add_metadata(&mut self, key: String, value: String) {
498 self.metadata.insert(key, value);
499 }
500
501 pub fn save_to_directory<P: AsRef<Path>>(
503 &self,
504 directory: P,
505 config: &SerializationConfig,
506 ) -> crate::SklResult<()> {
507 let dir_path = directory.as_ref();
508 fs::create_dir_all(dir_path).map_err(|e| {
509 SklearsError::InvalidInput(format!("Failed to create directory: {}", e))
510 })?;
511
512 for (idx, result) in self.results.iter().enumerate() {
514 let filename = format!("explanation_{:03}.json", idx);
515 let file_path = dir_path.join(filename);
516 result.save_to_file(file_path, config)?;
517 }
518
519 let metadata_path = dir_path.join("batch_metadata.json");
521 let metadata_json = serde_json::to_string_pretty(&self.metadata).map_err(|e| {
522 SklearsError::InvalidInput(format!("Failed to serialize metadata: {}", e))
523 })?;
524 fs::write(metadata_path, metadata_json)
525 .map_err(|e| SklearsError::InvalidInput(format!("Failed to write metadata: {}", e)))?;
526
527 Ok(())
528 }
529
530 pub fn load_from_directory<P: AsRef<Path>>(
532 directory: P,
533 config: &SerializationConfig,
534 ) -> crate::SklResult<Self> {
535 let dir_path = directory.as_ref();
536
537 let mut batch = ExplanationBatch::new();
538
539 let entries = fs::read_dir(dir_path)
541 .map_err(|e| SklearsError::InvalidInput(format!("Failed to read directory: {}", e)))?;
542
543 for entry in entries {
544 let entry = entry.map_err(|e| {
545 SklearsError::InvalidInput(format!("Failed to read directory entry: {}", e))
546 })?;
547
548 let path = entry.path();
549 if let Some(filename) = path.file_name().and_then(|n| n.to_str()) {
550 if filename.starts_with("explanation_") && filename.ends_with(".json") {
551 let result = SerializableExplanationResult::load_from_file(&path, config)?;
552 batch.add_result(result);
553 }
554 }
555 }
556
557 let metadata_path = dir_path.join("batch_metadata.json");
559 if metadata_path.exists() {
560 let metadata_content = fs::read_to_string(metadata_path).map_err(|e| {
561 SklearsError::InvalidInput(format!("Failed to read metadata: {}", e))
562 })?;
563
564 let metadata: HashMap<String, String> = serde_json::from_str(&metadata_content)
565 .map_err(|e| {
566 SklearsError::InvalidInput(format!("Failed to parse metadata: {}", e))
567 })?;
568
569 batch.metadata = metadata;
570 }
571
572 Ok(batch)
573 }
574
575 pub fn get_batch_summary(&self) -> BatchSummary {
577 let summaries: Vec<SerializationSummary> = self
578 .results
579 .iter()
580 .map(|result| result.get_summary())
581 .collect();
582
583 let methods: Vec<String> = summaries.iter().map(|s| s.method.clone()).collect();
584 let unique_methods: std::collections::HashSet<String> = methods.iter().cloned().collect();
585
586 BatchSummary {
587 num_results: self.results.len(),
588 methods_used: unique_methods.into_iter().collect(),
589 created_at: self.created_at,
590 total_features: summaries.iter().map(|s| s.num_features).sum(),
591 has_metadata: !self.metadata.is_empty(),
592 }
593 }
594}
595
596#[derive(Serialize, Deserialize, Clone, Debug)]
598pub struct BatchSummary {
599 pub num_results: usize,
601 pub methods_used: Vec<String>,
603 pub created_at: chrono::DateTime<chrono::Utc>,
605 pub total_features: usize,
607 pub has_metadata: bool,
609}
610
611#[cfg(test)]
612mod tests {
613 use super::*;
614 use approx::assert_abs_diff_eq;
615 use scirs2_core::ndarray::array;
617 use tempfile::tempdir;
618
619 #[test]
620 fn test_serializable_explanation_result_creation() {
621 let feature_importance = array![0.5, 0.3, 0.2];
622 let feature_names = vec![
623 "feature1".to_string(),
624 "feature2".to_string(),
625 "feature3".to_string(),
626 ];
627
628 let result = SerializableExplanationResult::new(
629 "test_id".to_string(),
630 "permutation".to_string(),
631 feature_importance.clone(),
632 Some(feature_names.clone()),
633 );
634
635 assert_eq!(result.id, "test_id");
636 assert_eq!(result.method, "permutation");
637 assert_eq!(result.feature_importance, vec![0.5, 0.3, 0.2]);
638 assert_eq!(result.feature_names, Some(feature_names));
639 }
640
641 #[test]
642 fn test_shap_values_conversion() {
643 let feature_importance = array![0.5, 0.3];
644 let shap_values = array![[0.1, 0.2], [0.3, 0.4]];
645
646 let result = SerializableExplanationResult::new(
647 "test_id".to_string(),
648 "shap".to_string(),
649 feature_importance,
650 None,
651 )
652 .with_shap_values(shap_values.clone());
653
654 let recovered_shap = result.get_shap_values().unwrap();
655 assert_eq!(recovered_shap, shap_values);
656 }
657
658 #[test]
659 fn test_json_serialization() {
660 let feature_importance = array![0.5, 0.3, 0.2];
661 let result = SerializableExplanationResult::new(
662 "test_id".to_string(),
663 "permutation".to_string(),
664 feature_importance,
665 None,
666 );
667
668 let json = result.to_json().unwrap();
669 let recovered = SerializableExplanationResult::from_json(&json).unwrap();
670
671 assert_eq!(result.id, recovered.id);
672 assert_eq!(result.method, recovered.method);
673 assert_eq!(result.feature_importance, recovered.feature_importance);
674 }
675
676 #[test]
677 fn test_csv_serialization() {
678 let feature_importance = array![0.5, 0.3, 0.2];
679 let feature_names = vec!["f1".to_string(), "f2".to_string(), "f3".to_string()];
680
681 let result = SerializableExplanationResult::new(
682 "test_id".to_string(),
683 "permutation".to_string(),
684 feature_importance,
685 Some(feature_names),
686 );
687
688 let csv = result.to_csv().unwrap();
689 let recovered = SerializableExplanationResult::from_csv(&csv).unwrap();
690
691 assert_eq!(
692 result.feature_importance.len(),
693 recovered.feature_importance.len()
694 );
695 for (a, b) in result
696 .feature_importance
697 .iter()
698 .zip(recovered.feature_importance.iter())
699 {
700 assert_abs_diff_eq!(*a, *b, epsilon = 1e-6);
701 }
702 }
703
704 #[test]
705 fn test_file_save_and_load() {
706 let temp_dir = tempdir().unwrap();
707 let file_path = temp_dir.path().join("test_explanation.json");
708
709 let feature_importance = array![0.5, 0.3, 0.2];
710 let result = SerializableExplanationResult::new(
711 "test_id".to_string(),
712 "permutation".to_string(),
713 feature_importance,
714 None,
715 );
716
717 let config = SerializationConfig::default();
718 result.save_to_file(&file_path, &config).unwrap();
719
720 let loaded = SerializableExplanationResult::load_from_file(&file_path, &config).unwrap();
721 assert_eq!(result.id, loaded.id);
722 assert_eq!(result.method, loaded.method);
723 assert_eq!(result.feature_importance, loaded.feature_importance);
724 }
725
726 #[test]
727 fn test_explanation_summary() {
728 let feature_importance = array![0.5, 0.3, 0.8, 0.1];
729 let result = SerializableExplanationResult::new(
730 "test_id".to_string(),
731 "permutation".to_string(),
732 feature_importance,
733 None,
734 );
735
736 let summary = result.get_summary();
737 assert_eq!(summary.method, "permutation");
738 assert_eq!(summary.num_features, 4);
739 assert_eq!(summary.max_importance, 0.8);
740 assert_eq!(summary.min_importance, 0.1);
741 assert!(!summary.has_shap_values);
742 assert!(!summary.has_confidence_intervals);
743 }
744
745 #[test]
746 fn test_explanation_batch() {
747 let mut batch = ExplanationBatch::new();
748
749 let result1 = SerializableExplanationResult::new(
750 "test_1".to_string(),
751 "permutation".to_string(),
752 array![0.5, 0.3],
753 None,
754 );
755
756 let result2 = SerializableExplanationResult::new(
757 "test_2".to_string(),
758 "shap".to_string(),
759 array![0.2, 0.8],
760 None,
761 );
762
763 batch.add_result(result1);
764 batch.add_result(result2);
765 batch.add_metadata("experiment".to_string(), "test_run".to_string());
766
767 let summary = batch.get_batch_summary();
768 assert_eq!(summary.num_results, 2);
769 assert!(summary.methods_used.contains(&"permutation".to_string()));
770 assert!(summary.methods_used.contains(&"shap".to_string()));
771 assert!(summary.has_metadata);
772 }
773
774 #[test]
775 fn test_batch_save_and_load() {
776 let temp_dir = tempdir().unwrap();
777
778 let mut batch = ExplanationBatch::new();
779 batch.add_result(SerializableExplanationResult::new(
780 "test_1".to_string(),
781 "permutation".to_string(),
782 array![0.5, 0.3],
783 None,
784 ));
785 batch.add_metadata("experiment".to_string(), "test_batch".to_string());
786
787 let config = SerializationConfig::default();
788 batch.save_to_directory(temp_dir.path(), &config).unwrap();
789
790 let loaded_batch = ExplanationBatch::load_from_directory(temp_dir.path(), &config).unwrap();
791 assert_eq!(loaded_batch.results.len(), 1);
792 assert_eq!(
793 loaded_batch.metadata.get("experiment"),
794 Some(&"test_batch".to_string())
795 );
796 }
797
798 #[test]
799 fn test_serialization_config_default() {
800 let config = SerializationConfig::default();
801 assert_eq!(config.format, SerializationFormat::Json);
802 assert_eq!(config.compression, CompressionType::None);
803 assert!(!config.include_raw_data);
804 assert_eq!(config.float_precision, 6);
805 }
806
807 #[test]
808 fn test_model_metadata_default() {
809 let metadata = ModelMetadata::default();
810 assert_eq!(metadata.model_type, "unknown");
811 assert_eq!(metadata.version, "1.0.0");
812 assert!(metadata.training_accuracy.is_none());
813 }
814
815 #[test]
816 fn test_dataset_metadata_default() {
817 let metadata = DatasetMetadata::default();
818 assert_eq!(metadata.num_samples, 0);
819 assert_eq!(metadata.num_features, 0);
820 assert!(metadata.name.is_none());
821 }
822
823 #[test]
824 fn test_with_methods() {
825 let feature_importance = array![0.5, 0.3];
826 let model_info = ModelMetadata {
827 model_type: "linear_regression".to_string(),
828 ..Default::default()
829 };
830
831 let result = SerializableExplanationResult::new(
832 "test_id".to_string(),
833 "permutation".to_string(),
834 feature_importance,
835 None,
836 )
837 .with_model_info(model_info.clone())
838 .with_metadata("key1".to_string(), "value1".to_string());
839
840 assert_eq!(result.model_info.model_type, "linear_regression");
841 assert_eq!(result.metadata.get("key1"), Some(&"value1".to_string()));
842 }
843}