1use crate::{Dataset, Result};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::{Path, PathBuf};
10use std::time::{SystemTime, UNIX_EPOCH};
11use tenflowers_core::{Tensor, TensorError};
12
13pub type VersionId = String;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct VersionMetadata {
19 pub version_id: VersionId,
21 pub parent_version: Option<VersionId>,
23 pub timestamp: u64,
25 pub description: String,
27 pub tags: Vec<String>,
29 pub custom_metadata: HashMap<String, String>,
31 pub checksum: String,
33 pub size_info: DatasetSizeInfo,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct DatasetSizeInfo {
40 pub sample_count: usize,
42 pub feature_shape: Vec<usize>,
44 pub label_shape: Vec<usize>,
46 pub size_bytes: u64,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct DatasetLineage {
53 pub version: VersionMetadata,
55 pub transformations: Vec<TransformationRecord>,
57 pub source_versions: Vec<VersionId>,
59 pub child_versions: Vec<VersionId>,
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct TransformationRecord {
66 pub transform_type: String,
68 pub parameters: HashMap<String, String>,
70 pub timestamp: u64,
72 pub description: String,
74}
75
76#[derive(Debug)]
78pub struct DatasetVersionManager {
79 base_path: PathBuf,
81 lineage_graph: HashMap<VersionId, DatasetLineage>,
83 current_version: Option<VersionId>,
85}
86
87impl DatasetVersionManager {
88 pub fn new<P: AsRef<Path>>(base_path: P) -> Result<Self> {
90 let base_path = base_path.as_ref().to_path_buf();
91
92 if !base_path.exists() {
94 std::fs::create_dir_all(&base_path).map_err(|e| {
95 TensorError::invalid_argument(format!("Failed to create version directory: {e}"))
96 })?;
97 }
98
99 let mut manager = Self {
100 base_path,
101 lineage_graph: HashMap::new(),
102 current_version: None,
103 };
104
105 manager.load_lineage_graph()?;
107
108 Ok(manager)
109 }
110
111 pub fn create_snapshot<T>(
113 &mut self,
114 dataset: &dyn Dataset<T>,
115 description: String,
116 tags: Vec<String>,
117 parent_version: Option<VersionId>,
118 ) -> Result<VersionId>
119 where
120 T: Clone + Default + serde::Serialize + serde::de::DeserializeOwned + Send + Sync + 'static,
121 {
122 let version_id = self.generate_version_id();
123 let timestamp = SystemTime::now()
124 .duration_since(UNIX_EPOCH)
125 .expect("system time before UNIX_EPOCH")
126 .as_secs();
127
128 let size_info = self.calculate_size_info(dataset)?;
130
131 let checksum = self.calculate_checksum(dataset)?;
133
134 let metadata = VersionMetadata {
136 version_id: version_id.clone(),
137 parent_version: parent_version.clone(),
138 timestamp,
139 description,
140 tags,
141 custom_metadata: HashMap::new(),
142 checksum,
143 size_info,
144 };
145
146 let version_dir = self.base_path.join(&version_id);
148 std::fs::create_dir_all(&version_dir).map_err(|e| {
149 TensorError::invalid_argument(format!("Failed to create version directory: {e}"))
150 })?;
151
152 self.save_dataset_samples(dataset, &version_dir)?;
154
155 self.save_metadata(&metadata, &version_dir)?;
157
158 let lineage = DatasetLineage {
160 version: metadata,
161 transformations: Vec::new(),
162 source_versions: if let Some(parent) = &parent_version {
163 vec![parent.clone()]
164 } else {
165 Vec::new()
166 },
167 child_versions: Vec::new(),
168 };
169
170 self.lineage_graph.insert(version_id.clone(), lineage);
171
172 if let Some(parent) = &parent_version {
174 if let Some(parent_lineage) = self.lineage_graph.get_mut(parent) {
175 parent_lineage.child_versions.push(version_id.clone());
176 }
177 }
178
179 self.current_version = Some(version_id.clone());
180 self.save_lineage_graph()?;
181
182 Ok(version_id)
183 }
184
185 pub fn load_snapshot<T>(&self, version_id: &str) -> Result<VersionedDataset<T>>
187 where
188 T: Clone + Default + serde::de::DeserializeOwned + Send + Sync + 'static,
189 {
190 let version_dir = self.base_path.join(version_id);
191
192 if !version_dir.exists() {
193 return Err(TensorError::invalid_argument(format!(
194 "Version {version_id} not found"
195 )));
196 }
197
198 let metadata = self.load_metadata(&version_dir)?;
200
201 let samples = self.load_dataset_samples(&version_dir)?;
203
204 Ok(VersionedDataset { metadata, samples })
205 }
206
207 pub fn get_lineage(&self, version_id: &str) -> Option<&DatasetLineage> {
209 self.lineage_graph.get(version_id)
210 }
211
212 pub fn add_transformation(
214 &mut self,
215 version_id: &str,
216 transform_type: String,
217 parameters: HashMap<String, String>,
218 description: String,
219 ) -> Result<()> {
220 let timestamp = SystemTime::now()
221 .duration_since(UNIX_EPOCH)
222 .expect("system time before UNIX_EPOCH")
223 .as_secs();
224
225 let transformation = TransformationRecord {
226 transform_type,
227 parameters,
228 timestamp,
229 description,
230 };
231
232 if let Some(lineage) = self.lineage_graph.get_mut(version_id) {
233 lineage.transformations.push(transformation);
234 self.save_lineage_graph()?;
235 } else {
236 return Err(TensorError::invalid_argument(format!(
237 "Version {version_id} not found"
238 )));
239 }
240
241 Ok(())
242 }
243
244 pub fn list_versions(&self) -> Vec<&VersionMetadata> {
246 self.lineage_graph
247 .values()
248 .map(|lineage| &lineage.version)
249 .collect()
250 }
251
252 pub fn get_versions_by_tag(&self, tag: &str) -> Vec<&VersionMetadata> {
254 self.lineage_graph
255 .values()
256 .filter(|lineage| lineage.version.tags.contains(&tag.to_string()))
257 .map(|lineage| &lineage.version)
258 .collect()
259 }
260
261 pub fn get_lineage_tree(&self, version_id: &str) -> Option<LineageTree> {
263 self.lineage_graph
264 .get(version_id)
265 .map(|lineage| self.build_lineage_tree(&lineage.version))
266 }
267
268 fn build_lineage_tree(&self, version: &VersionMetadata) -> LineageTree {
269 let children = version.version_id.clone();
270 let child_trees = if let Some(lineage) = self.lineage_graph.get(&children) {
271 lineage
272 .child_versions
273 .iter()
274 .filter_map(|child_id| {
275 self.lineage_graph
276 .get(child_id)
277 .map(|child_lineage| self.build_lineage_tree(&child_lineage.version))
278 })
279 .collect()
280 } else {
281 Vec::new()
282 };
283
284 LineageTree {
285 version: version.clone(),
286 children: child_trees,
287 }
288 }
289
290 fn generate_version_id(&self) -> VersionId {
291 format!("v_{}", uuid::Uuid::new_v4().to_string().replace('-', ""))
292 }
293
294 fn calculate_size_info<T>(&self, dataset: &dyn Dataset<T>) -> Result<DatasetSizeInfo>
295 where
296 T: Clone + Default + Send + Sync + 'static,
297 {
298 let sample_count = dataset.len();
299
300 if sample_count == 0 {
301 return Ok(DatasetSizeInfo {
302 sample_count: 0,
303 feature_shape: vec![0],
304 label_shape: vec![0],
305 size_bytes: 0,
306 });
307 }
308
309 let (features, labels) = dataset.get(0)?;
311 let feature_shape = features.shape().dims().to_vec();
312 let label_shape = labels.shape().dims().to_vec();
313
314 let feature_size = feature_shape.iter().product::<usize>();
316 let label_size = label_shape.iter().product::<usize>();
317 let estimated_bytes_per_sample = (feature_size + label_size) * std::mem::size_of::<f32>();
318 let size_bytes = (sample_count * estimated_bytes_per_sample) as u64;
319
320 Ok(DatasetSizeInfo {
321 sample_count,
322 feature_shape,
323 label_shape,
324 size_bytes,
325 })
326 }
327
328 fn calculate_checksum<T>(&self, dataset: &dyn Dataset<T>) -> Result<String>
329 where
330 T: Clone + Default + Send + Sync + 'static,
331 {
332 let len = dataset.len();
334 if len == 0 {
335 return Ok("empty_dataset".to_string());
336 }
337
338 let (first_features, first_labels) = dataset.get(0)?;
339 let mut checksum_value = 0u64;
340
341 checksum_value = checksum_value.wrapping_mul(31).wrapping_add(len as u64);
343
344 for &dim in first_features.shape().dims() {
346 checksum_value = checksum_value.wrapping_mul(31).wrapping_add(dim as u64);
347 }
348
349 for &dim in first_labels.shape().dims() {
351 checksum_value = checksum_value.wrapping_mul(31).wrapping_add(dim as u64);
352 }
353
354 let features_hash = format!("{:?}", first_features.shape().dims()).len() as u64;
356 let labels_hash = format!("{:?}", first_labels.shape().dims()).len() as u64;
357
358 checksum_value = checksum_value.wrapping_mul(31).wrapping_add(features_hash);
359 checksum_value = checksum_value.wrapping_mul(31).wrapping_add(labels_hash);
360
361 Ok(format!("{checksum_value:016x}"))
362 }
363
364 fn save_dataset_samples<T>(&self, dataset: &dyn Dataset<T>, version_dir: &Path) -> Result<()>
365 where
366 T: Clone + Default + serde::Serialize + Send + Sync + 'static,
367 {
368 let samples_file = version_dir.join("samples.json");
369 let mut samples = Vec::new();
370
371 for i in 0..dataset.len() {
372 let (features, labels) = dataset.get(i)?;
373
374 let features_data = if let Some(slice) = features.as_slice() {
376 slice.to_vec()
377 } else {
378 vec![features.get(&[]).unwrap_or(T::default())]
379 };
380
381 let labels_data = if let Some(slice) = labels.as_slice() {
382 slice.to_vec()
383 } else {
384 vec![labels.get(&[]).unwrap_or(T::default())]
385 };
386
387 samples.push(serde_json::json!({
388 "features": features_data,
389 "labels": labels_data,
390 "feature_shape": features.shape().dims(),
391 "label_shape": labels.shape().dims(),
392 }));
393 }
394
395 let json_data = serde_json::to_string_pretty(&samples).map_err(|e| {
396 TensorError::invalid_argument(format!("Failed to serialize samples: {e}"))
397 })?;
398
399 std::fs::write(samples_file, json_data).map_err(|e| {
400 TensorError::invalid_argument(format!("Failed to write samples file: {e}"))
401 })?;
402
403 Ok(())
404 }
405
406 fn load_dataset_samples<T>(&self, version_dir: &Path) -> Result<Vec<(Tensor<T>, Tensor<T>)>>
407 where
408 T: Clone + Default + serde::de::DeserializeOwned + Send + Sync + 'static,
409 {
410 let samples_file = version_dir.join("samples.json");
411 let json_data = std::fs::read_to_string(samples_file).map_err(|e| {
412 TensorError::invalid_argument(format!("Failed to read samples file: {e}"))
413 })?;
414
415 let json_samples: Vec<serde_json::Value> =
416 serde_json::from_str(&json_data).map_err(|e| {
417 TensorError::invalid_argument(format!("Failed to parse samples JSON: {e}"))
418 })?;
419
420 let mut samples = Vec::new();
421 for sample in json_samples {
422 let features_data: Vec<T> = serde_json::from_value(sample["features"].clone())
423 .map_err(|e| {
424 TensorError::invalid_argument(format!("Failed to parse features: {e}"))
425 })?;
426
427 let labels_data: Vec<T> =
428 serde_json::from_value(sample["labels"].clone()).map_err(|e| {
429 TensorError::invalid_argument(format!("Failed to parse labels: {e}"))
430 })?;
431
432 let feature_shape: Vec<usize> = serde_json::from_value(sample["feature_shape"].clone())
433 .map_err(|e| {
434 TensorError::invalid_argument(format!("Failed to parse feature shape: {e}"))
435 })?;
436
437 let label_shape: Vec<usize> = serde_json::from_value(sample["label_shape"].clone())
438 .map_err(|e| {
439 TensorError::invalid_argument(format!("Failed to parse label shape: {e}"))
440 })?;
441
442 let features_tensor = if feature_shape.is_empty() || feature_shape == vec![0] {
443 Tensor::from_scalar(features_data.into_iter().next().unwrap_or_default())
444 } else {
445 Tensor::from_vec(features_data, &feature_shape)?
446 };
447
448 let labels_tensor = if label_shape.is_empty() || label_shape == vec![0] {
449 Tensor::from_scalar(labels_data.into_iter().next().unwrap_or_default())
450 } else {
451 Tensor::from_vec(labels_data, &label_shape)?
452 };
453
454 samples.push((features_tensor, labels_tensor));
455 }
456
457 Ok(samples)
458 }
459
460 fn save_metadata(&self, metadata: &VersionMetadata, version_dir: &Path) -> Result<()> {
461 let metadata_file = version_dir.join("metadata.json");
462 let json_data = serde_json::to_string_pretty(metadata).map_err(|e| {
463 TensorError::invalid_argument(format!("Failed to serialize metadata: {e}"))
464 })?;
465
466 std::fs::write(metadata_file, json_data).map_err(|e| {
467 TensorError::invalid_argument(format!("Failed to write metadata file: {e}"))
468 })?;
469
470 Ok(())
471 }
472
473 fn load_metadata(&self, version_dir: &Path) -> Result<VersionMetadata> {
474 let metadata_file = version_dir.join("metadata.json");
475 let json_data = std::fs::read_to_string(metadata_file).map_err(|e| {
476 TensorError::invalid_argument(format!("Failed to read metadata file: {e}"))
477 })?;
478
479 serde_json::from_str(&json_data).map_err(|e| {
480 TensorError::invalid_argument(format!("Failed to parse metadata JSON: {e}"))
481 })
482 }
483
484 fn save_lineage_graph(&self) -> Result<()> {
485 let lineage_file = self.base_path.join("lineage.json");
486 let json_data = serde_json::to_string_pretty(&self.lineage_graph).map_err(|e| {
487 TensorError::invalid_argument(format!("Failed to serialize lineage graph: {e}"))
488 })?;
489
490 std::fs::write(lineage_file, json_data).map_err(|e| {
491 TensorError::invalid_argument(format!("Failed to write lineage file: {e}"))
492 })?;
493
494 Ok(())
495 }
496
497 fn load_lineage_graph(&mut self) -> Result<()> {
498 let lineage_file = self.base_path.join("lineage.json");
499
500 if !lineage_file.exists() {
501 return Ok(()); }
503
504 let json_data = std::fs::read_to_string(lineage_file).map_err(|e| {
505 TensorError::invalid_argument(format!("Failed to read lineage file: {e}"))
506 })?;
507
508 self.lineage_graph = serde_json::from_str(&json_data).map_err(|e| {
509 TensorError::invalid_argument(format!("Failed to parse lineage JSON: {e}"))
510 })?;
511
512 Ok(())
513 }
514}
515
516#[derive(Debug, Clone)]
518pub struct LineageTree {
519 pub version: VersionMetadata,
520 pub children: Vec<LineageTree>,
521}
522
523#[derive(Debug)]
525pub struct VersionedDataset<T> {
526 metadata: VersionMetadata,
527 samples: Vec<(Tensor<T>, Tensor<T>)>,
528}
529
530impl<T> VersionedDataset<T>
531where
532 T: Clone + Default + Send + Sync + 'static,
533{
534 pub fn metadata(&self) -> &VersionMetadata {
536 &self.metadata
537 }
538
539 pub fn version_id(&self) -> &str {
541 &self.metadata.version_id
542 }
543}
544
545impl<T> Dataset<T> for VersionedDataset<T>
546where
547 T: Clone + Default + Send + Sync + 'static,
548{
549 fn len(&self) -> usize {
550 self.samples.len()
551 }
552
553 fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
554 if index >= self.samples.len() {
555 return Err(TensorError::invalid_argument(format!(
556 "Index {} out of bounds for dataset of length {}",
557 index,
558 self.samples.len()
559 )));
560 }
561
562 Ok(self.samples[index].clone())
563 }
564}
565
566#[cfg(test)]
567mod tests {
568 use super::*;
569 use crate::TensorDataset;
570 use tempfile::TempDir;
571
572 #[test]
573 fn test_version_manager_creation() {
574 let temp_dir = TempDir::new().expect("test: temp dir creation should succeed");
575 let manager =
576 DatasetVersionManager::new(temp_dir.path()).expect("test: operation should succeed");
577
578 assert!(temp_dir.path().exists());
579 assert_eq!(manager.list_versions().len(), 0);
580 }
581
582 #[test]
583 fn test_create_and_load_snapshot() {
584 let temp_dir = TempDir::new().expect("test: temp dir creation should succeed");
585 let mut manager =
586 DatasetVersionManager::new(temp_dir.path()).expect("test: operation should succeed");
587
588 let features_data = vec![1.0, 2.0, 3.0, 4.0];
590 let labels_data = vec![0.0, 1.0];
591 let features =
592 Tensor::from_vec(features_data, &[2, 2]).expect("test: tensor creation should succeed");
593 let labels =
594 Tensor::from_vec(labels_data, &[2]).expect("test: tensor creation should succeed");
595 let dataset = TensorDataset::new(features, labels);
596
597 let version_id = manager
599 .create_snapshot(
600 &dataset,
601 "Test snapshot".to_string(),
602 vec!["test".to_string()],
603 None,
604 )
605 .expect("test: operation should succeed");
606
607 assert!(!version_id.is_empty());
608 assert_eq!(manager.list_versions().len(), 1);
609
610 let loaded_dataset = manager
612 .load_snapshot::<f32>(&version_id)
613 .expect("test: operation should succeed");
614 assert_eq!(loaded_dataset.len(), 2);
615 assert_eq!(loaded_dataset.version_id(), &version_id);
616
617 let (features, labels) = loaded_dataset.get(0).expect("index should be in bounds");
619 let features_slice = features.as_slice().expect("tensor should be contiguous");
620 assert_eq!(features_slice, &[1.0, 2.0]);
621 assert_eq!(labels.get(&[]).expect("test: get should succeed"), 0.0);
622 }
623
624 #[test]
625 fn test_lineage_tracking() {
626 let temp_dir = TempDir::new().expect("test: temp dir creation should succeed");
627 let mut manager =
628 DatasetVersionManager::new(temp_dir.path()).expect("test: operation should succeed");
629
630 let features_data1 = vec![1.0, 2.0];
632 let labels_data1 = vec![0.0];
633 let features1 = Tensor::from_vec(features_data1, &[1, 2])
634 .expect("test: tensor creation should succeed");
635 let labels1 =
636 Tensor::from_vec(labels_data1, &[1]).expect("test: tensor creation should succeed");
637 let dataset1 = TensorDataset::new(features1, labels1);
638
639 let version1 = manager
640 .create_snapshot(
641 &dataset1,
642 "Initial version".to_string(),
643 vec!["v1".to_string()],
644 None,
645 )
646 .expect("test: operation should succeed");
647
648 let features_data2 = vec![2.0, 4.0];
650 let labels_data2 = vec![1.0];
651 let features2 = Tensor::from_vec(features_data2, &[1, 2])
652 .expect("test: tensor creation should succeed");
653 let labels2 =
654 Tensor::from_vec(labels_data2, &[1]).expect("test: tensor creation should succeed");
655 let dataset2 = TensorDataset::new(features2, labels2);
656
657 let version2 = manager
658 .create_snapshot(
659 &dataset2,
660 "Scaled version".to_string(),
661 vec!["v2".to_string()],
662 Some(version1.clone()),
663 )
664 .expect("test: operation should succeed");
665
666 let mut params = HashMap::new();
668 params.insert("scale_factor".to_string(), "2.0".to_string());
669
670 manager
671 .add_transformation(
672 &version2,
673 "scale".to_string(),
674 params,
675 "Scale features by 2".to_string(),
676 )
677 .expect("test: operation should succeed");
678
679 let lineage = manager
681 .get_lineage(&version2)
682 .expect("test: operation should succeed");
683 assert_eq!(lineage.source_versions, vec![version1.clone()]);
684 assert_eq!(lineage.transformations.len(), 1);
685 assert_eq!(lineage.transformations[0].transform_type, "scale");
686
687 let tree = manager
689 .get_lineage_tree(&version1)
690 .expect("test: operation should succeed");
691 assert_eq!(tree.version.version_id, version1);
692 assert_eq!(tree.children.len(), 1);
693 assert_eq!(tree.children[0].version.version_id, version2);
694 }
695
696 #[test]
697 fn test_version_filtering() {
698 let temp_dir = TempDir::new().expect("test: temp dir creation should succeed");
699 let mut manager =
700 DatasetVersionManager::new(temp_dir.path()).expect("test: operation should succeed");
701
702 let features_data = vec![1.0];
704 let labels_data = vec![0.0];
705 let features =
706 Tensor::from_vec(features_data, &[1, 1]).expect("test: tensor creation should succeed");
707 let labels =
708 Tensor::from_vec(labels_data, &[1]).expect("test: tensor creation should succeed");
709 let dataset = TensorDataset::new(features, labels);
710
711 let _version1 = manager
712 .create_snapshot(
713 &dataset,
714 "Version 1".to_string(),
715 vec!["production".to_string()],
716 None,
717 )
718 .expect("test: operation should succeed");
719
720 let _version2 = manager
721 .create_snapshot(
722 &dataset,
723 "Version 2".to_string(),
724 vec!["development".to_string()],
725 None,
726 )
727 .expect("test: operation should succeed");
728
729 let _version3 = manager
730 .create_snapshot(
731 &dataset,
732 "Version 3".to_string(),
733 vec!["production".to_string(), "validated".to_string()],
734 None,
735 )
736 .expect("test: operation should succeed");
737
738 let prod_versions = manager.get_versions_by_tag("production");
740 assert_eq!(prod_versions.len(), 2);
741
742 let dev_versions = manager.get_versions_by_tag("development");
743 assert_eq!(dev_versions.len(), 1);
744
745 let validated_versions = manager.get_versions_by_tag("validated");
746 assert_eq!(validated_versions.len(), 1);
747 }
748}