sklears_datasets/
versioning.rs

1//! Dataset versioning and provenance tracking
2//!
3//! This module provides functionality for tracking dataset versions, lineage,
4//! and provenance information to ensure reproducibility and auditability.
5
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::fmt;
10use std::path::Path;
11use thiserror::Error;
12
13/// Error types for versioning operations
14#[derive(Debug, Error)]
15pub enum VersioningError {
16    #[error("IO error: {0}")]
17    Io(#[from] std::io::Error),
18    #[error("Serialization error: {0}")]
19    Serialization(#[from] serde_json::Error),
20    #[error("Version not found: {0}")]
21    VersionNotFound(String),
22    #[error("Invalid version format: {0}")]
23    InvalidVersion(String),
24    #[error("Checksum mismatch: expected {expected}, got {actual}")]
25    ChecksumMismatch { expected: String, actual: String },
26}
27
28pub type VersioningResult<T> = Result<T, VersioningError>;
29
30/// Semantic version for datasets
31#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
32pub struct DatasetVersion {
33    pub major: u32,
34    pub minor: u32,
35    pub patch: u32,
36    pub prerelease: Option<String>,
37}
38
39impl DatasetVersion {
40    pub fn new(major: u32, minor: u32, patch: u32) -> Self {
41        Self {
42            major,
43            minor,
44            patch,
45            prerelease: None,
46        }
47    }
48
49    pub fn with_prerelease(mut self, prerelease: String) -> Self {
50        self.prerelease = Some(prerelease);
51        self
52    }
53
54    pub fn from_string(s: &str) -> VersioningResult<Self> {
55        let parts: Vec<&str> = s.split('-').collect();
56        let version_parts: Vec<&str> = parts[0].split('.').collect();
57
58        if version_parts.len() != 3 {
59            return Err(VersioningError::InvalidVersion(s.to_string()));
60        }
61
62        let major = version_parts[0]
63            .parse()
64            .map_err(|_| VersioningError::InvalidVersion(s.to_string()))?;
65        let minor = version_parts[1]
66            .parse()
67            .map_err(|_| VersioningError::InvalidVersion(s.to_string()))?;
68        let patch = version_parts[2]
69            .parse()
70            .map_err(|_| VersioningError::InvalidVersion(s.to_string()))?;
71
72        let prerelease = if parts.len() > 1 {
73            Some(parts[1].to_string())
74        } else {
75            None
76        };
77
78        Ok(Self {
79            major,
80            minor,
81            patch,
82            prerelease,
83        })
84    }
85}
86
87impl fmt::Display for DatasetVersion {
88    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89        if let Some(ref pre) = self.prerelease {
90            write!(f, "{}.{}.{}-{}", self.major, self.minor, self.patch, pre)
91        } else {
92            write!(f, "{}.{}.{}", self.major, self.minor, self.patch)
93        }
94    }
95}
96
97/// Provenance information tracking dataset lineage
98#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct ProvenanceInfo {
100    /// Unique identifier for this dataset
101    pub dataset_id: String,
102    /// Version of the dataset
103    pub version: DatasetVersion,
104    /// Creation timestamp
105    pub created_at: DateTime<Utc>,
106    /// Last modification timestamp
107    pub modified_at: DateTime<Utc>,
108    /// Creator/author information
109    pub creator: String,
110    /// Description of the dataset
111    pub description: String,
112    /// Source datasets (parent datasets this was derived from)
113    pub sources: Vec<String>,
114    /// Transformation operations applied
115    pub transformations: Vec<TransformationStep>,
116    /// Checksums for data integrity
117    pub checksums: HashMap<String, String>,
118    /// Custom metadata
119    pub metadata: HashMap<String, String>,
120}
121
122impl ProvenanceInfo {
123    pub fn new(dataset_id: String, version: DatasetVersion, creator: String) -> Self {
124        let now = Utc::now();
125        Self {
126            dataset_id,
127            version,
128            created_at: now,
129            modified_at: now,
130            creator,
131            description: String::new(),
132            sources: Vec::new(),
133            transformations: Vec::new(),
134            checksums: HashMap::new(),
135            metadata: HashMap::new(),
136        }
137    }
138
139    pub fn with_description(mut self, description: String) -> Self {
140        self.description = description;
141        self
142    }
143
144    pub fn add_source(&mut self, source_id: String) {
145        self.sources.push(source_id);
146        self.modified_at = Utc::now();
147    }
148
149    pub fn add_transformation(&mut self, transformation: TransformationStep) {
150        self.transformations.push(transformation);
151        self.modified_at = Utc::now();
152    }
153
154    pub fn add_checksum(&mut self, name: String, checksum: String) {
155        self.checksums.insert(name, checksum);
156        self.modified_at = Utc::now();
157    }
158
159    pub fn add_metadata(&mut self, key: String, value: String) {
160        self.metadata.insert(key, value);
161        self.modified_at = Utc::now();
162    }
163
164    /// Save provenance information to a JSON file
165    pub fn save_to_file<P: AsRef<Path>>(&self, path: P) -> VersioningResult<()> {
166        let json = serde_json::to_string_pretty(self)?;
167        std::fs::write(path, json)?;
168        Ok(())
169    }
170
171    /// Load provenance information from a JSON file
172    pub fn load_from_file<P: AsRef<Path>>(path: P) -> VersioningResult<Self> {
173        let json = std::fs::read_to_string(path)?;
174        let provenance: ProvenanceInfo = serde_json::from_str(&json)?;
175        Ok(provenance)
176    }
177}
178
179/// A single transformation step in the dataset lineage
180#[derive(Debug, Clone, Serialize, Deserialize)]
181pub struct TransformationStep {
182    /// Timestamp of the transformation
183    pub timestamp: DateTime<Utc>,
184    /// Type of transformation (e.g., "normalization", "feature_selection", "sampling")
185    pub transformation_type: String,
186    /// Description of the transformation
187    pub description: String,
188    /// Parameters used in the transformation
189    pub parameters: HashMap<String, String>,
190    /// User/system that performed the transformation
191    pub performed_by: String,
192}
193
194impl TransformationStep {
195    pub fn new(transformation_type: String, description: String, performed_by: String) -> Self {
196        Self {
197            timestamp: Utc::now(),
198            transformation_type,
199            description,
200            parameters: HashMap::new(),
201            performed_by,
202        }
203    }
204
205    pub fn with_parameter(mut self, key: String, value: String) -> Self {
206        self.parameters.insert(key, value);
207        self
208    }
209}
210
211/// Dataset version registry for managing multiple versions
212#[derive(Debug, Clone, Serialize, Deserialize)]
213pub struct VersionRegistry {
214    /// Name of the dataset
215    pub dataset_name: String,
216    /// All registered versions
217    pub versions: HashMap<String, ProvenanceInfo>,
218    /// Current/latest version
219    pub current_version: Option<String>,
220}
221
222impl VersionRegistry {
223    pub fn new(dataset_name: String) -> Self {
224        Self {
225            dataset_name,
226            versions: HashMap::new(),
227            current_version: None,
228        }
229    }
230
231    pub fn register_version(&mut self, provenance: ProvenanceInfo) {
232        let version_str = provenance.version.to_string();
233        self.versions.insert(version_str.clone(), provenance);
234        self.current_version = Some(version_str);
235    }
236
237    pub fn get_version(&self, version: &str) -> Option<&ProvenanceInfo> {
238        self.versions.get(version)
239    }
240
241    pub fn get_current(&self) -> Option<&ProvenanceInfo> {
242        self.current_version
243            .as_ref()
244            .and_then(|v| self.versions.get(v))
245    }
246
247    pub fn set_current(&mut self, version: &str) -> VersioningResult<()> {
248        if !self.versions.contains_key(version) {
249            return Err(VersioningError::VersionNotFound(version.to_string()));
250        }
251        self.current_version = Some(version.to_string());
252        Ok(())
253    }
254
255    /// List all versions in chronological order
256    pub fn list_versions(&self) -> Vec<&ProvenanceInfo> {
257        let mut versions: Vec<&ProvenanceInfo> = self.versions.values().collect();
258        versions.sort_by_key(|p| &p.created_at);
259        versions
260    }
261
262    /// Save version registry to a JSON file
263    pub fn save_to_file<P: AsRef<Path>>(&self, path: P) -> VersioningResult<()> {
264        let json = serde_json::to_string_pretty(self)?;
265        std::fs::write(path, json)?;
266        Ok(())
267    }
268
269    /// Load version registry from a JSON file
270    pub fn load_from_file<P: AsRef<Path>>(path: P) -> VersioningResult<Self> {
271        let json = std::fs::read_to_string(path)?;
272        let registry: VersionRegistry = serde_json::from_str(&json)?;
273        Ok(registry)
274    }
275}
276
277/// Calculate SHA-256 checksum for data verification
278pub fn calculate_checksum(data: &[u8]) -> String {
279    use std::collections::hash_map::DefaultHasher;
280    use std::hash::{Hash, Hasher};
281
282    let mut hasher = DefaultHasher::new();
283    data.hash(&mut hasher);
284    format!("{:x}", hasher.finish())
285}
286
287/// Verify data integrity against a checksum
288pub fn verify_checksum(data: &[u8], expected: &str) -> VersioningResult<()> {
289    let actual = calculate_checksum(data);
290    if actual != expected {
291        Err(VersioningError::ChecksumMismatch {
292            expected: expected.to_string(),
293            actual,
294        })
295    } else {
296        Ok(())
297    }
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303
304    #[test]
305    fn test_dataset_version() {
306        let version = DatasetVersion::new(1, 2, 3);
307        assert_eq!(version.to_string(), "1.2.3");
308
309        let version_pre = DatasetVersion::new(1, 2, 3).with_prerelease("alpha".to_string());
310        assert_eq!(version_pre.to_string(), "1.2.3-alpha");
311    }
312
313    #[test]
314    fn test_version_parsing() {
315        let version = DatasetVersion::from_string("1.2.3").unwrap();
316        assert_eq!(version.major, 1);
317        assert_eq!(version.minor, 2);
318        assert_eq!(version.patch, 3);
319        assert_eq!(version.prerelease, None);
320
321        let version_pre = DatasetVersion::from_string("1.2.3-beta").unwrap();
322        assert_eq!(version_pre.prerelease, Some("beta".to_string()));
323    }
324
325    #[test]
326    fn test_provenance_info() {
327        let mut provenance = ProvenanceInfo::new(
328            "test-dataset".to_string(),
329            DatasetVersion::new(1, 0, 0),
330            "test-user".to_string(),
331        );
332
333        provenance.add_source("source-dataset-1".to_string());
334        provenance.add_metadata("key1".to_string(), "value1".to_string());
335
336        assert_eq!(provenance.sources.len(), 1);
337        assert_eq!(provenance.metadata.len(), 1);
338    }
339
340    #[test]
341    fn test_transformation_step() {
342        let step = TransformationStep::new(
343            "normalization".to_string(),
344            "StandardScaler normalization".to_string(),
345            "test-user".to_string(),
346        )
347        .with_parameter("mean".to_string(), "0.0".to_string())
348        .with_parameter("std".to_string(), "1.0".to_string());
349
350        assert_eq!(step.parameters.len(), 2);
351    }
352
353    #[test]
354    fn test_version_registry() {
355        let mut registry = VersionRegistry::new("test-dataset".to_string());
356
357        let prov1 = ProvenanceInfo::new(
358            "test-dataset".to_string(),
359            DatasetVersion::new(1, 0, 0),
360            "user1".to_string(),
361        );
362        registry.register_version(prov1);
363
364        let prov2 = ProvenanceInfo::new(
365            "test-dataset".to_string(),
366            DatasetVersion::new(1, 1, 0),
367            "user1".to_string(),
368        );
369        registry.register_version(prov2);
370
371        assert_eq!(registry.versions.len(), 2);
372        assert!(registry.get_version("1.0.0").is_some());
373        assert!(registry.get_version("1.1.0").is_some());
374    }
375
376    #[test]
377    fn test_checksum() {
378        let data = b"test data";
379        let checksum = calculate_checksum(data);
380        assert!(verify_checksum(data, &checksum).is_ok());
381
382        assert!(verify_checksum(data, "invalid").is_err());
383    }
384
385    #[test]
386    fn test_provenance_serialization() {
387        use std::env::temp_dir;
388
389        let provenance = ProvenanceInfo::new(
390            "test-dataset".to_string(),
391            DatasetVersion::new(1, 0, 0),
392            "test-user".to_string(),
393        )
394        .with_description("Test dataset".to_string());
395
396        let temp_path = temp_dir().join("test_provenance.json");
397        provenance.save_to_file(&temp_path).unwrap();
398
399        let loaded = ProvenanceInfo::load_from_file(&temp_path).unwrap();
400        assert_eq!(loaded.dataset_id, "test-dataset");
401        assert_eq!(loaded.version.major, 1);
402
403        std::fs::remove_file(temp_path).ok();
404    }
405}