1use sklears_core::error::{Result as SklResult, SklearsError};
7use std::collections::{BTreeMap, HashMap};
8use std::fs::{self, File};
9use std::hash::Hash;
10use std::io::{BufReader, BufWriter, Read, Write};
11use std::path::{Path, PathBuf};
12use std::sync::{Arc, Mutex, RwLock};
13use std::time::{Duration, SystemTime, UNIX_EPOCH};
14
15#[derive(Debug, Clone)]
17pub struct StateSnapshot {
18 pub id: String,
20 pub timestamp: SystemTime,
22 pub state_data: StateData,
24 pub metadata: HashMap<String, String>,
26 pub version: u64,
28 pub parent_id: Option<String>,
30 pub checksum: String,
32}
33
34#[derive(Debug, Clone)]
36pub struct StateData {
37 pub config: HashMap<String, String>,
39 pub model_parameters: HashMap<String, Vec<f64>>,
41 pub feature_names: Option<Vec<String>>,
43 pub steps_state: Vec<StepState>,
45 pub execution_stats: ExecutionStatistics,
47 pub custom_data: HashMap<String, Vec<u8>>,
49}
50
51#[derive(Debug, Clone)]
53pub struct StepState {
54 pub name: String,
56 pub step_type: String,
58 pub parameters: HashMap<String, Vec<f64>>,
60 pub config: HashMap<String, String>,
62 pub is_fitted: bool,
64 pub metadata: HashMap<String, String>,
66}
67
68#[derive(Debug, Clone)]
70pub struct ExecutionStatistics {
71 pub training_samples: usize,
73 pub prediction_requests: usize,
75 pub avg_prediction_time: Duration,
77 pub accuracy: Option<f64>,
79 pub memory_usage: MemoryUsage,
81 pub last_updated: SystemTime,
83}
84
85#[derive(Debug, Clone, Default)]
87pub struct MemoryUsage {
88 pub peak_memory: u64,
90 pub current_memory: u64,
92 pub allocations: u64,
94 pub deallocations: u64,
96}
97
98impl Default for ExecutionStatistics {
99 fn default() -> Self {
100 Self {
101 training_samples: 0,
102 prediction_requests: 0,
103 avg_prediction_time: Duration::ZERO,
104 accuracy: None,
105 memory_usage: MemoryUsage::default(),
106 last_updated: SystemTime::now(),
107 }
108 }
109}
110
111#[derive(Debug, Clone)]
113pub enum PersistenceStrategy {
114 InMemory,
116 LocalFileSystem {
118 base_path: PathBuf,
120 compression: bool,
122 },
123 Distributed {
125 nodes: Vec<String>,
127 replication_factor: usize,
129 },
130 Database {
132 connection_string: String,
134 table_name: String,
136 },
137 Custom {
139 save_fn: fn(&StateSnapshot, &str) -> SklResult<()>,
141 load_fn: fn(&str) -> SklResult<StateSnapshot>,
143 },
144}
145
146#[derive(Debug, Clone)]
148pub struct CheckpointConfig {
149 pub auto_checkpoint_interval: Option<Duration>,
151 pub max_checkpoints: usize,
153 pub checkpoint_on_update: bool,
155 pub checkpoint_on_error: bool,
157 pub compression_level: u32,
159 pub incremental: bool,
161}
162
163impl Default for CheckpointConfig {
164 fn default() -> Self {
165 Self {
166 auto_checkpoint_interval: Some(Duration::from_secs(300)), max_checkpoints: 10,
168 checkpoint_on_update: true,
169 checkpoint_on_error: true,
170 compression_level: 6,
171 incremental: false,
172 }
173 }
174}
175
176pub struct StateManager {
178 strategy: PersistenceStrategy,
180 config: CheckpointConfig,
182 snapshots: Arc<RwLock<BTreeMap<String, StateSnapshot>>>,
184 version_history: Arc<RwLock<Vec<String>>>,
186 checkpoint_timers: Arc<Mutex<HashMap<String, std::thread::JoinHandle<()>>>>,
188 listeners: Arc<RwLock<Vec<Box<dyn Fn(&StateSnapshot) + Send + Sync>>>>,
190}
191
192impl StateManager {
193 #[must_use]
195 pub fn new(strategy: PersistenceStrategy, config: CheckpointConfig) -> Self {
196 Self {
197 strategy,
198 config,
199 snapshots: Arc::new(RwLock::new(BTreeMap::new())),
200 version_history: Arc::new(RwLock::new(Vec::new())),
201 checkpoint_timers: Arc::new(Mutex::new(HashMap::new())),
202 listeners: Arc::new(RwLock::new(Vec::new())),
203 }
204 }
205
206 pub fn save_snapshot(&self, snapshot: StateSnapshot) -> SklResult<()> {
208 {
210 let mut snapshots = self.snapshots.write().unwrap();
211 snapshots.insert(snapshot.id.clone(), snapshot.clone());
212
213 if snapshots.len() > self.config.max_checkpoints {
215 if let Some((oldest_id, _)) = snapshots.iter().next() {
216 let oldest_id = oldest_id.clone();
217 snapshots.remove(&oldest_id);
218 }
219 }
220 }
221
222 {
224 let mut history = self.version_history.write().unwrap();
225 history.push(snapshot.id.clone());
226
227 if history.len() > self.config.max_checkpoints {
229 history.remove(0);
230 }
231 }
232
233 match &self.strategy {
235 PersistenceStrategy::InMemory => {
236 }
238 PersistenceStrategy::LocalFileSystem {
239 base_path,
240 compression,
241 } => {
242 self.save_to_filesystem(&snapshot, base_path, *compression)?;
243 }
244 PersistenceStrategy::Distributed {
245 nodes,
246 replication_factor,
247 } => {
248 self.save_to_distributed(&snapshot, nodes, *replication_factor)?;
249 }
250 PersistenceStrategy::Database {
251 connection_string,
252 table_name,
253 } => {
254 self.save_to_database(&snapshot, connection_string, table_name)?;
255 }
256 PersistenceStrategy::Custom { save_fn, .. } => {
257 save_fn(&snapshot, &snapshot.id)?;
258 }
259 }
260
261 self.notify_listeners(&snapshot);
263
264 Ok(())
265 }
266
267 pub fn load_snapshot(&self, snapshot_id: &str) -> SklResult<StateSnapshot> {
269 {
271 let snapshots = self.snapshots.read().unwrap();
272 if let Some(snapshot) = snapshots.get(snapshot_id) {
273 return Ok(snapshot.clone());
274 }
275 }
276
277 match &self.strategy {
279 PersistenceStrategy::InMemory => Err(SklearsError::InvalidInput(format!(
280 "Snapshot {snapshot_id} not found in memory"
281 ))),
282 PersistenceStrategy::LocalFileSystem {
283 base_path,
284 compression: _,
285 } => self.load_from_filesystem(snapshot_id, base_path),
286 PersistenceStrategy::Distributed {
287 nodes,
288 replication_factor: _,
289 } => self.load_from_distributed(snapshot_id, nodes),
290 PersistenceStrategy::Database {
291 connection_string,
292 table_name,
293 } => self.load_from_database(snapshot_id, connection_string, table_name),
294 PersistenceStrategy::Custom { load_fn, .. } => load_fn(snapshot_id),
295 }
296 }
297
298 pub fn create_checkpoint(&self, pipeline_id: &str, state_data: StateData) -> SklResult<String> {
300 let snapshot_id = self.generate_snapshot_id(pipeline_id);
301 let checksum = self.calculate_checksum(&state_data)?;
302
303 let snapshot = StateSnapshot {
304 id: snapshot_id.clone(),
305 timestamp: SystemTime::now(),
306 state_data,
307 metadata: HashMap::new(),
308 version: self.get_next_version(),
309 parent_id: self.get_latest_snapshot_id(pipeline_id),
310 checksum,
311 };
312
313 self.save_snapshot(snapshot)?;
314 Ok(snapshot_id)
315 }
316
317 pub fn resume_from_checkpoint(&self, snapshot_id: &str) -> SklResult<StateData> {
319 let snapshot = self.load_snapshot(snapshot_id)?;
320
321 let calculated_checksum = self.calculate_checksum(&snapshot.state_data)?;
323 if calculated_checksum != snapshot.checksum {
324 return Err(SklearsError::InvalidData {
325 reason: format!("Checksum mismatch for snapshot {snapshot_id}"),
326 });
327 }
328
329 Ok(snapshot.state_data)
330 }
331
332 #[must_use]
334 pub fn list_snapshots(&self) -> Vec<String> {
335 let snapshots = self.snapshots.read().unwrap();
336 snapshots.keys().cloned().collect()
337 }
338
339 #[must_use]
341 pub fn get_version_history(&self) -> Vec<String> {
342 let history = self.version_history.read().unwrap();
343 history.clone()
344 }
345
346 pub fn rollback(&self, target_snapshot_id: &str) -> SklResult<StateData> {
348 let snapshot = self.load_snapshot(target_snapshot_id)?;
349
350 let rollback_id = format!("rollback_{target_snapshot_id}");
352 let rollback_snapshot = StateSnapshot {
353 id: rollback_id,
354 timestamp: SystemTime::now(),
355 state_data: snapshot.state_data.clone(),
356 metadata: {
357 let mut meta = HashMap::new();
358 meta.insert("rollback_from".to_string(), target_snapshot_id.to_string());
359 meta
360 },
361 version: self.get_next_version(),
362 parent_id: Some(target_snapshot_id.to_string()),
363 checksum: snapshot.checksum.clone(),
364 };
365
366 self.save_snapshot(rollback_snapshot)?;
367 Ok(snapshot.state_data)
368 }
369
370 pub fn delete_snapshot(&self, snapshot_id: &str) -> SklResult<()> {
372 {
374 let mut snapshots = self.snapshots.write().unwrap();
375 snapshots.remove(snapshot_id);
376 }
377
378 {
380 let mut history = self.version_history.write().unwrap();
381 history.retain(|id| id != snapshot_id);
382 }
383
384 match &self.strategy {
386 PersistenceStrategy::InMemory => {
387 }
389 PersistenceStrategy::LocalFileSystem { base_path, .. } => {
390 let file_path = base_path.join(format!("{snapshot_id}.snapshot"));
391 if file_path.exists() {
392 fs::remove_file(file_path)?;
393 }
394 }
395 PersistenceStrategy::Distributed { .. } => {
396 }
398 PersistenceStrategy::Database { .. } => {
399 }
401 PersistenceStrategy::Custom { .. } => {
402 }
404 }
405
406 Ok(())
407 }
408
409 pub fn start_auto_checkpoint(
411 &self,
412 pipeline_id: String,
413 state_provider: Arc<dyn Fn() -> SklResult<StateData> + Send + Sync>,
414 ) -> SklResult<()> {
415 if let Some(interval) = self.config.auto_checkpoint_interval {
416 let pipeline_id_clone = pipeline_id.clone();
417 let state_manager = StateManager::new(self.strategy.clone(), self.config.clone());
418
419 let handle = std::thread::spawn(move || loop {
420 std::thread::sleep(interval);
421
422 match state_provider() {
423 Ok(state_data) => {
424 if let Err(e) =
425 state_manager.create_checkpoint(&pipeline_id_clone, state_data)
426 {
427 eprintln!("Auto-checkpoint failed: {e:?}");
428 }
429 }
430 Err(e) => {
431 eprintln!("Failed to get state for auto-checkpoint: {e:?}");
432 }
433 }
434 });
435
436 let mut timers = self.checkpoint_timers.lock().unwrap();
437 timers.insert(pipeline_id, handle);
438 }
439
440 Ok(())
441 }
442
443 pub fn stop_auto_checkpoint(&self, pipeline_id: &str) -> SklResult<()> {
445 let mut timers = self.checkpoint_timers.lock().unwrap();
446 if let Some(handle) = timers.remove(pipeline_id) {
447 }
450 Ok(())
451 }
452
453 pub fn add_listener(&self, listener: Box<dyn Fn(&StateSnapshot) + Send + Sync>) {
455 let mut listeners = self.listeners.write().unwrap();
456 listeners.push(listener);
457 }
458
459 fn save_to_filesystem(
461 &self,
462 snapshot: &StateSnapshot,
463 base_path: &Path,
464 compression: bool,
465 ) -> SklResult<()> {
466 fs::create_dir_all(base_path)?;
468
469 let file_path = base_path.join(format!("{}.snapshot", snapshot.id));
470 let file = File::create(file_path)?;
471 let mut writer = BufWriter::new(file);
472
473 let json_data = self.serialize_snapshot(snapshot)?;
475
476 if compression {
477 writer.write_all(json_data.as_bytes())?;
479 } else {
480 writer.write_all(json_data.as_bytes())?;
481 }
482
483 writer.flush()?;
484 Ok(())
485 }
486
487 fn load_from_filesystem(
489 &self,
490 snapshot_id: &str,
491 base_path: &Path,
492 ) -> SklResult<StateSnapshot> {
493 let file_path = base_path.join(format!("{snapshot_id}.snapshot"));
494
495 if !file_path.exists() {
496 return Err(SklearsError::InvalidInput(format!(
497 "Snapshot file {} not found",
498 file_path.display()
499 )));
500 }
501
502 let file = File::open(file_path)?;
503 let mut reader = BufReader::new(file);
504 let mut contents = String::new();
505 reader.read_to_string(&mut contents)?;
506
507 self.deserialize_snapshot(&contents)
508 }
509
510 fn save_to_distributed(
512 &self,
513 _snapshot: &StateSnapshot,
514 _nodes: &[String],
515 _replication_factor: usize,
516 ) -> SklResult<()> {
517 Ok(())
523 }
524
525 fn load_from_distributed(
527 &self,
528 _snapshot_id: &str,
529 _nodes: &[String],
530 ) -> SklResult<StateSnapshot> {
531 Err(SklearsError::InvalidInput(
533 "Distributed loading not implemented".to_string(),
534 ))
535 }
536
537 fn save_to_database(
539 &self,
540 _snapshot: &StateSnapshot,
541 _connection_string: &str,
542 _table_name: &str,
543 ) -> SklResult<()> {
544 Ok(())
547 }
548
549 fn load_from_database(
551 &self,
552 _snapshot_id: &str,
553 _connection_string: &str,
554 _table_name: &str,
555 ) -> SklResult<StateSnapshot> {
556 Err(SklearsError::InvalidInput(
558 "Database loading not implemented".to_string(),
559 ))
560 }
561
562 fn serialize_snapshot(&self, snapshot: &StateSnapshot) -> SklResult<String> {
564 Ok(format!(
567 r#"{{
568 "id": "{}",
569 "timestamp": {},
570 "version": {},
571 "checksum": "{}"
572 }}"#,
573 snapshot.id,
574 snapshot
575 .timestamp
576 .duration_since(UNIX_EPOCH)
577 .unwrap()
578 .as_secs(),
579 snapshot.version,
580 snapshot.checksum
581 ))
582 }
583
584 fn deserialize_snapshot(&self, _json_data: &str) -> SklResult<StateSnapshot> {
586 Ok(StateSnapshot {
589 id: "dummy".to_string(),
590 timestamp: SystemTime::now(),
591 state_data: StateData {
592 config: HashMap::new(),
593 model_parameters: HashMap::new(),
594 feature_names: None,
595 steps_state: Vec::new(),
596 execution_stats: ExecutionStatistics::default(),
597 custom_data: HashMap::new(),
598 },
599 metadata: HashMap::new(),
600 version: 1,
601 parent_id: None,
602 checksum: "dummy_checksum".to_string(),
603 })
604 }
605
606 fn generate_snapshot_id(&self, pipeline_id: &str) -> String {
608 let timestamp = SystemTime::now()
609 .duration_since(UNIX_EPOCH)
610 .unwrap()
611 .as_millis();
612 format!("{pipeline_id}_{timestamp}")
613 }
614
615 fn calculate_checksum(&self, state_data: &StateData) -> SklResult<String> {
617 use std::collections::hash_map::DefaultHasher;
620 use std::hash::Hasher;
621
622 let mut hasher = DefaultHasher::new();
623 state_data.config.len().hash(&mut hasher);
624 state_data.model_parameters.len().hash(&mut hasher);
625 state_data.steps_state.len().hash(&mut hasher);
626
627 Ok(format!("checksum_{}", hasher.finish()))
628 }
629
630 fn get_next_version(&self) -> u64 {
632 let snapshots = self.snapshots.read().unwrap();
633 snapshots.values().map(|s| s.version).max().unwrap_or(0) + 1
634 }
635
636 fn get_latest_snapshot_id(&self, pipeline_id: &str) -> Option<String> {
638 let snapshots = self.snapshots.read().unwrap();
639 snapshots
640 .values()
641 .filter(|s| s.id.starts_with(pipeline_id))
642 .max_by_key(|s| s.timestamp)
643 .map(|s| s.id.clone())
644 }
645
646 fn notify_listeners(&self, snapshot: &StateSnapshot) {
648 let listeners = self.listeners.read().unwrap();
649 for listener in listeners.iter() {
650 listener(snapshot);
651 }
652 }
653}
654
655pub struct StateSynchronizer {
657 local_state: Arc<StateManager>,
659 remote_states: Vec<Arc<StateManager>>,
661 config: SyncConfig,
663 conflict_resolution: ConflictResolution,
665}
666
667#[derive(Debug, Clone)]
669pub struct SyncConfig {
670 pub sync_interval: Duration,
672 pub bidirectional: bool,
674 pub conflict_detection: bool,
676 pub batch_sync: bool,
678 pub max_retries: usize,
680}
681
682impl Default for SyncConfig {
683 fn default() -> Self {
684 Self {
685 sync_interval: Duration::from_secs(30),
686 bidirectional: true,
687 conflict_detection: true,
688 batch_sync: false,
689 max_retries: 3,
690 }
691 }
692}
693
694#[derive(Debug, Clone)]
696pub enum ConflictResolution {
697 LatestWins,
699 HighestVersionWins,
701 Manual,
703 Custom(fn(&StateSnapshot, &StateSnapshot) -> StateSnapshot),
705}
706
707impl StateSynchronizer {
708 #[must_use]
710 pub fn new(
711 local_state: Arc<StateManager>,
712 config: SyncConfig,
713 conflict_resolution: ConflictResolution,
714 ) -> Self {
715 Self {
716 local_state,
717 remote_states: Vec::new(),
718 config,
719 conflict_resolution,
720 }
721 }
722
723 pub fn add_remote(&mut self, remote_state: Arc<StateManager>) {
725 self.remote_states.push(remote_state);
726 }
727
728 pub fn synchronize(&self) -> SklResult<SyncResult> {
730 let mut result = SyncResult {
731 synced_snapshots: 0,
732 conflicts_resolved: 0,
733 errors: Vec::new(),
734 };
735
736 for remote in &self.remote_states {
737 match self.sync_with_remote(remote) {
738 Ok(sync_stats) => {
739 result.synced_snapshots += sync_stats.synced_snapshots;
740 result.conflicts_resolved += sync_stats.conflicts_resolved;
741 }
742 Err(e) => {
743 result.errors.push(format!("Sync error: {e:?}"));
744 }
745 }
746 }
747
748 Ok(result)
749 }
750
751 fn sync_with_remote(&self, remote: &Arc<StateManager>) -> SklResult<SyncResult> {
753 let mut result = SyncResult {
754 synced_snapshots: 0,
755 conflicts_resolved: 0,
756 errors: Vec::new(),
757 };
758
759 let local_snapshots = self.local_state.list_snapshots();
761 let remote_snapshots = remote.list_snapshots();
762
763 for remote_id in &remote_snapshots {
765 if !local_snapshots.contains(remote_id) {
766 match remote.load_snapshot(remote_id) {
768 Ok(remote_snapshot) => {
769 if let Some(local_snapshot) =
771 self.find_conflicting_snapshot(&remote_snapshot)
772 {
773 let resolved =
774 self.resolve_conflict(&local_snapshot, &remote_snapshot)?;
775 self.local_state.save_snapshot(resolved)?;
776 result.conflicts_resolved += 1;
777 } else {
778 self.local_state.save_snapshot(remote_snapshot)?;
779 result.synced_snapshots += 1;
780 }
781 }
782 Err(e) => {
783 result
784 .errors
785 .push(format!("Failed to load remote snapshot {remote_id}: {e:?}"));
786 }
787 }
788 }
789 }
790
791 if self.config.bidirectional {
793 for local_id in &local_snapshots {
794 if !remote_snapshots.contains(local_id) {
795 match self.local_state.load_snapshot(local_id) {
796 Ok(local_snapshot) => {
797 remote.save_snapshot(local_snapshot)?;
798 result.synced_snapshots += 1;
799 }
800 Err(e) => {
801 result
802 .errors
803 .push(format!("Failed to sync local snapshot {local_id}: {e:?}"));
804 }
805 }
806 }
807 }
808 }
809
810 Ok(result)
811 }
812
813 fn find_conflicting_snapshot(&self, remote_snapshot: &StateSnapshot) -> Option<StateSnapshot> {
815 None
818 }
819
820 fn resolve_conflict(
822 &self,
823 local: &StateSnapshot,
824 remote: &StateSnapshot,
825 ) -> SklResult<StateSnapshot> {
826 match &self.conflict_resolution {
827 ConflictResolution::LatestWins => {
828 if remote.timestamp > local.timestamp {
829 Ok(remote.clone())
830 } else {
831 Ok(local.clone())
832 }
833 }
834 ConflictResolution::HighestVersionWins => {
835 if remote.version > local.version {
836 Ok(remote.clone())
837 } else {
838 Ok(local.clone())
839 }
840 }
841 ConflictResolution::Manual => Err(SklearsError::InvalidData {
842 reason: "Manual conflict resolution required".to_string(),
843 }),
844 ConflictResolution::Custom(resolve_fn) => Ok(resolve_fn(local, remote)),
845 }
846 }
847}
848
849#[derive(Debug, Clone)]
851pub struct SyncResult {
852 pub synced_snapshots: usize,
854 pub conflicts_resolved: usize,
856 pub errors: Vec<String>,
858}
859
860pub struct PipelineVersionControl {
862 state_manager: Arc<StateManager>,
864 branches: Arc<RwLock<HashMap<String, Branch>>>,
866 current_branch: Arc<RwLock<String>>,
868 tags: Arc<RwLock<HashMap<String, String>>>, }
871
872#[derive(Debug, Clone)]
874pub struct Branch {
875 pub name: String,
877 pub head: Option<String>,
879 pub created_at: SystemTime,
881 pub metadata: HashMap<String, String>,
883}
884
885impl PipelineVersionControl {
886 #[must_use]
888 pub fn new(state_manager: Arc<StateManager>) -> Self {
889 let mut branches = HashMap::new();
890 branches.insert(
891 "main".to_string(),
892 Branch {
893 name: "main".to_string(),
894 head: None,
895 created_at: SystemTime::now(),
896 metadata: HashMap::new(),
897 },
898 );
899
900 Self {
901 state_manager,
902 branches: Arc::new(RwLock::new(branches)),
903 current_branch: Arc::new(RwLock::new("main".to_string())),
904 tags: Arc::new(RwLock::new(HashMap::new())),
905 }
906 }
907
908 pub fn create_branch(&self, branch_name: &str, from_snapshot: Option<&str>) -> SklResult<()> {
910 let mut branches = self.branches.write().unwrap();
911
912 if branches.contains_key(branch_name) {
913 return Err(SklearsError::InvalidInput(format!(
914 "Branch {branch_name} already exists"
915 )));
916 }
917
918 let branch = Branch {
919 name: branch_name.to_string(),
920 head: from_snapshot.map(std::string::ToString::to_string),
921 created_at: SystemTime::now(),
922 metadata: HashMap::new(),
923 };
924
925 branches.insert(branch_name.to_string(), branch);
926 Ok(())
927 }
928
929 pub fn checkout_branch(&self, branch_name: &str) -> SklResult<()> {
931 let branches = self.branches.read().unwrap();
932
933 if !branches.contains_key(branch_name) {
934 return Err(SklearsError::InvalidInput(format!(
935 "Branch {branch_name} does not exist"
936 )));
937 }
938
939 let mut current = self.current_branch.write().unwrap();
940 *current = branch_name.to_string();
941 Ok(())
942 }
943
944 pub fn commit(&self, snapshot_id: &str, message: &str) -> SklResult<()> {
946 let current_branch_name = {
947 let current = self.current_branch.read().unwrap();
948 current.clone()
949 };
950
951 let mut branches = self.branches.write().unwrap();
952 if let Some(branch) = branches.get_mut(¤t_branch_name) {
953 branch.head = Some(snapshot_id.to_string());
954 branch
955 .metadata
956 .insert("last_commit_message".to_string(), message.to_string());
957 branch.metadata.insert(
958 "last_commit_time".to_string(),
959 SystemTime::now()
960 .duration_since(UNIX_EPOCH)
961 .unwrap()
962 .as_secs()
963 .to_string(),
964 );
965 }
966
967 Ok(())
968 }
969
970 pub fn create_tag(&self, tag_name: &str, snapshot_id: &str) -> SklResult<()> {
972 let mut tags = self.tags.write().unwrap();
973 tags.insert(tag_name.to_string(), snapshot_id.to_string());
974 Ok(())
975 }
976
977 #[must_use]
979 pub fn get_tag(&self, tag_name: &str) -> Option<String> {
980 let tags = self.tags.read().unwrap();
981 tags.get(tag_name).cloned()
982 }
983
984 #[must_use]
986 pub fn list_branches(&self) -> Vec<String> {
987 let branches = self.branches.read().unwrap();
988 branches.keys().cloned().collect()
989 }
990
991 #[must_use]
993 pub fn list_tags(&self) -> HashMap<String, String> {
994 let tags = self.tags.read().unwrap();
995 tags.clone()
996 }
997
998 #[must_use]
1000 pub fn current_branch(&self) -> String {
1001 let current = self.current_branch.read().unwrap();
1002 current.clone()
1003 }
1004}
1005
1006#[allow(non_snake_case)]
1007#[cfg(test)]
1008mod tests {
1009 use super::*;
1010 use std::env;
1011
1012 #[test]
1013 fn test_state_snapshot_creation() {
1014 let snapshot = StateSnapshot {
1015 id: "test_snapshot".to_string(),
1016 timestamp: SystemTime::now(),
1017 state_data: StateData {
1018 config: HashMap::new(),
1019 model_parameters: HashMap::new(),
1020 feature_names: None,
1021 steps_state: Vec::new(),
1022 execution_stats: ExecutionStatistics::default(),
1023 custom_data: HashMap::new(),
1024 },
1025 metadata: HashMap::new(),
1026 version: 1,
1027 parent_id: None,
1028 checksum: "test_checksum".to_string(),
1029 };
1030
1031 assert_eq!(snapshot.id, "test_snapshot");
1032 assert_eq!(snapshot.version, 1);
1033 }
1034
1035 #[test]
1036 fn test_state_manager_memory() {
1037 let strategy = PersistenceStrategy::InMemory;
1038 let config = CheckpointConfig::default();
1039 let manager = StateManager::new(strategy, config);
1040
1041 let state_data = StateData {
1042 config: HashMap::new(),
1043 model_parameters: HashMap::new(),
1044 feature_names: None,
1045 steps_state: Vec::new(),
1046 execution_stats: ExecutionStatistics::default(),
1047 custom_data: HashMap::new(),
1048 };
1049
1050 let checkpoint_id = manager
1051 .create_checkpoint("test_pipeline", state_data)
1052 .unwrap();
1053 assert!(checkpoint_id.starts_with("test_pipeline"));
1054
1055 let loaded_state = manager.resume_from_checkpoint(&checkpoint_id).unwrap();
1056 assert_eq!(loaded_state.config.len(), 0);
1057 }
1058
1059 #[test]
1060 fn test_version_control() {
1061 let strategy = PersistenceStrategy::InMemory;
1062 let config = CheckpointConfig::default();
1063 let state_manager = Arc::new(StateManager::new(strategy, config));
1064 let vc = PipelineVersionControl::new(state_manager);
1065
1066 assert_eq!(vc.current_branch(), "main");
1067
1068 vc.create_branch("feature", None).unwrap();
1069 vc.checkout_branch("feature").unwrap();
1070 assert_eq!(vc.current_branch(), "feature");
1071
1072 vc.create_tag("v1.0", "snapshot_123").unwrap();
1073 assert_eq!(vc.get_tag("v1.0"), Some("snapshot_123".to_string()));
1074 }
1075
1076 #[test]
1077 fn test_checkpoint_config() {
1078 let config = CheckpointConfig {
1079 auto_checkpoint_interval: Some(Duration::from_secs(60)),
1080 max_checkpoints: 5,
1081 checkpoint_on_update: true,
1082 checkpoint_on_error: false,
1083 compression_level: 9,
1084 incremental: true,
1085 };
1086
1087 assert_eq!(config.max_checkpoints, 5);
1088 assert_eq!(config.compression_level, 9);
1089 assert!(config.incremental);
1090 }
1091
1092 #[test]
1093 fn test_execution_statistics() {
1094 let mut stats = ExecutionStatistics::default();
1095 stats.training_samples = 1000;
1096 stats.prediction_requests = 50;
1097 stats.accuracy = Some(0.95);
1098
1099 assert_eq!(stats.training_samples, 1000);
1100 assert_eq!(stats.prediction_requests, 50);
1101 assert_eq!(stats.accuracy, Some(0.95));
1102 }
1103}