1use std::collections::HashMap;
43use std::fs;
44use std::path::{Path, PathBuf};
45use std::time::{SystemTime, UNIX_EPOCH};
46
47use serde::{Deserialize, Serialize};
48
49use super::{EpisodeTransitions, NgramStats, SelectionPerformance};
50use crate::online_stats::ActionStats;
51
52pub const SNAPSHOT_VERSION: u32 = 1;
54
55pub type Timestamp = u64;
57
58#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
60pub struct SessionId(pub String);
61
62impl SessionId {
63 pub fn timestamp(&self) -> Option<Timestamp> {
65 self.0.parse().ok()
66 }
67}
68
69#[derive(Debug, Clone, PartialEq, Eq, Hash)]
75pub enum SnapshotKey {
76 Global,
78 Scenario(String),
80 Session {
82 scenario: String,
83 session_id: SessionId,
84 },
85}
86
87impl std::fmt::Display for SnapshotKey {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 match self {
90 Self::Global => write!(f, "global"),
91 Self::Scenario(s) => write!(f, "scenario:{}", s),
92 Self::Session {
93 scenario,
94 session_id,
95 } => {
96 write!(f, "session:{}:{}", scenario, session_id.0)
97 }
98 }
99 }
100}
101
102pub trait SnapshotStorage {
104 type Error: std::error::Error + Send + Sync + 'static;
106
107 fn save(&self, key: &SnapshotKey, snapshot: &LearningSnapshot) -> Result<(), Self::Error>;
109
110 fn load(&self, key: &SnapshotKey) -> Result<Option<LearningSnapshot>, Self::Error>;
112
113 fn delete(&self, key: &SnapshotKey) -> Result<bool, Self::Error>;
115
116 fn exists(&self, key: &SnapshotKey) -> Result<bool, Self::Error>;
118}
119
120pub trait TimeSeriesQuery {
122 type Error: std::error::Error + Send + Sync + 'static;
124
125 fn query_range(
127 &self,
128 scenario: &str,
129 from: Timestamp,
130 to: Timestamp,
131 ) -> Result<Vec<LearningSnapshot>, Self::Error>;
132
133 fn query_latest(
135 &self,
136 scenario: &str,
137 limit: usize,
138 ) -> Result<Vec<LearningSnapshot>, Self::Error>;
139
140 fn query_since(
142 &self,
143 scenario: &str,
144 since: Timestamp,
145 ) -> Result<Vec<LearningSnapshot>, Self::Error> {
146 self.query_range(scenario, since, u64::MAX)
147 }
148
149 fn list_sessions(&self, scenario: &str) -> Result<Vec<SessionId>, Self::Error>;
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
159pub struct LearningSnapshot {
160 pub version: u32,
162 pub metadata: SnapshotMetadata,
164 pub episode_transitions: EpisodeTransitions,
166 pub ngram_stats: NgramStats,
168 pub selection_performance: SelectionPerformance,
170 pub contextual_stats: HashMap<(String, String), ActionStats>,
172 pub action_stats: HashMap<String, ActionStats>,
174}
175
176#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct SnapshotMetadata {
179 pub scenario_name: Option<String>,
181 pub task_description: Option<String>,
183 pub created_at: u64,
185 pub session_count: u32,
187 pub total_episodes: u32,
189 pub total_actions: u32,
191}
192
193impl Default for SnapshotMetadata {
194 fn default() -> Self {
195 Self {
196 scenario_name: None,
197 task_description: None,
198 created_at: SystemTime::now()
199 .duration_since(UNIX_EPOCH)
200 .map(|d| d.as_secs())
201 .unwrap_or(0),
202 session_count: 1,
203 total_episodes: 0,
204 total_actions: 0,
205 }
206 }
207}
208
209impl SnapshotMetadata {
210 pub fn with_scenario(mut self, name: impl Into<String>) -> Self {
212 self.scenario_name = Some(name.into());
213 self
214 }
215
216 pub fn with_task(mut self, desc: impl Into<String>) -> Self {
218 self.task_description = Some(desc.into());
219 self
220 }
221}
222
223impl LearningSnapshot {
224 pub fn empty() -> Self {
226 Self {
227 version: SNAPSHOT_VERSION,
228 metadata: SnapshotMetadata::default(),
229 episode_transitions: EpisodeTransitions::default(),
230 ngram_stats: NgramStats::default(),
231 selection_performance: SelectionPerformance::default(),
232 contextual_stats: HashMap::new(),
233 action_stats: HashMap::new(),
234 }
235 }
236
237 pub fn with_metadata(mut self, metadata: SnapshotMetadata) -> Self {
239 self.metadata = metadata;
240 self
241 }
242}
243
244#[derive(Debug, Clone, Copy, Default)]
250pub enum MergeStrategy {
251 #[default]
253 Additive,
254 TimeDecay {
256 half_life_sessions: u32,
258 },
259 SuccessWeighted,
261}
262
263pub struct FileSystemStorage {
269 base_dir: PathBuf,
270}
271
272impl FileSystemStorage {
273 pub fn new(base_dir: impl AsRef<Path>) -> std::io::Result<Self> {
275 let base_dir = base_dir.as_ref().to_path_buf();
276 fs::create_dir_all(&base_dir)?;
277 Ok(Self { base_dir })
278 }
279
280 pub fn base_dir(&self) -> &Path {
282 &self.base_dir
283 }
284
285 fn key_to_path(&self, key: &SnapshotKey) -> PathBuf {
286 match key {
287 SnapshotKey::Global => self.base_dir.join("global_stats.json"),
288 SnapshotKey::Scenario(scenario) => self
289 .base_dir
290 .join("scenarios")
291 .join(scenario)
292 .join("stats.json"),
293 SnapshotKey::Session {
294 scenario,
295 session_id,
296 } => self
297 .base_dir
298 .join("scenarios")
299 .join(scenario)
300 .join("sessions")
301 .join(&session_id.0)
302 .join("stats.json"),
303 }
304 }
305
306 fn sessions_dir(&self, scenario: &str) -> PathBuf {
307 self.base_dir
308 .join("scenarios")
309 .join(scenario)
310 .join("sessions")
311 }
312}
313
314impl SnapshotStorage for FileSystemStorage {
315 type Error = std::io::Error;
316
317 fn save(&self, key: &SnapshotKey, snapshot: &LearningSnapshot) -> Result<(), Self::Error> {
318 let path = self.key_to_path(key);
319 if let Some(parent) = path.parent() {
320 fs::create_dir_all(parent)?;
321 }
322 snapshot.save_json(&path)
323 }
324
325 fn load(&self, key: &SnapshotKey) -> Result<Option<LearningSnapshot>, Self::Error> {
326 let path = self.key_to_path(key);
327 if !path.exists() {
328 return Ok(None);
329 }
330 LearningSnapshot::load_json(&path).map(Some)
331 }
332
333 fn delete(&self, key: &SnapshotKey) -> Result<bool, Self::Error> {
334 let path = self.key_to_path(key);
335 if path.exists() {
336 fs::remove_file(&path)?;
337 Ok(true)
338 } else {
339 Ok(false)
340 }
341 }
342
343 fn exists(&self, key: &SnapshotKey) -> Result<bool, Self::Error> {
344 Ok(self.key_to_path(key).exists())
345 }
346}
347
348impl TimeSeriesQuery for FileSystemStorage {
349 type Error = std::io::Error;
350
351 fn query_range(
352 &self,
353 scenario: &str,
354 from: Timestamp,
355 to: Timestamp,
356 ) -> Result<Vec<LearningSnapshot>, Self::Error> {
357 let sessions = self.list_sessions(scenario)?;
358 let mut results = Vec::new();
359
360 for session_id in sessions {
361 if let Some(ts) = session_id.timestamp() {
362 if ts >= from && ts <= to {
363 let key = SnapshotKey::Session {
364 scenario: scenario.to_string(),
365 session_id,
366 };
367 if let Some(snapshot) = self.load(&key)? {
368 results.push(snapshot);
369 }
370 }
371 }
372 }
373 Ok(results)
374 }
375
376 fn query_latest(
377 &self,
378 scenario: &str,
379 limit: usize,
380 ) -> Result<Vec<LearningSnapshot>, Self::Error> {
381 let mut sessions = self.list_sessions(scenario)?;
382 sessions.sort_by(|a, b| b.0.cmp(&a.0));
384 sessions.truncate(limit);
385
386 let mut results = Vec::new();
387 for session_id in sessions {
388 let key = SnapshotKey::Session {
389 scenario: scenario.to_string(),
390 session_id,
391 };
392 if let Some(snapshot) = self.load(&key)? {
393 results.push(snapshot);
394 }
395 }
396 Ok(results)
397 }
398
399 fn list_sessions(&self, scenario: &str) -> Result<Vec<SessionId>, Self::Error> {
400 let sessions_dir = self.sessions_dir(scenario);
401 if !sessions_dir.exists() {
402 return Ok(Vec::new());
403 }
404
405 let mut sessions = Vec::new();
406 for entry in fs::read_dir(sessions_dir)? {
407 let entry = entry?;
408 if entry.file_type()?.is_dir() {
409 if let Some(name) = entry.file_name().to_str() {
410 sessions.push(SessionId(name.to_string()));
411 }
412 }
413 }
414 sessions.sort_by(|a, b| a.0.cmp(&b.0));
415 Ok(sessions)
416 }
417}
418
419pub struct LearningStore {
427 storage: FileSystemStorage,
428}
429
430impl LearningStore {
431 pub fn new(base_dir: impl AsRef<Path>) -> std::io::Result<Self> {
433 let storage = FileSystemStorage::new(base_dir)?;
434 Ok(Self { storage })
435 }
436
437 pub fn default_path() -> PathBuf {
439 dirs::data_dir()
440 .unwrap_or_else(|| PathBuf::from("."))
441 .join("swarm-engine")
442 .join("learning")
443 }
444
445 pub fn storage(&self) -> &FileSystemStorage {
447 &self.storage
448 }
449
450 pub fn load_global(&self) -> std::io::Result<LearningSnapshot> {
456 self.storage.load(&SnapshotKey::Global)?.ok_or_else(|| {
457 std::io::Error::new(std::io::ErrorKind::NotFound, "global stats not found")
458 })
459 }
460
461 pub fn save_global(&self, snapshot: &LearningSnapshot) -> std::io::Result<()> {
463 self.storage.save(&SnapshotKey::Global, snapshot)
464 }
465
466 pub fn load_scenario(&self, scenario: &str) -> std::io::Result<LearningSnapshot> {
468 self.storage
469 .load(&SnapshotKey::Scenario(scenario.to_string()))?
470 .ok_or_else(|| {
471 std::io::Error::new(std::io::ErrorKind::NotFound, "scenario stats not found")
472 })
473 }
474
475 pub fn save_scenario(
477 &self,
478 scenario: &str,
479 snapshot: &LearningSnapshot,
480 ) -> std::io::Result<()> {
481 self.storage
482 .save(&SnapshotKey::Scenario(scenario.to_string()), snapshot)
483 }
484
485 pub fn save_session(
487 &self,
488 scenario: &str,
489 snapshot: &LearningSnapshot,
490 ) -> std::io::Result<SessionId> {
491 let session_id = self.generate_session_id();
493
494 let key = SnapshotKey::Session {
496 scenario: scenario.to_string(),
497 session_id: session_id.clone(),
498 };
499 self.storage.save(&key, snapshot)?;
500
501 let meta_path = self.storage.key_to_path(&key).with_file_name("meta.json");
503 let meta_json = serde_json::to_string_pretty(&snapshot.metadata)?;
504 fs::write(meta_path, meta_json)?;
505
506 self.merge_into_scenario(scenario, snapshot)?;
508
509 self.merge_into_global(snapshot)?;
511
512 Ok(session_id)
513 }
514
515 pub fn list_sessions(&self, scenario: &str) -> std::io::Result<Vec<SessionId>> {
517 self.storage.list_sessions(scenario)
518 }
519
520 pub fn load_session(
522 &self,
523 scenario: &str,
524 session_id: &SessionId,
525 ) -> std::io::Result<LearningSnapshot> {
526 let key = SnapshotKey::Session {
527 scenario: scenario.to_string(),
528 session_id: session_id.clone(),
529 };
530 self.storage
531 .load(&key)?
532 .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::NotFound, "session not found"))
533 }
534
535 pub fn query_range(
541 &self,
542 scenario: &str,
543 from: Timestamp,
544 to: Timestamp,
545 ) -> std::io::Result<Vec<LearningSnapshot>> {
546 self.storage.query_range(scenario, from, to)
547 }
548
549 pub fn query_latest(
551 &self,
552 scenario: &str,
553 limit: usize,
554 ) -> std::io::Result<Vec<LearningSnapshot>> {
555 self.storage.query_latest(scenario, limit)
556 }
557
558 pub fn merge(
564 &self,
565 snapshots: &[LearningSnapshot],
566 strategy: MergeStrategy,
567 ) -> LearningSnapshot {
568 merge_snapshots(snapshots, strategy)
569 }
570
571 fn generate_session_id(&self) -> SessionId {
572 let timestamp = SystemTime::now()
573 .duration_since(UNIX_EPOCH)
574 .map(|d| d.as_secs())
575 .unwrap_or(0);
576 SessionId(format!("{:010}", timestamp))
577 }
578
579 fn merge_into_scenario(
580 &self,
581 scenario: &str,
582 snapshot: &LearningSnapshot,
583 ) -> std::io::Result<()> {
584 let existing = self
585 .storage
586 .load(&SnapshotKey::Scenario(scenario.to_string()))?;
587 let merged = match existing {
588 Some(existing) => {
589 merge_snapshots(&[existing, snapshot.clone()], MergeStrategy::Additive)
590 }
591 None => snapshot.clone(),
592 };
593 self.save_scenario(scenario, &merged)
594 }
595
596 fn merge_into_global(&self, snapshot: &LearningSnapshot) -> std::io::Result<()> {
597 let existing = self.storage.load(&SnapshotKey::Global)?;
598 let merged = match existing {
599 Some(existing) => {
600 merge_snapshots(&[existing, snapshot.clone()], MergeStrategy::Additive)
601 }
602 None => snapshot.clone(),
603 };
604 self.save_global(&merged)
605 }
606
607 pub fn load_offline_model(
613 &self,
614 scenario: &str,
615 ) -> std::io::Result<super::offline::OfflineModel> {
616 let path = self.offline_model_path(scenario);
617 if !path.exists() {
618 return Err(std::io::Error::new(
619 std::io::ErrorKind::NotFound,
620 "offline model not found",
621 ));
622 }
623 let json = fs::read_to_string(&path)?;
624 serde_json::from_str(&json)
625 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
626 }
627
628 pub fn save_offline_model(
630 &self,
631 scenario: &str,
632 model: &super::offline::OfflineModel,
633 ) -> std::io::Result<()> {
634 let path = self.offline_model_path(scenario);
635 if let Some(parent) = path.parent() {
636 fs::create_dir_all(parent)?;
637 }
638 let json = serde_json::to_string_pretty(model)?;
639 fs::write(path, json)
640 }
641
642 pub fn run_offline_learning(
644 &self,
645 scenario: &str,
646 session_limit: usize,
647 ) -> std::io::Result<super::offline::OfflineModel> {
648 let snapshots = self.query_latest(scenario, session_limit)?;
649 if snapshots.is_empty() {
650 return Err(std::io::Error::new(
651 std::io::ErrorKind::NotFound,
652 "no sessions found for offline learning",
653 ));
654 }
655
656 let analyzer = super::offline::OfflineAnalyzer::new(&snapshots);
657 let model = analyzer.analyze();
658
659 self.save_offline_model(scenario, &model)?;
660
661 Ok(model)
662 }
663
664 fn offline_model_path(&self, scenario: &str) -> PathBuf {
665 self.storage
666 .base_dir()
667 .join("scenarios")
668 .join(scenario)
669 .join("offline_model.json")
670 }
671}
672
673pub fn merge_snapshots(
679 snapshots: &[LearningSnapshot],
680 strategy: MergeStrategy,
681) -> LearningSnapshot {
682 if snapshots.is_empty() {
683 return LearningSnapshot::empty();
684 }
685 if snapshots.len() == 1 {
686 return snapshots[0].clone();
687 }
688
689 let mut result = LearningSnapshot::empty();
690
691 let weights: Vec<f64> = match strategy {
693 MergeStrategy::Additive => vec![1.0; snapshots.len()],
694 MergeStrategy::TimeDecay { half_life_sessions } => {
695 let half_life = half_life_sessions as f64;
696 snapshots
697 .iter()
698 .enumerate()
699 .map(|(i, _)| {
700 let age = (snapshots.len() - 1 - i) as f64;
701 0.5_f64.powf(age / half_life)
702 })
703 .collect()
704 }
705 MergeStrategy::SuccessWeighted => snapshots
706 .iter()
707 .map(|s| {
708 let total = s.metadata.total_episodes as f64;
709 let success = s.episode_transitions.success_episodes as f64;
710 if total == 0.0 {
711 1.0
712 } else {
713 1.0 + success / total
714 }
715 })
716 .collect(),
717 };
718
719 result.metadata = SnapshotMetadata {
721 scenario_name: snapshots
722 .last()
723 .and_then(|s| s.metadata.scenario_name.clone()),
724 task_description: snapshots
725 .last()
726 .and_then(|s| s.metadata.task_description.clone()),
727 created_at: SystemTime::now()
728 .duration_since(UNIX_EPOCH)
729 .map(|d| d.as_secs())
730 .unwrap_or(0),
731 session_count: snapshots.iter().map(|s| s.metadata.session_count).sum(),
732 total_episodes: snapshots.iter().map(|s| s.metadata.total_episodes).sum(),
733 total_actions: snapshots.iter().map(|s| s.metadata.total_actions).sum(),
734 };
735
736 for (snapshot, &weight) in snapshots.iter().zip(weights.iter()) {
738 for (key, &count) in &snapshot.episode_transitions.success_transitions {
739 let weighted_count = (count as f64 * weight).round() as u32;
740 *result
741 .episode_transitions
742 .success_transitions
743 .entry(key.clone())
744 .or_default() += weighted_count;
745 }
746 for (key, &count) in &snapshot.episode_transitions.failure_transitions {
747 let weighted_count = (count as f64 * weight).round() as u32;
748 *result
749 .episode_transitions
750 .failure_transitions
751 .entry(key.clone())
752 .or_default() += weighted_count;
753 }
754 result.episode_transitions.success_episodes +=
755 (snapshot.episode_transitions.success_episodes as f64 * weight).round() as u32;
756 result.episode_transitions.failure_episodes +=
757 (snapshot.episode_transitions.failure_episodes as f64 * weight).round() as u32;
758 }
759
760 for (snapshot, &weight) in snapshots.iter().zip(weights.iter()) {
762 for (key, &(success, failure)) in &snapshot.ngram_stats.trigrams {
763 let entry = result
764 .ngram_stats
765 .trigrams
766 .entry(key.clone())
767 .or_insert((0, 0));
768 entry.0 += (success as f64 * weight).round() as u32;
769 entry.1 += (failure as f64 * weight).round() as u32;
770 }
771 for (key, &(success, failure)) in &snapshot.ngram_stats.quadgrams {
772 let entry = result
773 .ngram_stats
774 .quadgrams
775 .entry(key.clone())
776 .or_insert((0, 0));
777 entry.0 += (success as f64 * weight).round() as u32;
778 entry.1 += (failure as f64 * weight).round() as u32;
779 }
780 }
781
782 for (snapshot, &weight) in snapshots.iter().zip(weights.iter()) {
784 for (key, stats) in &snapshot.contextual_stats {
785 let entry = result.contextual_stats.entry(key.clone()).or_default();
786 entry.visits += (stats.visits as f64 * weight).round() as u32;
787 entry.successes += (stats.successes as f64 * weight).round() as u32;
788 entry.failures += (stats.failures as f64 * weight).round() as u32;
789 }
790 }
791
792 for (snapshot, &weight) in snapshots.iter().zip(weights.iter()) {
794 for (key, stats) in &snapshot.action_stats {
795 let entry = result.action_stats.entry(key.clone()).or_default();
796 entry.visits += (stats.visits as f64 * weight).round() as u32;
797 entry.successes += (stats.successes as f64 * weight).round() as u32;
798 entry.failures += (stats.failures as f64 * weight).round() as u32;
799 }
800 }
801
802 for (snapshot, &weight) in snapshots.iter().zip(weights.iter()) {
804 for (strat, stats) in &snapshot.selection_performance.strategy_stats {
805 let entry = result
806 .selection_performance
807 .strategy_stats
808 .entry(strat.clone())
809 .or_default();
810 entry.visits += (stats.visits as f64 * weight).round() as u32;
811 entry.successes += (stats.successes as f64 * weight).round() as u32;
812 entry.failures += (stats.failures as f64 * weight).round() as u32;
813 entry.episodes_success += (stats.episodes_success as f64 * weight).round() as u32;
814 entry.episodes_failure += (stats.episodes_failure as f64 * weight).round() as u32;
815 }
816 }
817
818 result
819}
820
821impl LearningSnapshot {
826 pub fn to_json(&self) -> serde_json::Result<String> {
828 serde_json::to_string_pretty(self)
829 }
830
831 pub fn from_json(json: &str) -> serde_json::Result<Self> {
833 serde_json::from_str(json)
834 }
835
836 pub fn save_json(&self, path: impl AsRef<Path>) -> std::io::Result<()> {
838 let json = self.to_json()?;
839 fs::write(path, json)
840 }
841
842 pub fn load_json(path: impl AsRef<Path>) -> std::io::Result<Self> {
844 let json = fs::read_to_string(path)?;
845 Self::from_json(&json).map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
846 }
847}
848
849#[cfg(test)]
854mod tests {
855 use super::*;
856 use tempfile::tempdir;
857
858 #[test]
859 fn test_snapshot_serialization() {
860 let snapshot = LearningSnapshot::empty()
861 .with_metadata(SnapshotMetadata::default().with_scenario("test"));
862
863 let json = snapshot.to_json().unwrap();
864 let loaded = LearningSnapshot::from_json(&json).unwrap();
865
866 assert_eq!(loaded.version, SNAPSHOT_VERSION);
867 assert_eq!(loaded.metadata.scenario_name, Some("test".to_string()));
868 }
869
870 #[test]
871 fn test_learning_store_save_load() {
872 let dir = tempdir().unwrap();
873 let store = LearningStore::new(dir.path()).unwrap();
874
875 let snapshot = LearningSnapshot::empty()
876 .with_metadata(SnapshotMetadata::default().with_scenario("troubleshooting"));
877
878 store.save_scenario("troubleshooting", &snapshot).unwrap();
879 let loaded = store.load_scenario("troubleshooting").unwrap();
880
881 assert_eq!(
882 loaded.metadata.scenario_name,
883 Some("troubleshooting".to_string())
884 );
885 }
886
887 #[test]
888 fn test_merge_additive() {
889 let dir = tempdir().unwrap();
890 let store = LearningStore::new(dir.path()).unwrap();
891
892 let mut s1 = LearningSnapshot::empty();
893 s1.episode_transitions
894 .success_transitions
895 .insert(("A".to_string(), "B".to_string()), 5);
896 s1.metadata.total_episodes = 10;
897
898 let mut s2 = LearningSnapshot::empty();
899 s2.episode_transitions
900 .success_transitions
901 .insert(("A".to_string(), "B".to_string()), 3);
902 s2.metadata.total_episodes = 5;
903
904 let merged = store.merge(&[s1, s2], MergeStrategy::Additive);
905
906 assert_eq!(
907 merged
908 .episode_transitions
909 .success_transitions
910 .get(&("A".to_string(), "B".to_string())),
911 Some(&8)
912 );
913 assert_eq!(merged.metadata.total_episodes, 15);
914 }
915
916 #[test]
917 fn test_merge_time_decay() {
918 let dir = tempdir().unwrap();
919 let store = LearningStore::new(dir.path()).unwrap();
920
921 let mut s1 = LearningSnapshot::empty();
923 s1.episode_transitions
924 .success_transitions
925 .insert(("A".to_string(), "B".to_string()), 100);
926
927 let mut s2 = LearningSnapshot::empty();
929 s2.episode_transitions
930 .success_transitions
931 .insert(("A".to_string(), "B".to_string()), 100);
932
933 let merged = store.merge(
934 &[s1, s2],
935 MergeStrategy::TimeDecay {
936 half_life_sessions: 1,
937 },
938 );
939
940 let count = merged
942 .episode_transitions
943 .success_transitions
944 .get(&("A".to_string(), "B".to_string()))
945 .unwrap();
946 assert_eq!(*count, 150);
947 }
948
949 #[test]
950 fn test_session_management() {
951 let dir = tempdir().unwrap();
952 let store = LearningStore::new(dir.path()).unwrap();
953
954 let metadata = SnapshotMetadata::default().with_scenario("test");
956 let snapshot = LearningSnapshot {
957 version: SNAPSHOT_VERSION,
958 metadata,
959 action_stats: Default::default(),
960 episode_transitions: Default::default(),
961 ngram_stats: Default::default(),
962 selection_performance: Default::default(),
963 contextual_stats: Default::default(),
964 };
965
966 let session_id = store.save_session("test", &snapshot).unwrap();
967 assert!(!session_id.0.is_empty());
968
969 let sessions = store.list_sessions("test").unwrap();
970 assert_eq!(sessions.len(), 1);
971 assert_eq!(sessions[0], session_id);
972 }
973}