1use sklears_core::error::{Result as SklResult, SklearsError};
7use std::collections::{BTreeMap, HashMap};
8use std::fs::{self, File};
9use std::hash::Hash;
10use std::io::{BufReader, BufWriter, Read, Write};
11use std::path::{Path, PathBuf};
12use std::sync::{Arc, Mutex, RwLock};
13use std::time::{Duration, SystemTime, UNIX_EPOCH};
14
15#[derive(Debug, Clone)]
17pub struct StateSnapshot {
18 pub id: String,
20 pub timestamp: SystemTime,
22 pub state_data: StateData,
24 pub metadata: HashMap<String, String>,
26 pub version: u64,
28 pub parent_id: Option<String>,
30 pub checksum: String,
32}
33
34#[derive(Debug, Clone)]
36pub struct StateData {
37 pub config: HashMap<String, String>,
39 pub model_parameters: HashMap<String, Vec<f64>>,
41 pub feature_names: Option<Vec<String>>,
43 pub steps_state: Vec<StepState>,
45 pub execution_stats: ExecutionStatistics,
47 pub custom_data: HashMap<String, Vec<u8>>,
49}
50
51#[derive(Debug, Clone)]
53pub struct StepState {
54 pub name: String,
56 pub step_type: String,
58 pub parameters: HashMap<String, Vec<f64>>,
60 pub config: HashMap<String, String>,
62 pub is_fitted: bool,
64 pub metadata: HashMap<String, String>,
66}
67
68#[derive(Debug, Clone)]
70pub struct ExecutionStatistics {
71 pub training_samples: usize,
73 pub prediction_requests: usize,
75 pub avg_prediction_time: Duration,
77 pub accuracy: Option<f64>,
79 pub memory_usage: MemoryUsage,
81 pub last_updated: SystemTime,
83}
84
85#[derive(Debug, Clone, Default)]
87pub struct MemoryUsage {
88 pub peak_memory: u64,
90 pub current_memory: u64,
92 pub allocations: u64,
94 pub deallocations: u64,
96}
97
98impl Default for ExecutionStatistics {
99 fn default() -> Self {
100 Self {
101 training_samples: 0,
102 prediction_requests: 0,
103 avg_prediction_time: Duration::ZERO,
104 accuracy: None,
105 memory_usage: MemoryUsage::default(),
106 last_updated: SystemTime::now(),
107 }
108 }
109}
110
111#[derive(Debug, Clone)]
113pub enum PersistenceStrategy {
114 InMemory,
116 LocalFileSystem {
118 base_path: PathBuf,
120 compression: bool,
122 },
123 Distributed {
125 nodes: Vec<String>,
127 replication_factor: usize,
129 },
130 Database {
132 connection_string: String,
134 table_name: String,
136 },
137 Custom {
139 save_fn: fn(&StateSnapshot, &str) -> SklResult<()>,
141 load_fn: fn(&str) -> SklResult<StateSnapshot>,
143 },
144}
145
146#[derive(Debug, Clone)]
148pub struct CheckpointConfig {
149 pub auto_checkpoint_interval: Option<Duration>,
151 pub max_checkpoints: usize,
153 pub checkpoint_on_update: bool,
155 pub checkpoint_on_error: bool,
157 pub compression_level: u32,
159 pub incremental: bool,
161}
162
163impl Default for CheckpointConfig {
164 fn default() -> Self {
165 Self {
166 auto_checkpoint_interval: Some(Duration::from_secs(300)), max_checkpoints: 10,
168 checkpoint_on_update: true,
169 checkpoint_on_error: true,
170 compression_level: 6,
171 incremental: false,
172 }
173 }
174}
175
176pub struct StateManager {
178 strategy: PersistenceStrategy,
180 config: CheckpointConfig,
182 snapshots: Arc<RwLock<BTreeMap<String, StateSnapshot>>>,
184 version_history: Arc<RwLock<Vec<String>>>,
186 checkpoint_timers: Arc<Mutex<HashMap<String, std::thread::JoinHandle<()>>>>,
188 listeners: Arc<RwLock<Vec<Box<dyn Fn(&StateSnapshot) + Send + Sync>>>>,
190}
191
192impl StateManager {
193 #[must_use]
195 pub fn new(strategy: PersistenceStrategy, config: CheckpointConfig) -> Self {
196 Self {
197 strategy,
198 config,
199 snapshots: Arc::new(RwLock::new(BTreeMap::new())),
200 version_history: Arc::new(RwLock::new(Vec::new())),
201 checkpoint_timers: Arc::new(Mutex::new(HashMap::new())),
202 listeners: Arc::new(RwLock::new(Vec::new())),
203 }
204 }
205
206 pub fn save_snapshot(&self, snapshot: StateSnapshot) -> SklResult<()> {
208 {
210 let mut snapshots = self.snapshots.write().unwrap_or_else(|e| e.into_inner());
211 snapshots.insert(snapshot.id.clone(), snapshot.clone());
212
213 if snapshots.len() > self.config.max_checkpoints {
215 if let Some((oldest_id, _)) = snapshots.iter().next() {
216 let oldest_id = oldest_id.clone();
217 snapshots.remove(&oldest_id);
218 }
219 }
220 }
221
222 {
224 let mut history = self
225 .version_history
226 .write()
227 .unwrap_or_else(|e| e.into_inner());
228 history.push(snapshot.id.clone());
229
230 if history.len() > self.config.max_checkpoints {
232 history.remove(0);
233 }
234 }
235
236 match &self.strategy {
238 PersistenceStrategy::InMemory => {
239 }
241 PersistenceStrategy::LocalFileSystem {
242 base_path,
243 compression,
244 } => {
245 self.save_to_filesystem(&snapshot, base_path, *compression)?;
246 }
247 PersistenceStrategy::Distributed {
248 nodes,
249 replication_factor,
250 } => {
251 self.save_to_distributed(&snapshot, nodes, *replication_factor)?;
252 }
253 PersistenceStrategy::Database {
254 connection_string,
255 table_name,
256 } => {
257 self.save_to_database(&snapshot, connection_string, table_name)?;
258 }
259 PersistenceStrategy::Custom { save_fn, .. } => {
260 save_fn(&snapshot, &snapshot.id)?;
261 }
262 }
263
264 self.notify_listeners(&snapshot);
266
267 Ok(())
268 }
269
270 pub fn load_snapshot(&self, snapshot_id: &str) -> SklResult<StateSnapshot> {
272 {
274 let snapshots = self.snapshots.read().unwrap_or_else(|e| e.into_inner());
275 if let Some(snapshot) = snapshots.get(snapshot_id) {
276 return Ok(snapshot.clone());
277 }
278 }
279
280 match &self.strategy {
282 PersistenceStrategy::InMemory => Err(SklearsError::InvalidInput(format!(
283 "Snapshot {snapshot_id} not found in memory"
284 ))),
285 PersistenceStrategy::LocalFileSystem {
286 base_path,
287 compression: _,
288 } => self.load_from_filesystem(snapshot_id, base_path),
289 PersistenceStrategy::Distributed {
290 nodes,
291 replication_factor: _,
292 } => self.load_from_distributed(snapshot_id, nodes),
293 PersistenceStrategy::Database {
294 connection_string,
295 table_name,
296 } => self.load_from_database(snapshot_id, connection_string, table_name),
297 PersistenceStrategy::Custom { load_fn, .. } => load_fn(snapshot_id),
298 }
299 }
300
301 pub fn create_checkpoint(&self, pipeline_id: &str, state_data: StateData) -> SklResult<String> {
303 let snapshot_id = self.generate_snapshot_id(pipeline_id);
304 let checksum = self.calculate_checksum(&state_data)?;
305
306 let snapshot = StateSnapshot {
307 id: snapshot_id.clone(),
308 timestamp: SystemTime::now(),
309 state_data,
310 metadata: HashMap::new(),
311 version: self.get_next_version(),
312 parent_id: self.get_latest_snapshot_id(pipeline_id),
313 checksum,
314 };
315
316 self.save_snapshot(snapshot)?;
317 Ok(snapshot_id)
318 }
319
320 pub fn resume_from_checkpoint(&self, snapshot_id: &str) -> SklResult<StateData> {
322 let snapshot = self.load_snapshot(snapshot_id)?;
323
324 let calculated_checksum = self.calculate_checksum(&snapshot.state_data)?;
326 if calculated_checksum != snapshot.checksum {
327 return Err(SklearsError::InvalidData {
328 reason: format!("Checksum mismatch for snapshot {snapshot_id}"),
329 });
330 }
331
332 Ok(snapshot.state_data)
333 }
334
335 #[must_use]
337 pub fn list_snapshots(&self) -> Vec<String> {
338 let snapshots = self.snapshots.read().unwrap_or_else(|e| e.into_inner());
339 snapshots.keys().cloned().collect()
340 }
341
342 #[must_use]
344 pub fn get_version_history(&self) -> Vec<String> {
345 let history = self
346 .version_history
347 .read()
348 .unwrap_or_else(|e| e.into_inner());
349 history.clone()
350 }
351
352 pub fn rollback(&self, target_snapshot_id: &str) -> SklResult<StateData> {
354 let snapshot = self.load_snapshot(target_snapshot_id)?;
355
356 let rollback_id = format!("rollback_{target_snapshot_id}");
358 let rollback_snapshot = StateSnapshot {
359 id: rollback_id,
360 timestamp: SystemTime::now(),
361 state_data: snapshot.state_data.clone(),
362 metadata: {
363 let mut meta = HashMap::new();
364 meta.insert("rollback_from".to_string(), target_snapshot_id.to_string());
365 meta
366 },
367 version: self.get_next_version(),
368 parent_id: Some(target_snapshot_id.to_string()),
369 checksum: snapshot.checksum.clone(),
370 };
371
372 self.save_snapshot(rollback_snapshot)?;
373 Ok(snapshot.state_data)
374 }
375
376 pub fn delete_snapshot(&self, snapshot_id: &str) -> SklResult<()> {
378 {
380 let mut snapshots = self.snapshots.write().unwrap_or_else(|e| e.into_inner());
381 snapshots.remove(snapshot_id);
382 }
383
384 {
386 let mut history = self
387 .version_history
388 .write()
389 .unwrap_or_else(|e| e.into_inner());
390 history.retain(|id| id != snapshot_id);
391 }
392
393 match &self.strategy {
395 PersistenceStrategy::InMemory => {
396 }
398 PersistenceStrategy::LocalFileSystem { base_path, .. } => {
399 let file_path = base_path.join(format!("{snapshot_id}.snapshot"));
400 if file_path.exists() {
401 fs::remove_file(file_path)?;
402 }
403 }
404 PersistenceStrategy::Distributed { .. } => {
405 }
407 PersistenceStrategy::Database { .. } => {
408 }
410 PersistenceStrategy::Custom { .. } => {
411 }
413 }
414
415 Ok(())
416 }
417
418 pub fn start_auto_checkpoint(
420 &self,
421 pipeline_id: String,
422 state_provider: Arc<dyn Fn() -> SklResult<StateData> + Send + Sync>,
423 ) -> SklResult<()> {
424 if let Some(interval) = self.config.auto_checkpoint_interval {
425 let pipeline_id_clone = pipeline_id.clone();
426 let state_manager = StateManager::new(self.strategy.clone(), self.config.clone());
427
428 let handle = std::thread::spawn(move || loop {
429 std::thread::sleep(interval);
430
431 match state_provider() {
432 Ok(state_data) => {
433 if let Err(e) =
434 state_manager.create_checkpoint(&pipeline_id_clone, state_data)
435 {
436 eprintln!("Auto-checkpoint failed: {e:?}");
437 }
438 }
439 Err(e) => {
440 eprintln!("Failed to get state for auto-checkpoint: {e:?}");
441 }
442 }
443 });
444
445 let mut timers = self
446 .checkpoint_timers
447 .lock()
448 .unwrap_or_else(|e| e.into_inner());
449 timers.insert(pipeline_id, handle);
450 }
451
452 Ok(())
453 }
454
455 pub fn stop_auto_checkpoint(&self, pipeline_id: &str) -> SklResult<()> {
457 let mut timers = self
458 .checkpoint_timers
459 .lock()
460 .unwrap_or_else(|e| e.into_inner());
461 if let Some(handle) = timers.remove(pipeline_id) {
462 }
465 Ok(())
466 }
467
468 pub fn add_listener(&self, listener: Box<dyn Fn(&StateSnapshot) + Send + Sync>) {
470 let mut listeners = self.listeners.write().unwrap_or_else(|e| e.into_inner());
471 listeners.push(listener);
472 }
473
474 fn save_to_filesystem(
476 &self,
477 snapshot: &StateSnapshot,
478 base_path: &Path,
479 compression: bool,
480 ) -> SklResult<()> {
481 fs::create_dir_all(base_path)?;
483
484 let file_path = base_path.join(format!("{}.snapshot", snapshot.id));
485 let file = File::create(file_path)?;
486 let mut writer = BufWriter::new(file);
487
488 let json_data = self.serialize_snapshot(snapshot)?;
490
491 if compression {
492 writer.write_all(json_data.as_bytes())?;
494 } else {
495 writer.write_all(json_data.as_bytes())?;
496 }
497
498 writer.flush()?;
499 Ok(())
500 }
501
502 fn load_from_filesystem(
504 &self,
505 snapshot_id: &str,
506 base_path: &Path,
507 ) -> SklResult<StateSnapshot> {
508 let file_path = base_path.join(format!("{snapshot_id}.snapshot"));
509
510 if !file_path.exists() {
511 return Err(SklearsError::InvalidInput(format!(
512 "Snapshot file {} not found",
513 file_path.display()
514 )));
515 }
516
517 let file = File::open(file_path)?;
518 let mut reader = BufReader::new(file);
519 let mut contents = String::new();
520 reader.read_to_string(&mut contents)?;
521
522 self.deserialize_snapshot(&contents)
523 }
524
525 fn save_to_distributed(
527 &self,
528 _snapshot: &StateSnapshot,
529 _nodes: &[String],
530 _replication_factor: usize,
531 ) -> SklResult<()> {
532 Ok(())
538 }
539
540 fn load_from_distributed(
542 &self,
543 _snapshot_id: &str,
544 _nodes: &[String],
545 ) -> SklResult<StateSnapshot> {
546 Err(SklearsError::InvalidInput(
548 "Distributed loading not implemented".to_string(),
549 ))
550 }
551
552 fn save_to_database(
554 &self,
555 _snapshot: &StateSnapshot,
556 _connection_string: &str,
557 _table_name: &str,
558 ) -> SklResult<()> {
559 Ok(())
562 }
563
564 fn load_from_database(
566 &self,
567 _snapshot_id: &str,
568 _connection_string: &str,
569 _table_name: &str,
570 ) -> SklResult<StateSnapshot> {
571 Err(SklearsError::InvalidInput(
573 "Database loading not implemented".to_string(),
574 ))
575 }
576
577 fn serialize_snapshot(&self, snapshot: &StateSnapshot) -> SklResult<String> {
579 Ok(format!(
582 r#"{{
583 "id": "{}",
584 "timestamp": {},
585 "version": {},
586 "checksum": "{}"
587 }}"#,
588 snapshot.id,
589 snapshot
590 .timestamp
591 .duration_since(UNIX_EPOCH)
592 .unwrap_or_default()
593 .as_secs(),
594 snapshot.version,
595 snapshot.checksum
596 ))
597 }
598
599 fn deserialize_snapshot(&self, _json_data: &str) -> SklResult<StateSnapshot> {
601 Ok(StateSnapshot {
604 id: "dummy".to_string(),
605 timestamp: SystemTime::now(),
606 state_data: StateData {
607 config: HashMap::new(),
608 model_parameters: HashMap::new(),
609 feature_names: None,
610 steps_state: Vec::new(),
611 execution_stats: ExecutionStatistics::default(),
612 custom_data: HashMap::new(),
613 },
614 metadata: HashMap::new(),
615 version: 1,
616 parent_id: None,
617 checksum: "dummy_checksum".to_string(),
618 })
619 }
620
621 fn generate_snapshot_id(&self, pipeline_id: &str) -> String {
623 let timestamp = SystemTime::now()
624 .duration_since(UNIX_EPOCH)
625 .unwrap_or_default()
626 .as_millis();
627 format!("{pipeline_id}_{timestamp}")
628 }
629
630 fn calculate_checksum(&self, state_data: &StateData) -> SklResult<String> {
632 use std::collections::hash_map::DefaultHasher;
635 use std::hash::Hasher;
636
637 let mut hasher = DefaultHasher::new();
638 state_data.config.len().hash(&mut hasher);
639 state_data.model_parameters.len().hash(&mut hasher);
640 state_data.steps_state.len().hash(&mut hasher);
641
642 Ok(format!("checksum_{}", hasher.finish()))
643 }
644
645 fn get_next_version(&self) -> u64 {
647 let snapshots = self.snapshots.read().unwrap_or_else(|e| e.into_inner());
648 snapshots.values().map(|s| s.version).max().unwrap_or(0) + 1
649 }
650
651 fn get_latest_snapshot_id(&self, pipeline_id: &str) -> Option<String> {
653 let snapshots = self.snapshots.read().unwrap_or_else(|e| e.into_inner());
654 snapshots
655 .values()
656 .filter(|s| s.id.starts_with(pipeline_id))
657 .max_by_key(|s| s.timestamp)
658 .map(|s| s.id.clone())
659 }
660
661 fn notify_listeners(&self, snapshot: &StateSnapshot) {
663 let listeners = self.listeners.read().unwrap_or_else(|e| e.into_inner());
664 for listener in listeners.iter() {
665 listener(snapshot);
666 }
667 }
668}
669
670pub struct StateSynchronizer {
672 local_state: Arc<StateManager>,
674 remote_states: Vec<Arc<StateManager>>,
676 config: SyncConfig,
678 conflict_resolution: ConflictResolution,
680}
681
682#[derive(Debug, Clone)]
684pub struct SyncConfig {
685 pub sync_interval: Duration,
687 pub bidirectional: bool,
689 pub conflict_detection: bool,
691 pub batch_sync: bool,
693 pub max_retries: usize,
695}
696
697impl Default for SyncConfig {
698 fn default() -> Self {
699 Self {
700 sync_interval: Duration::from_secs(30),
701 bidirectional: true,
702 conflict_detection: true,
703 batch_sync: false,
704 max_retries: 3,
705 }
706 }
707}
708
709#[derive(Debug, Clone)]
711pub enum ConflictResolution {
712 LatestWins,
714 HighestVersionWins,
716 Manual,
718 Custom(fn(&StateSnapshot, &StateSnapshot) -> StateSnapshot),
720}
721
722impl StateSynchronizer {
723 #[must_use]
725 pub fn new(
726 local_state: Arc<StateManager>,
727 config: SyncConfig,
728 conflict_resolution: ConflictResolution,
729 ) -> Self {
730 Self {
731 local_state,
732 remote_states: Vec::new(),
733 config,
734 conflict_resolution,
735 }
736 }
737
738 pub fn add_remote(&mut self, remote_state: Arc<StateManager>) {
740 self.remote_states.push(remote_state);
741 }
742
743 pub fn synchronize(&self) -> SklResult<SyncResult> {
745 let mut result = SyncResult {
746 synced_snapshots: 0,
747 conflicts_resolved: 0,
748 errors: Vec::new(),
749 };
750
751 for remote in &self.remote_states {
752 match self.sync_with_remote(remote) {
753 Ok(sync_stats) => {
754 result.synced_snapshots += sync_stats.synced_snapshots;
755 result.conflicts_resolved += sync_stats.conflicts_resolved;
756 }
757 Err(e) => {
758 result.errors.push(format!("Sync error: {e:?}"));
759 }
760 }
761 }
762
763 Ok(result)
764 }
765
766 fn sync_with_remote(&self, remote: &Arc<StateManager>) -> SklResult<SyncResult> {
768 let mut result = SyncResult {
769 synced_snapshots: 0,
770 conflicts_resolved: 0,
771 errors: Vec::new(),
772 };
773
774 let local_snapshots = self.local_state.list_snapshots();
776 let remote_snapshots = remote.list_snapshots();
777
778 for remote_id in &remote_snapshots {
780 if !local_snapshots.contains(remote_id) {
781 match remote.load_snapshot(remote_id) {
783 Ok(remote_snapshot) => {
784 if let Some(local_snapshot) =
786 self.find_conflicting_snapshot(&remote_snapshot)
787 {
788 let resolved =
789 self.resolve_conflict(&local_snapshot, &remote_snapshot)?;
790 self.local_state.save_snapshot(resolved)?;
791 result.conflicts_resolved += 1;
792 } else {
793 self.local_state.save_snapshot(remote_snapshot)?;
794 result.synced_snapshots += 1;
795 }
796 }
797 Err(e) => {
798 result
799 .errors
800 .push(format!("Failed to load remote snapshot {remote_id}: {e:?}"));
801 }
802 }
803 }
804 }
805
806 if self.config.bidirectional {
808 for local_id in &local_snapshots {
809 if !remote_snapshots.contains(local_id) {
810 match self.local_state.load_snapshot(local_id) {
811 Ok(local_snapshot) => {
812 remote.save_snapshot(local_snapshot)?;
813 result.synced_snapshots += 1;
814 }
815 Err(e) => {
816 result
817 .errors
818 .push(format!("Failed to sync local snapshot {local_id}: {e:?}"));
819 }
820 }
821 }
822 }
823 }
824
825 Ok(result)
826 }
827
828 fn find_conflicting_snapshot(&self, remote_snapshot: &StateSnapshot) -> Option<StateSnapshot> {
830 None
833 }
834
835 fn resolve_conflict(
837 &self,
838 local: &StateSnapshot,
839 remote: &StateSnapshot,
840 ) -> SklResult<StateSnapshot> {
841 match &self.conflict_resolution {
842 ConflictResolution::LatestWins => {
843 if remote.timestamp > local.timestamp {
844 Ok(remote.clone())
845 } else {
846 Ok(local.clone())
847 }
848 }
849 ConflictResolution::HighestVersionWins => {
850 if remote.version > local.version {
851 Ok(remote.clone())
852 } else {
853 Ok(local.clone())
854 }
855 }
856 ConflictResolution::Manual => Err(SklearsError::InvalidData {
857 reason: "Manual conflict resolution required".to_string(),
858 }),
859 ConflictResolution::Custom(resolve_fn) => Ok(resolve_fn(local, remote)),
860 }
861 }
862}
863
864#[derive(Debug, Clone)]
866pub struct SyncResult {
867 pub synced_snapshots: usize,
869 pub conflicts_resolved: usize,
871 pub errors: Vec<String>,
873}
874
875pub struct PipelineVersionControl {
877 state_manager: Arc<StateManager>,
879 branches: Arc<RwLock<HashMap<String, Branch>>>,
881 current_branch: Arc<RwLock<String>>,
883 tags: Arc<RwLock<HashMap<String, String>>>, }
886
887#[derive(Debug, Clone)]
889pub struct Branch {
890 pub name: String,
892 pub head: Option<String>,
894 pub created_at: SystemTime,
896 pub metadata: HashMap<String, String>,
898}
899
900impl PipelineVersionControl {
901 #[must_use]
903 pub fn new(state_manager: Arc<StateManager>) -> Self {
904 let mut branches = HashMap::new();
905 branches.insert(
906 "main".to_string(),
907 Branch {
908 name: "main".to_string(),
909 head: None,
910 created_at: SystemTime::now(),
911 metadata: HashMap::new(),
912 },
913 );
914
915 Self {
916 state_manager,
917 branches: Arc::new(RwLock::new(branches)),
918 current_branch: Arc::new(RwLock::new("main".to_string())),
919 tags: Arc::new(RwLock::new(HashMap::new())),
920 }
921 }
922
923 pub fn create_branch(&self, branch_name: &str, from_snapshot: Option<&str>) -> SklResult<()> {
925 let mut branches = self.branches.write().unwrap_or_else(|e| e.into_inner());
926
927 if branches.contains_key(branch_name) {
928 return Err(SklearsError::InvalidInput(format!(
929 "Branch {branch_name} already exists"
930 )));
931 }
932
933 let branch = Branch {
934 name: branch_name.to_string(),
935 head: from_snapshot.map(std::string::ToString::to_string),
936 created_at: SystemTime::now(),
937 metadata: HashMap::new(),
938 };
939
940 branches.insert(branch_name.to_string(), branch);
941 Ok(())
942 }
943
944 pub fn checkout_branch(&self, branch_name: &str) -> SklResult<()> {
946 let branches = self.branches.read().unwrap_or_else(|e| e.into_inner());
947
948 if !branches.contains_key(branch_name) {
949 return Err(SklearsError::InvalidInput(format!(
950 "Branch {branch_name} does not exist"
951 )));
952 }
953
954 let mut current = self
955 .current_branch
956 .write()
957 .unwrap_or_else(|e| e.into_inner());
958 *current = branch_name.to_string();
959 Ok(())
960 }
961
962 pub fn commit(&self, snapshot_id: &str, message: &str) -> SklResult<()> {
964 let current_branch_name = {
965 let current = self
966 .current_branch
967 .read()
968 .unwrap_or_else(|e| e.into_inner());
969 current.clone()
970 };
971
972 let mut branches = self.branches.write().unwrap_or_else(|e| e.into_inner());
973 if let Some(branch) = branches.get_mut(¤t_branch_name) {
974 branch.head = Some(snapshot_id.to_string());
975 branch
976 .metadata
977 .insert("last_commit_message".to_string(), message.to_string());
978 branch.metadata.insert(
979 "last_commit_time".to_string(),
980 SystemTime::now()
981 .duration_since(UNIX_EPOCH)
982 .unwrap_or_default()
983 .as_secs()
984 .to_string(),
985 );
986 }
987
988 Ok(())
989 }
990
991 pub fn create_tag(&self, tag_name: &str, snapshot_id: &str) -> SklResult<()> {
993 let mut tags = self.tags.write().unwrap_or_else(|e| e.into_inner());
994 tags.insert(tag_name.to_string(), snapshot_id.to_string());
995 Ok(())
996 }
997
998 #[must_use]
1000 pub fn get_tag(&self, tag_name: &str) -> Option<String> {
1001 let tags = self.tags.read().unwrap_or_else(|e| e.into_inner());
1002 tags.get(tag_name).cloned()
1003 }
1004
1005 #[must_use]
1007 pub fn list_branches(&self) -> Vec<String> {
1008 let branches = self.branches.read().unwrap_or_else(|e| e.into_inner());
1009 branches.keys().cloned().collect()
1010 }
1011
1012 #[must_use]
1014 pub fn list_tags(&self) -> HashMap<String, String> {
1015 let tags = self.tags.read().unwrap_or_else(|e| e.into_inner());
1016 tags.clone()
1017 }
1018
1019 #[must_use]
1021 pub fn current_branch(&self) -> String {
1022 let current = self
1023 .current_branch
1024 .read()
1025 .unwrap_or_else(|e| e.into_inner());
1026 current.clone()
1027 }
1028}
1029
1030#[allow(non_snake_case)]
1031#[cfg(test)]
1032mod tests {
1033 use super::*;
1034 use std::env;
1035
1036 #[test]
1037 fn test_state_snapshot_creation() {
1038 let snapshot = StateSnapshot {
1039 id: "test_snapshot".to_string(),
1040 timestamp: SystemTime::now(),
1041 state_data: StateData {
1042 config: HashMap::new(),
1043 model_parameters: HashMap::new(),
1044 feature_names: None,
1045 steps_state: Vec::new(),
1046 execution_stats: ExecutionStatistics::default(),
1047 custom_data: HashMap::new(),
1048 },
1049 metadata: HashMap::new(),
1050 version: 1,
1051 parent_id: None,
1052 checksum: "test_checksum".to_string(),
1053 };
1054
1055 assert_eq!(snapshot.id, "test_snapshot");
1056 assert_eq!(snapshot.version, 1);
1057 }
1058
1059 #[test]
1060 fn test_state_manager_memory() {
1061 let strategy = PersistenceStrategy::InMemory;
1062 let config = CheckpointConfig::default();
1063 let manager = StateManager::new(strategy, config);
1064
1065 let state_data = StateData {
1066 config: HashMap::new(),
1067 model_parameters: HashMap::new(),
1068 feature_names: None,
1069 steps_state: Vec::new(),
1070 execution_stats: ExecutionStatistics::default(),
1071 custom_data: HashMap::new(),
1072 };
1073
1074 let checkpoint_id = manager
1075 .create_checkpoint("test_pipeline", state_data)
1076 .unwrap_or_default();
1077 assert!(checkpoint_id.starts_with("test_pipeline"));
1078
1079 let loaded_state = manager
1080 .resume_from_checkpoint(&checkpoint_id)
1081 .expect("operation should succeed");
1082 assert_eq!(loaded_state.config.len(), 0);
1083 }
1084
1085 #[test]
1086 fn test_version_control() {
1087 let strategy = PersistenceStrategy::InMemory;
1088 let config = CheckpointConfig::default();
1089 let state_manager = Arc::new(StateManager::new(strategy, config));
1090 let vc = PipelineVersionControl::new(state_manager);
1091
1092 assert_eq!(vc.current_branch(), "main");
1093
1094 vc.create_branch("feature", None).unwrap_or_default();
1095 vc.checkout_branch("feature").unwrap_or_default();
1096 assert_eq!(vc.current_branch(), "feature");
1097
1098 vc.create_tag("v1.0", "snapshot_123").unwrap_or_default();
1099 assert_eq!(vc.get_tag("v1.0"), Some("snapshot_123".to_string()));
1100 }
1101
1102 #[test]
1103 fn test_checkpoint_config() {
1104 let config = CheckpointConfig {
1105 auto_checkpoint_interval: Some(Duration::from_secs(60)),
1106 max_checkpoints: 5,
1107 checkpoint_on_update: true,
1108 checkpoint_on_error: false,
1109 compression_level: 9,
1110 incremental: true,
1111 };
1112
1113 assert_eq!(config.max_checkpoints, 5);
1114 assert_eq!(config.compression_level, 9);
1115 assert!(config.incremental);
1116 }
1117
1118 #[test]
1119 fn test_execution_statistics() {
1120 let mut stats = ExecutionStatistics::default();
1121 stats.training_samples = 1000;
1122 stats.prediction_requests = 50;
1123 stats.accuracy = Some(0.95);
1124
1125 assert_eq!(stats.training_samples, 1000);
1126 assert_eq!(stats.prediction_requests, 50);
1127 assert_eq!(stats.accuracy, Some(0.95));
1128 }
1129}