1use std::collections::HashMap;
43use std::fs;
44use std::path::{Path, PathBuf};
45use std::time::{SystemTime, UNIX_EPOCH};
46
47use serde::{Deserialize, Serialize};
48
49use super::session_group::{LearningPhase, SessionGroupId};
50use super::{EpisodeTransitions, NgramStats, SelectionPerformance};
51use crate::online_stats::ActionStats;
52
53pub const SNAPSHOT_VERSION: u32 = 1;
55
56pub type Timestamp = u64;
58
59#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
61pub struct SessionId(pub String);
62
63impl SessionId {
64 pub fn timestamp(&self) -> Option<Timestamp> {
66 self.0.parse().ok()
67 }
68}
69
70#[derive(Debug, Clone, PartialEq, Eq, Hash)]
76pub enum SnapshotKey {
77 Global,
79 Scenario(String),
81 Session {
83 scenario: String,
84 session_id: SessionId,
85 },
86}
87
88impl std::fmt::Display for SnapshotKey {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 match self {
91 Self::Global => write!(f, "global"),
92 Self::Scenario(s) => write!(f, "scenario:{}", s),
93 Self::Session {
94 scenario,
95 session_id,
96 } => {
97 write!(f, "session:{}:{}", scenario, session_id.0)
98 }
99 }
100 }
101}
102
103pub trait SnapshotStorage {
105 type Error: std::error::Error + Send + Sync + 'static;
107
108 fn save(&self, key: &SnapshotKey, snapshot: &LearningSnapshot) -> Result<(), Self::Error>;
110
111 fn load(&self, key: &SnapshotKey) -> Result<Option<LearningSnapshot>, Self::Error>;
113
114 fn delete(&self, key: &SnapshotKey) -> Result<bool, Self::Error>;
116
117 fn exists(&self, key: &SnapshotKey) -> Result<bool, Self::Error>;
119}
120
121pub trait TimeSeriesQuery {
123 type Error: std::error::Error + Send + Sync + 'static;
125
126 fn query_range(
128 &self,
129 scenario: &str,
130 from: Timestamp,
131 to: Timestamp,
132 ) -> Result<Vec<LearningSnapshot>, Self::Error>;
133
134 fn query_latest(
136 &self,
137 scenario: &str,
138 limit: usize,
139 ) -> Result<Vec<LearningSnapshot>, Self::Error>;
140
141 fn query_since(
143 &self,
144 scenario: &str,
145 since: Timestamp,
146 ) -> Result<Vec<LearningSnapshot>, Self::Error> {
147 self.query_range(scenario, since, u64::MAX)
148 }
149
150 fn list_sessions(&self, scenario: &str) -> Result<Vec<SessionId>, Self::Error>;
152}
153
154#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct LearningSnapshot {
161 pub version: u32,
163 pub metadata: SnapshotMetadata,
165 pub episode_transitions: EpisodeTransitions,
167 pub ngram_stats: NgramStats,
169 pub selection_performance: SelectionPerformance,
171 pub contextual_stats: HashMap<(String, String), ActionStats>,
173 pub action_stats: HashMap<String, ActionStats>,
175}
176
177#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct SnapshotMetadata {
180 pub scenario_name: Option<String>,
182 pub task_description: Option<String>,
184 pub created_at: u64,
186 pub session_count: u32,
188 pub total_episodes: u32,
190 pub total_actions: u32,
192 #[serde(default, skip_serializing_if = "Option::is_none")]
194 pub phase: Option<LearningPhase>,
195 #[serde(default, skip_serializing_if = "Option::is_none")]
197 pub group_id: Option<SessionGroupId>,
198}
199
200impl Default for SnapshotMetadata {
201 fn default() -> Self {
202 Self {
203 scenario_name: None,
204 task_description: None,
205 created_at: SystemTime::now()
206 .duration_since(UNIX_EPOCH)
207 .map(|d| d.as_secs())
208 .unwrap_or(0),
209 session_count: 1,
210 total_episodes: 0,
211 total_actions: 0,
212 phase: None,
213 group_id: None,
214 }
215 }
216}
217
218impl SnapshotMetadata {
219 pub fn with_scenario(mut self, name: impl Into<String>) -> Self {
221 self.scenario_name = Some(name.into());
222 self
223 }
224
225 pub fn with_task(mut self, desc: impl Into<String>) -> Self {
227 self.task_description = Some(desc.into());
228 self
229 }
230
231 pub fn with_phase(mut self, phase: LearningPhase) -> Self {
233 self.phase = Some(phase);
234 self
235 }
236
237 pub fn with_group_id(mut self, group_id: SessionGroupId) -> Self {
239 self.group_id = Some(group_id);
240 self
241 }
242}
243
244impl LearningSnapshot {
245 pub fn empty() -> Self {
247 Self {
248 version: SNAPSHOT_VERSION,
249 metadata: SnapshotMetadata::default(),
250 episode_transitions: EpisodeTransitions::default(),
251 ngram_stats: NgramStats::default(),
252 selection_performance: SelectionPerformance::default(),
253 contextual_stats: HashMap::new(),
254 action_stats: HashMap::new(),
255 }
256 }
257
258 pub fn with_metadata(mut self, metadata: SnapshotMetadata) -> Self {
260 self.metadata = metadata;
261 self
262 }
263}
264
265#[derive(Debug, Clone, Copy, Default)]
271pub enum MergeStrategy {
272 #[default]
274 Additive,
275 TimeDecay {
277 half_life_sessions: u32,
279 },
280 SuccessWeighted,
282}
283
284#[derive(Clone)]
290pub struct FileSystemStorage {
291 base_dir: PathBuf,
292}
293
294impl FileSystemStorage {
295 pub fn new(base_dir: impl AsRef<Path>) -> std::io::Result<Self> {
297 let base_dir = base_dir.as_ref().to_path_buf();
298 fs::create_dir_all(&base_dir)?;
299 Ok(Self { base_dir })
300 }
301
302 pub fn base_dir(&self) -> &Path {
304 &self.base_dir
305 }
306
307 fn key_to_path(&self, key: &SnapshotKey) -> PathBuf {
308 match key {
309 SnapshotKey::Global => self.base_dir.join("global_stats.json"),
310 SnapshotKey::Scenario(scenario) => self
311 .base_dir
312 .join("scenarios")
313 .join(scenario)
314 .join("stats.json"),
315 SnapshotKey::Session {
316 scenario,
317 session_id,
318 } => self
319 .base_dir
320 .join("scenarios")
321 .join(scenario)
322 .join("sessions")
323 .join(&session_id.0)
324 .join("stats.json"),
325 }
326 }
327
328 fn sessions_dir(&self, scenario: &str) -> PathBuf {
329 self.base_dir
330 .join("scenarios")
331 .join(scenario)
332 .join("sessions")
333 }
334}
335
336impl SnapshotStorage for FileSystemStorage {
337 type Error = std::io::Error;
338
339 fn save(&self, key: &SnapshotKey, snapshot: &LearningSnapshot) -> Result<(), Self::Error> {
340 let path = self.key_to_path(key);
341 if let Some(parent) = path.parent() {
342 fs::create_dir_all(parent)?;
343 }
344 snapshot.save_json(&path)
345 }
346
347 fn load(&self, key: &SnapshotKey) -> Result<Option<LearningSnapshot>, Self::Error> {
348 let path = self.key_to_path(key);
349 if !path.exists() {
350 return Ok(None);
351 }
352 LearningSnapshot::load_json(&path).map(Some)
353 }
354
355 fn delete(&self, key: &SnapshotKey) -> Result<bool, Self::Error> {
356 let path = self.key_to_path(key);
357 if path.exists() {
358 fs::remove_file(&path)?;
359 Ok(true)
360 } else {
361 Ok(false)
362 }
363 }
364
365 fn exists(&self, key: &SnapshotKey) -> Result<bool, Self::Error> {
366 Ok(self.key_to_path(key).exists())
367 }
368}
369
370impl TimeSeriesQuery for FileSystemStorage {
371 type Error = std::io::Error;
372
373 fn query_range(
374 &self,
375 scenario: &str,
376 from: Timestamp,
377 to: Timestamp,
378 ) -> Result<Vec<LearningSnapshot>, Self::Error> {
379 let sessions = self.list_sessions(scenario)?;
380 let mut results = Vec::new();
381
382 for session_id in sessions {
383 if let Some(ts) = session_id.timestamp() {
384 if ts >= from && ts <= to {
385 let key = SnapshotKey::Session {
386 scenario: scenario.to_string(),
387 session_id,
388 };
389 if let Some(snapshot) = self.load(&key)? {
390 results.push(snapshot);
391 }
392 }
393 }
394 }
395 Ok(results)
396 }
397
398 fn query_latest(
399 &self,
400 scenario: &str,
401 limit: usize,
402 ) -> Result<Vec<LearningSnapshot>, Self::Error> {
403 let mut sessions = self.list_sessions(scenario)?;
404 sessions.sort_by(|a, b| b.0.cmp(&a.0));
406 sessions.truncate(limit);
407
408 let mut results = Vec::new();
409 for session_id in sessions {
410 let key = SnapshotKey::Session {
411 scenario: scenario.to_string(),
412 session_id,
413 };
414 if let Some(snapshot) = self.load(&key)? {
415 results.push(snapshot);
416 }
417 }
418 Ok(results)
419 }
420
421 fn list_sessions(&self, scenario: &str) -> Result<Vec<SessionId>, Self::Error> {
422 let sessions_dir = self.sessions_dir(scenario);
423 if !sessions_dir.exists() {
424 return Ok(Vec::new());
425 }
426
427 let mut sessions = Vec::new();
428 for entry in fs::read_dir(sessions_dir)? {
429 let entry = entry?;
430 if entry.file_type()?.is_dir() {
431 if let Some(name) = entry.file_name().to_str() {
432 sessions.push(SessionId(name.to_string()));
433 }
434 }
435 }
436 sessions.sort_by(|a, b| a.0.cmp(&b.0));
437 Ok(sessions)
438 }
439}
440
441#[derive(Clone)]
449pub struct LearningStore {
450 storage: FileSystemStorage,
451}
452
453impl LearningStore {
454 pub fn new(base_dir: impl AsRef<Path>) -> std::io::Result<Self> {
456 let storage = FileSystemStorage::new(base_dir)?;
457 Ok(Self { storage })
458 }
459
460 pub fn default_path() -> PathBuf {
462 dirs::data_dir()
463 .unwrap_or_else(|| PathBuf::from("."))
464 .join("swarm-engine")
465 .join("learning")
466 }
467
468 pub fn storage(&self) -> &FileSystemStorage {
470 &self.storage
471 }
472
473 pub fn load_global(&self) -> std::io::Result<LearningSnapshot> {
479 self.storage.load(&SnapshotKey::Global)?.ok_or_else(|| {
480 std::io::Error::new(std::io::ErrorKind::NotFound, "global stats not found")
481 })
482 }
483
484 pub fn save_global(&self, snapshot: &LearningSnapshot) -> std::io::Result<()> {
486 self.storage.save(&SnapshotKey::Global, snapshot)
487 }
488
489 pub fn load_scenario(&self, scenario: &str) -> std::io::Result<LearningSnapshot> {
491 self.storage
492 .load(&SnapshotKey::Scenario(scenario.to_string()))?
493 .ok_or_else(|| {
494 std::io::Error::new(std::io::ErrorKind::NotFound, "scenario stats not found")
495 })
496 }
497
498 pub fn save_scenario(
500 &self,
501 scenario: &str,
502 snapshot: &LearningSnapshot,
503 ) -> std::io::Result<()> {
504 self.storage
505 .save(&SnapshotKey::Scenario(scenario.to_string()), snapshot)
506 }
507
508 pub fn save_session(
510 &self,
511 scenario: &str,
512 snapshot: &LearningSnapshot,
513 ) -> std::io::Result<SessionId> {
514 let session_id = self.generate_session_id();
516
517 let key = SnapshotKey::Session {
519 scenario: scenario.to_string(),
520 session_id: session_id.clone(),
521 };
522 self.storage.save(&key, snapshot)?;
523
524 let meta_path = self.storage.key_to_path(&key).with_file_name("meta.json");
526 let meta_json = serde_json::to_string_pretty(&snapshot.metadata)?;
527 fs::write(meta_path, meta_json)?;
528
529 self.merge_into_scenario(scenario, snapshot)?;
531
532 self.merge_into_global(snapshot)?;
534
535 Ok(session_id)
536 }
537
538 pub fn list_sessions(&self, scenario: &str) -> std::io::Result<Vec<SessionId>> {
540 self.storage.list_sessions(scenario)
541 }
542
543 pub fn load_session(
545 &self,
546 scenario: &str,
547 session_id: &SessionId,
548 ) -> std::io::Result<LearningSnapshot> {
549 let key = SnapshotKey::Session {
550 scenario: scenario.to_string(),
551 session_id: session_id.clone(),
552 };
553 self.storage
554 .load(&key)?
555 .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::NotFound, "session not found"))
556 }
557
558 pub fn query_range(
564 &self,
565 scenario: &str,
566 from: Timestamp,
567 to: Timestamp,
568 ) -> std::io::Result<Vec<LearningSnapshot>> {
569 self.storage.query_range(scenario, from, to)
570 }
571
572 pub fn query_latest(
574 &self,
575 scenario: &str,
576 limit: usize,
577 ) -> std::io::Result<Vec<LearningSnapshot>> {
578 self.storage.query_latest(scenario, limit)
579 }
580
581 pub fn merge(
587 &self,
588 snapshots: &[LearningSnapshot],
589 strategy: MergeStrategy,
590 ) -> LearningSnapshot {
591 merge_snapshots(snapshots, strategy)
592 }
593
594 fn generate_session_id(&self) -> SessionId {
595 let timestamp = SystemTime::now()
596 .duration_since(UNIX_EPOCH)
597 .map(|d| d.as_secs())
598 .unwrap_or(0);
599 SessionId(format!("{:010}", timestamp))
600 }
601
602 fn merge_into_scenario(
603 &self,
604 scenario: &str,
605 snapshot: &LearningSnapshot,
606 ) -> std::io::Result<()> {
607 let existing = self
608 .storage
609 .load(&SnapshotKey::Scenario(scenario.to_string()))?;
610 let merged = match existing {
611 Some(existing) => {
612 merge_snapshots(&[existing, snapshot.clone()], MergeStrategy::Additive)
613 }
614 None => snapshot.clone(),
615 };
616 self.save_scenario(scenario, &merged)
617 }
618
619 fn merge_into_global(&self, snapshot: &LearningSnapshot) -> std::io::Result<()> {
620 let existing = self.storage.load(&SnapshotKey::Global)?;
621 let merged = match existing {
622 Some(existing) => {
623 merge_snapshots(&[existing, snapshot.clone()], MergeStrategy::Additive)
624 }
625 None => snapshot.clone(),
626 };
627 self.save_global(&merged)
628 }
629
630 pub fn load_offline_model(
636 &self,
637 scenario: &str,
638 ) -> std::io::Result<super::offline::OfflineModel> {
639 let path = self.offline_model_path(scenario);
640 if !path.exists() {
641 return Err(std::io::Error::new(
642 std::io::ErrorKind::NotFound,
643 "offline model not found",
644 ));
645 }
646 let json = fs::read_to_string(&path)?;
647 serde_json::from_str(&json)
648 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
649 }
650
651 pub fn save_offline_model(
653 &self,
654 scenario: &str,
655 model: &super::offline::OfflineModel,
656 ) -> std::io::Result<()> {
657 let path = self.offline_model_path(scenario);
658 if let Some(parent) = path.parent() {
659 fs::create_dir_all(parent)?;
660 }
661 let json = serde_json::to_string_pretty(model)?;
662 fs::write(path, json)
663 }
664
665 pub fn run_offline_learning(
667 &self,
668 scenario: &str,
669 session_limit: usize,
670 ) -> std::io::Result<super::offline::OfflineModel> {
671 let snapshots = self.query_latest(scenario, session_limit)?;
672 if snapshots.is_empty() {
673 return Err(std::io::Error::new(
674 std::io::ErrorKind::NotFound,
675 "no sessions found for offline learning",
676 ));
677 }
678
679 let analyzer = super::offline::OfflineAnalyzer::new(&snapshots);
680 let model = analyzer.analyze();
681
682 self.save_offline_model(scenario, &model)?;
683
684 Ok(model)
685 }
686
687 fn offline_model_path(&self, scenario: &str) -> PathBuf {
688 self.storage
689 .base_dir()
690 .join("scenarios")
691 .join(scenario)
692 .join("offline_model.json")
693 }
694}
695
696pub fn merge_snapshots(
702 snapshots: &[LearningSnapshot],
703 strategy: MergeStrategy,
704) -> LearningSnapshot {
705 if snapshots.is_empty() {
706 return LearningSnapshot::empty();
707 }
708 if snapshots.len() == 1 {
709 return snapshots[0].clone();
710 }
711
712 let mut result = LearningSnapshot::empty();
713
714 let weights: Vec<f64> = match strategy {
716 MergeStrategy::Additive => vec![1.0; snapshots.len()],
717 MergeStrategy::TimeDecay { half_life_sessions } => {
718 let half_life = half_life_sessions as f64;
719 snapshots
720 .iter()
721 .enumerate()
722 .map(|(i, _)| {
723 let age = (snapshots.len() - 1 - i) as f64;
724 0.5_f64.powf(age / half_life)
725 })
726 .collect()
727 }
728 MergeStrategy::SuccessWeighted => snapshots
729 .iter()
730 .map(|s| {
731 let total = s.metadata.total_episodes as f64;
732 let success = s.episode_transitions.success_episodes as f64;
733 if total == 0.0 {
734 1.0
735 } else {
736 1.0 + success / total
737 }
738 })
739 .collect(),
740 };
741
742 result.metadata = SnapshotMetadata {
744 scenario_name: snapshots
745 .last()
746 .and_then(|s| s.metadata.scenario_name.clone()),
747 task_description: snapshots
748 .last()
749 .and_then(|s| s.metadata.task_description.clone()),
750 created_at: SystemTime::now()
751 .duration_since(UNIX_EPOCH)
752 .map(|d| d.as_secs())
753 .unwrap_or(0),
754 session_count: snapshots.iter().map(|s| s.metadata.session_count).sum(),
755 total_episodes: snapshots.iter().map(|s| s.metadata.total_episodes).sum(),
756 total_actions: snapshots.iter().map(|s| s.metadata.total_actions).sum(),
757 phase: None,
759 group_id: None,
760 };
761
762 for (snapshot, &weight) in snapshots.iter().zip(weights.iter()) {
764 for (key, &count) in &snapshot.episode_transitions.success_transitions {
765 let weighted_count = (count as f64 * weight).round() as u32;
766 *result
767 .episode_transitions
768 .success_transitions
769 .entry(key.clone())
770 .or_default() += weighted_count;
771 }
772 for (key, &count) in &snapshot.episode_transitions.failure_transitions {
773 let weighted_count = (count as f64 * weight).round() as u32;
774 *result
775 .episode_transitions
776 .failure_transitions
777 .entry(key.clone())
778 .or_default() += weighted_count;
779 }
780 result.episode_transitions.success_episodes +=
781 (snapshot.episode_transitions.success_episodes as f64 * weight).round() as u32;
782 result.episode_transitions.failure_episodes +=
783 (snapshot.episode_transitions.failure_episodes as f64 * weight).round() as u32;
784 }
785
786 for (snapshot, &weight) in snapshots.iter().zip(weights.iter()) {
788 for (key, &(success, failure)) in &snapshot.ngram_stats.trigrams {
789 let entry = result
790 .ngram_stats
791 .trigrams
792 .entry(key.clone())
793 .or_insert((0, 0));
794 entry.0 += (success as f64 * weight).round() as u32;
795 entry.1 += (failure as f64 * weight).round() as u32;
796 }
797 for (key, &(success, failure)) in &snapshot.ngram_stats.quadgrams {
798 let entry = result
799 .ngram_stats
800 .quadgrams
801 .entry(key.clone())
802 .or_insert((0, 0));
803 entry.0 += (success as f64 * weight).round() as u32;
804 entry.1 += (failure as f64 * weight).round() as u32;
805 }
806 }
807
808 for (snapshot, &weight) in snapshots.iter().zip(weights.iter()) {
810 for (key, stats) in &snapshot.contextual_stats {
811 let entry = result.contextual_stats.entry(key.clone()).or_default();
812 entry.visits += (stats.visits as f64 * weight).round() as u32;
813 entry.successes += (stats.successes as f64 * weight).round() as u32;
814 entry.failures += (stats.failures as f64 * weight).round() as u32;
815 }
816 }
817
818 for (snapshot, &weight) in snapshots.iter().zip(weights.iter()) {
820 for (key, stats) in &snapshot.action_stats {
821 let entry = result.action_stats.entry(key.clone()).or_default();
822 entry.visits += (stats.visits as f64 * weight).round() as u32;
823 entry.successes += (stats.successes as f64 * weight).round() as u32;
824 entry.failures += (stats.failures as f64 * weight).round() as u32;
825 }
826 }
827
828 for (snapshot, &weight) in snapshots.iter().zip(weights.iter()) {
830 for (strat, stats) in &snapshot.selection_performance.strategy_stats {
831 let entry = result
832 .selection_performance
833 .strategy_stats
834 .entry(strat.clone())
835 .or_default();
836 entry.visits += (stats.visits as f64 * weight).round() as u32;
837 entry.successes += (stats.successes as f64 * weight).round() as u32;
838 entry.failures += (stats.failures as f64 * weight).round() as u32;
839 entry.episodes_success += (stats.episodes_success as f64 * weight).round() as u32;
840 entry.episodes_failure += (stats.episodes_failure as f64 * weight).round() as u32;
841 }
842 }
843
844 result
845}
846
847impl LearningSnapshot {
852 pub fn to_json(&self) -> serde_json::Result<String> {
854 serde_json::to_string_pretty(self)
855 }
856
857 pub fn from_json(json: &str) -> serde_json::Result<Self> {
859 serde_json::from_str(json)
860 }
861
862 pub fn save_json(&self, path: impl AsRef<Path>) -> std::io::Result<()> {
864 let json = self.to_json()?;
865 fs::write(path, json)
866 }
867
868 pub fn load_json(path: impl AsRef<Path>) -> std::io::Result<Self> {
870 let json = fs::read_to_string(path)?;
871 Self::from_json(&json).map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
872 }
873}
874
875#[cfg(test)]
880mod tests {
881 use super::*;
882 use tempfile::tempdir;
883
884 #[test]
885 fn test_snapshot_serialization() {
886 let snapshot = LearningSnapshot::empty()
887 .with_metadata(SnapshotMetadata::default().with_scenario("test"));
888
889 let json = snapshot.to_json().unwrap();
890 let loaded = LearningSnapshot::from_json(&json).unwrap();
891
892 assert_eq!(loaded.version, SNAPSHOT_VERSION);
893 assert_eq!(loaded.metadata.scenario_name, Some("test".to_string()));
894 }
895
896 #[test]
897 fn test_learning_store_save_load() {
898 let dir = tempdir().unwrap();
899 let store = LearningStore::new(dir.path()).unwrap();
900
901 let snapshot = LearningSnapshot::empty()
902 .with_metadata(SnapshotMetadata::default().with_scenario("troubleshooting"));
903
904 store.save_scenario("troubleshooting", &snapshot).unwrap();
905 let loaded = store.load_scenario("troubleshooting").unwrap();
906
907 assert_eq!(
908 loaded.metadata.scenario_name,
909 Some("troubleshooting".to_string())
910 );
911 }
912
913 #[test]
914 fn test_merge_additive() {
915 let dir = tempdir().unwrap();
916 let store = LearningStore::new(dir.path()).unwrap();
917
918 let mut s1 = LearningSnapshot::empty();
919 s1.episode_transitions
920 .success_transitions
921 .insert(("A".to_string(), "B".to_string()), 5);
922 s1.metadata.total_episodes = 10;
923
924 let mut s2 = LearningSnapshot::empty();
925 s2.episode_transitions
926 .success_transitions
927 .insert(("A".to_string(), "B".to_string()), 3);
928 s2.metadata.total_episodes = 5;
929
930 let merged = store.merge(&[s1, s2], MergeStrategy::Additive);
931
932 assert_eq!(
933 merged
934 .episode_transitions
935 .success_transitions
936 .get(&("A".to_string(), "B".to_string())),
937 Some(&8)
938 );
939 assert_eq!(merged.metadata.total_episodes, 15);
940 }
941
942 #[test]
943 fn test_merge_time_decay() {
944 let dir = tempdir().unwrap();
945 let store = LearningStore::new(dir.path()).unwrap();
946
947 let mut s1 = LearningSnapshot::empty();
949 s1.episode_transitions
950 .success_transitions
951 .insert(("A".to_string(), "B".to_string()), 100);
952
953 let mut s2 = LearningSnapshot::empty();
955 s2.episode_transitions
956 .success_transitions
957 .insert(("A".to_string(), "B".to_string()), 100);
958
959 let merged = store.merge(
960 &[s1, s2],
961 MergeStrategy::TimeDecay {
962 half_life_sessions: 1,
963 },
964 );
965
966 let count = merged
968 .episode_transitions
969 .success_transitions
970 .get(&("A".to_string(), "B".to_string()))
971 .unwrap();
972 assert_eq!(*count, 150);
973 }
974
975 #[test]
976 fn test_session_management() {
977 let dir = tempdir().unwrap();
978 let store = LearningStore::new(dir.path()).unwrap();
979
980 let metadata = SnapshotMetadata::default().with_scenario("test");
982 let snapshot = LearningSnapshot {
983 version: SNAPSHOT_VERSION,
984 metadata,
985 action_stats: Default::default(),
986 episode_transitions: Default::default(),
987 ngram_stats: Default::default(),
988 selection_performance: Default::default(),
989 contextual_stats: Default::default(),
990 };
991
992 let session_id = store.save_session("test", &snapshot).unwrap();
993 assert!(!session_id.0.is_empty());
994
995 let sessions = store.list_sessions("test").unwrap();
996 assert_eq!(sessions.len(), 1);
997 assert_eq!(sessions[0], session_id);
998 }
999}