Skip to main content

swarm_engine_core/learn/
snapshot.rs

1//! Learning Snapshot - 学習データの永続化システム
2//!
3//! Multi-Session 間で統計を共有・マージするための永続化層。
4//!
5//! # アーキテクチャ
6//!
7//! ```text
8//! ┌─────────────────────────────────────────────────────────────┐
9//! │                    LearningStore                            │
10//! │                   (Facade / Entry Point)                    │
11//! ├─────────────────────────────────────────────────────────────┤
12//! │                                                             │
13//! │  ┌─────────────────┐    ┌─────────────────────────────┐    │
14//! │  │ SnapshotStorage │    │     TimeSeriesQuery         │    │
15//! │  │     (trait)     │    │        (trait)              │    │
16//! │  ├─────────────────┤    ├─────────────────────────────┤    │
17//! │  │ save(key, snap) │    │ query_range(from, to)       │    │
18//! │  │ load(key)       │    │ query_latest(n)             │    │
19//! │  │ delete(key)     │    │ list_sessions(scenario)     │    │
20//! │  │ exists(key)     │    └─────────────────────────────┘    │
21//! │  └─────────────────┘                                       │
22//! │           │                         │                      │
23//! │           └────────────┬────────────┘                      │
24//! │                        ▼                                   │
25//! │              ┌─────────────────┐                           │
26//! │              │FileSystemStorage│                           │
27//! │              │  (implements)   │                           │
28//! │              └─────────────────┘                           │
29//! └─────────────────────────────────────────────────────────────┘
30//!
31//! ~/.swarm-engine/learning/
32//! ├── global_stats.json
33//! └── scenarios/
34//!     └── troubleshooting/
35//!         ├── stats.json
36//!         └── sessions/
37//!             └── {timestamp}/
38//!                 ├── meta.json
39//!                 └── stats.json
40//! ```
41
42use std::collections::HashMap;
43use std::fs;
44use std::path::{Path, PathBuf};
45use std::time::{SystemTime, UNIX_EPOCH};
46
47use serde::{Deserialize, Serialize};
48
49use super::{EpisodeTransitions, NgramStats, SelectionPerformance};
50use crate::online_stats::ActionStats;
51
52/// フォーマットバージョン
53pub const SNAPSHOT_VERSION: u32 = 1;
54
55/// タイムスタンプ(Unix seconds)
56pub type Timestamp = u64;
57
58/// セッションID(タイムスタンプベース)
59#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
60pub struct SessionId(pub String);
61
62impl SessionId {
63    /// タイムスタンプから SessionId を取得
64    pub fn timestamp(&self) -> Option<Timestamp> {
65        self.0.parse().ok()
66    }
67}
68
69// ============================================================================
70// Storage Traits - 永続化の抽象化
71// ============================================================================
72
73/// スナップショットのキー(階層的)
74#[derive(Debug, Clone, PartialEq, Eq, Hash)]
75pub enum SnapshotKey {
76    /// グローバル統計
77    Global,
78    /// シナリオ別統計
79    Scenario(String),
80    /// セッション別統計
81    Session {
82        scenario: String,
83        session_id: SessionId,
84    },
85}
86
87impl std::fmt::Display for SnapshotKey {
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        match self {
90            Self::Global => write!(f, "global"),
91            Self::Scenario(s) => write!(f, "scenario:{}", s),
92            Self::Session {
93                scenario,
94                session_id,
95            } => {
96                write!(f, "session:{}:{}", scenario, session_id.0)
97            }
98        }
99    }
100}
101
102/// スナップショットの永続化(CRUD)
103pub trait SnapshotStorage {
104    /// エラー型
105    type Error: std::error::Error + Send + Sync + 'static;
106
107    /// スナップショットを保存
108    fn save(&self, key: &SnapshotKey, snapshot: &LearningSnapshot) -> Result<(), Self::Error>;
109
110    /// スナップショットをロード
111    fn load(&self, key: &SnapshotKey) -> Result<Option<LearningSnapshot>, Self::Error>;
112
113    /// スナップショットを削除
114    fn delete(&self, key: &SnapshotKey) -> Result<bool, Self::Error>;
115
116    /// スナップショットの存在確認
117    fn exists(&self, key: &SnapshotKey) -> Result<bool, Self::Error>;
118}
119
120/// 時系列クエリ
121pub trait TimeSeriesQuery {
122    /// エラー型
123    type Error: std::error::Error + Send + Sync + 'static;
124
125    /// 時間範囲でセッションを取得
126    fn query_range(
127        &self,
128        scenario: &str,
129        from: Timestamp,
130        to: Timestamp,
131    ) -> Result<Vec<LearningSnapshot>, Self::Error>;
132
133    /// 最新 N 件のセッションを取得
134    fn query_latest(
135        &self,
136        scenario: &str,
137        limit: usize,
138    ) -> Result<Vec<LearningSnapshot>, Self::Error>;
139
140    /// 指定時刻以降のセッションを取得
141    fn query_since(
142        &self,
143        scenario: &str,
144        since: Timestamp,
145    ) -> Result<Vec<LearningSnapshot>, Self::Error> {
146        self.query_range(scenario, since, u64::MAX)
147    }
148
149    /// シナリオのセッション一覧を取得
150    fn list_sessions(&self, scenario: &str) -> Result<Vec<SessionId>, Self::Error>;
151}
152
153// ============================================================================
154// LearningSnapshot - 永続化単位
155// ============================================================================
156
157/// 学習データのスナップショット(永続化単位)
158#[derive(Debug, Clone, Serialize, Deserialize)]
159pub struct LearningSnapshot {
160    /// フォーマットバージョン(互換性用)
161    pub version: u32,
162    /// メタデータ
163    pub metadata: SnapshotMetadata,
164    /// エピソード遷移統計(2-gram)
165    pub episode_transitions: EpisodeTransitions,
166    /// N-gram 統計(3-gram, 4-gram)
167    pub ngram_stats: NgramStats,
168    /// Selection 戦略効果
169    pub selection_performance: SelectionPerformance,
170    /// コンテキスト条件付き統計
171    pub contextual_stats: HashMap<(String, String), ActionStats>,
172    /// アクション別統計
173    pub action_stats: HashMap<String, ActionStats>,
174}
175
176/// スナップショットのメタデータ
177#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct SnapshotMetadata {
179    /// シナリオ名
180    pub scenario_name: Option<String>,
181    /// タスクの説明
182    pub task_description: Option<String>,
183    /// 作成日時(Unix timestamp)
184    pub created_at: u64,
185    /// マージされたセッション数
186    pub session_count: u32,
187    /// 総エピソード数
188    pub total_episodes: u32,
189    /// 総アクション数
190    pub total_actions: u32,
191}
192
193impl Default for SnapshotMetadata {
194    fn default() -> Self {
195        Self {
196            scenario_name: None,
197            task_description: None,
198            created_at: SystemTime::now()
199                .duration_since(UNIX_EPOCH)
200                .map(|d| d.as_secs())
201                .unwrap_or(0),
202            session_count: 1,
203            total_episodes: 0,
204            total_actions: 0,
205        }
206    }
207}
208
209impl SnapshotMetadata {
210    /// シナリオ名を設定
211    pub fn with_scenario(mut self, name: impl Into<String>) -> Self {
212        self.scenario_name = Some(name.into());
213        self
214    }
215
216    /// タスク説明を設定
217    pub fn with_task(mut self, desc: impl Into<String>) -> Self {
218        self.task_description = Some(desc.into());
219        self
220    }
221}
222
223impl LearningSnapshot {
224    /// 空のスナップショットを作成
225    pub fn empty() -> Self {
226        Self {
227            version: SNAPSHOT_VERSION,
228            metadata: SnapshotMetadata::default(),
229            episode_transitions: EpisodeTransitions::default(),
230            ngram_stats: NgramStats::default(),
231            selection_performance: SelectionPerformance::default(),
232            contextual_stats: HashMap::new(),
233            action_stats: HashMap::new(),
234        }
235    }
236
237    /// メタデータを設定
238    pub fn with_metadata(mut self, metadata: SnapshotMetadata) -> Self {
239        self.metadata = metadata;
240        self
241    }
242}
243
244// ============================================================================
245// MergeStrategy - マージ戦略
246// ============================================================================
247
248/// マージ戦略
249#[derive(Debug, Clone, Copy, Default)]
250pub enum MergeStrategy {
251    /// 単純加算(デフォルト)
252    #[default]
253    Additive,
254    /// 時間減衰(古いデータの影響を減らす)
255    TimeDecay {
256        /// 半減期(セッション数)
257        half_life_sessions: u32,
258    },
259    /// 成功エピソード重視
260    SuccessWeighted,
261}
262
263// ============================================================================
264// FileSystemStorage - ファイルシステムベースの実装
265// ============================================================================
266
267/// ファイルシステムベースのストレージ実装
268pub struct FileSystemStorage {
269    base_dir: PathBuf,
270}
271
272impl FileSystemStorage {
273    /// 新しい FileSystemStorage を作成
274    pub fn new(base_dir: impl AsRef<Path>) -> std::io::Result<Self> {
275        let base_dir = base_dir.as_ref().to_path_buf();
276        fs::create_dir_all(&base_dir)?;
277        Ok(Self { base_dir })
278    }
279
280    /// ベースディレクトリを取得
281    pub fn base_dir(&self) -> &Path {
282        &self.base_dir
283    }
284
285    fn key_to_path(&self, key: &SnapshotKey) -> PathBuf {
286        match key {
287            SnapshotKey::Global => self.base_dir.join("global_stats.json"),
288            SnapshotKey::Scenario(scenario) => self
289                .base_dir
290                .join("scenarios")
291                .join(scenario)
292                .join("stats.json"),
293            SnapshotKey::Session {
294                scenario,
295                session_id,
296            } => self
297                .base_dir
298                .join("scenarios")
299                .join(scenario)
300                .join("sessions")
301                .join(&session_id.0)
302                .join("stats.json"),
303        }
304    }
305
306    fn sessions_dir(&self, scenario: &str) -> PathBuf {
307        self.base_dir
308            .join("scenarios")
309            .join(scenario)
310            .join("sessions")
311    }
312}
313
314impl SnapshotStorage for FileSystemStorage {
315    type Error = std::io::Error;
316
317    fn save(&self, key: &SnapshotKey, snapshot: &LearningSnapshot) -> Result<(), Self::Error> {
318        let path = self.key_to_path(key);
319        if let Some(parent) = path.parent() {
320            fs::create_dir_all(parent)?;
321        }
322        snapshot.save_json(&path)
323    }
324
325    fn load(&self, key: &SnapshotKey) -> Result<Option<LearningSnapshot>, Self::Error> {
326        let path = self.key_to_path(key);
327        if !path.exists() {
328            return Ok(None);
329        }
330        LearningSnapshot::load_json(&path).map(Some)
331    }
332
333    fn delete(&self, key: &SnapshotKey) -> Result<bool, Self::Error> {
334        let path = self.key_to_path(key);
335        if path.exists() {
336            fs::remove_file(&path)?;
337            Ok(true)
338        } else {
339            Ok(false)
340        }
341    }
342
343    fn exists(&self, key: &SnapshotKey) -> Result<bool, Self::Error> {
344        Ok(self.key_to_path(key).exists())
345    }
346}
347
348impl TimeSeriesQuery for FileSystemStorage {
349    type Error = std::io::Error;
350
351    fn query_range(
352        &self,
353        scenario: &str,
354        from: Timestamp,
355        to: Timestamp,
356    ) -> Result<Vec<LearningSnapshot>, Self::Error> {
357        let sessions = self.list_sessions(scenario)?;
358        let mut results = Vec::new();
359
360        for session_id in sessions {
361            if let Some(ts) = session_id.timestamp() {
362                if ts >= from && ts <= to {
363                    let key = SnapshotKey::Session {
364                        scenario: scenario.to_string(),
365                        session_id,
366                    };
367                    if let Some(snapshot) = self.load(&key)? {
368                        results.push(snapshot);
369                    }
370                }
371            }
372        }
373        Ok(results)
374    }
375
376    fn query_latest(
377        &self,
378        scenario: &str,
379        limit: usize,
380    ) -> Result<Vec<LearningSnapshot>, Self::Error> {
381        let mut sessions = self.list_sessions(scenario)?;
382        // 降順ソート(最新が先頭)
383        sessions.sort_by(|a, b| b.0.cmp(&a.0));
384        sessions.truncate(limit);
385
386        let mut results = Vec::new();
387        for session_id in sessions {
388            let key = SnapshotKey::Session {
389                scenario: scenario.to_string(),
390                session_id,
391            };
392            if let Some(snapshot) = self.load(&key)? {
393                results.push(snapshot);
394            }
395        }
396        Ok(results)
397    }
398
399    fn list_sessions(&self, scenario: &str) -> Result<Vec<SessionId>, Self::Error> {
400        let sessions_dir = self.sessions_dir(scenario);
401        if !sessions_dir.exists() {
402            return Ok(Vec::new());
403        }
404
405        let mut sessions = Vec::new();
406        for entry in fs::read_dir(sessions_dir)? {
407            let entry = entry?;
408            if entry.file_type()?.is_dir() {
409                if let Some(name) = entry.file_name().to_str() {
410                    sessions.push(SessionId(name.to_string()));
411                }
412            }
413        }
414        sessions.sort_by(|a, b| a.0.cmp(&b.0));
415        Ok(sessions)
416    }
417}
418
419// ============================================================================
420// LearningStore - Facade
421// ============================================================================
422
423/// 学習データの永続化マネージャ(Facade)
424///
425/// 内部で FileSystemStorage を使用。高レベル API を提供。
426pub struct LearningStore {
427    storage: FileSystemStorage,
428}
429
430impl LearningStore {
431    /// 新しい LearningStore を作成
432    pub fn new(base_dir: impl AsRef<Path>) -> std::io::Result<Self> {
433        let storage = FileSystemStorage::new(base_dir)?;
434        Ok(Self { storage })
435    }
436
437    /// デフォルトのパスで作成
438    pub fn default_path() -> PathBuf {
439        dirs::data_dir()
440            .unwrap_or_else(|| PathBuf::from("."))
441            .join("swarm-engine")
442            .join("learning")
443    }
444
445    /// 内部ストレージへの参照を取得
446    pub fn storage(&self) -> &FileSystemStorage {
447        &self.storage
448    }
449
450    // ========================================================================
451    // High-Level API(後方互換性)
452    // ========================================================================
453
454    /// グローバル統計をロード
455    pub fn load_global(&self) -> std::io::Result<LearningSnapshot> {
456        self.storage.load(&SnapshotKey::Global)?.ok_or_else(|| {
457            std::io::Error::new(std::io::ErrorKind::NotFound, "global stats not found")
458        })
459    }
460
461    /// グローバル統計を保存
462    pub fn save_global(&self, snapshot: &LearningSnapshot) -> std::io::Result<()> {
463        self.storage.save(&SnapshotKey::Global, snapshot)
464    }
465
466    /// シナリオ別統計をロード
467    pub fn load_scenario(&self, scenario: &str) -> std::io::Result<LearningSnapshot> {
468        self.storage
469            .load(&SnapshotKey::Scenario(scenario.to_string()))?
470            .ok_or_else(|| {
471                std::io::Error::new(std::io::ErrorKind::NotFound, "scenario stats not found")
472            })
473    }
474
475    /// シナリオ別統計を保存
476    pub fn save_scenario(
477        &self,
478        scenario: &str,
479        snapshot: &LearningSnapshot,
480    ) -> std::io::Result<()> {
481        self.storage
482            .save(&SnapshotKey::Scenario(scenario.to_string()), snapshot)
483    }
484
485    /// セッション統計を保存(自動マージ)
486    pub fn save_session(
487        &self,
488        scenario: &str,
489        snapshot: &LearningSnapshot,
490    ) -> std::io::Result<SessionId> {
491        // セッションIDを生成
492        let session_id = self.generate_session_id();
493
494        // セッション統計を保存
495        let key = SnapshotKey::Session {
496            scenario: scenario.to_string(),
497            session_id: session_id.clone(),
498        };
499        self.storage.save(&key, snapshot)?;
500
501        // メタデータを別途 JSON で保存(可読性用)
502        let meta_path = self.storage.key_to_path(&key).with_file_name("meta.json");
503        let meta_json = serde_json::to_string_pretty(&snapshot.metadata)?;
504        fs::write(meta_path, meta_json)?;
505
506        // シナリオ別統計にマージ
507        self.merge_into_scenario(scenario, snapshot)?;
508
509        // グローバル統計にマージ
510        self.merge_into_global(snapshot)?;
511
512        Ok(session_id)
513    }
514
515    /// シナリオのセッション一覧を取得
516    pub fn list_sessions(&self, scenario: &str) -> std::io::Result<Vec<SessionId>> {
517        self.storage.list_sessions(scenario)
518    }
519
520    /// 特定セッションの統計をロード
521    pub fn load_session(
522        &self,
523        scenario: &str,
524        session_id: &SessionId,
525    ) -> std::io::Result<LearningSnapshot> {
526        let key = SnapshotKey::Session {
527            scenario: scenario.to_string(),
528            session_id: session_id.clone(),
529        };
530        self.storage
531            .load(&key)?
532            .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::NotFound, "session not found"))
533    }
534
535    // ========================================================================
536    // Time Series Query(新規 API)
537    // ========================================================================
538
539    /// 時間範囲でセッションを取得
540    pub fn query_range(
541        &self,
542        scenario: &str,
543        from: Timestamp,
544        to: Timestamp,
545    ) -> std::io::Result<Vec<LearningSnapshot>> {
546        self.storage.query_range(scenario, from, to)
547    }
548
549    /// 最新 N 件のセッションを取得
550    pub fn query_latest(
551        &self,
552        scenario: &str,
553        limit: usize,
554    ) -> std::io::Result<Vec<LearningSnapshot>> {
555        self.storage.query_latest(scenario, limit)
556    }
557
558    // ========================================================================
559    // Merge Logic
560    // ========================================================================
561
562    /// 複数のスナップショットをマージ
563    pub fn merge(
564        &self,
565        snapshots: &[LearningSnapshot],
566        strategy: MergeStrategy,
567    ) -> LearningSnapshot {
568        merge_snapshots(snapshots, strategy)
569    }
570
571    fn generate_session_id(&self) -> SessionId {
572        let timestamp = SystemTime::now()
573            .duration_since(UNIX_EPOCH)
574            .map(|d| d.as_secs())
575            .unwrap_or(0);
576        SessionId(format!("{:010}", timestamp))
577    }
578
579    fn merge_into_scenario(
580        &self,
581        scenario: &str,
582        snapshot: &LearningSnapshot,
583    ) -> std::io::Result<()> {
584        let existing = self
585            .storage
586            .load(&SnapshotKey::Scenario(scenario.to_string()))?;
587        let merged = match existing {
588            Some(existing) => {
589                merge_snapshots(&[existing, snapshot.clone()], MergeStrategy::Additive)
590            }
591            None => snapshot.clone(),
592        };
593        self.save_scenario(scenario, &merged)
594    }
595
596    fn merge_into_global(&self, snapshot: &LearningSnapshot) -> std::io::Result<()> {
597        let existing = self.storage.load(&SnapshotKey::Global)?;
598        let merged = match existing {
599            Some(existing) => {
600                merge_snapshots(&[existing, snapshot.clone()], MergeStrategy::Additive)
601            }
602            None => snapshot.clone(),
603        };
604        self.save_global(&merged)
605    }
606
607    // ========================================================================
608    // Offline Model API
609    // ========================================================================
610
611    /// Offline Model をロード
612    pub fn load_offline_model(
613        &self,
614        scenario: &str,
615    ) -> std::io::Result<super::offline::OfflineModel> {
616        let path = self.offline_model_path(scenario);
617        if !path.exists() {
618            return Err(std::io::Error::new(
619                std::io::ErrorKind::NotFound,
620                "offline model not found",
621            ));
622        }
623        let json = fs::read_to_string(&path)?;
624        serde_json::from_str(&json)
625            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
626    }
627
628    /// Offline Model を保存
629    pub fn save_offline_model(
630        &self,
631        scenario: &str,
632        model: &super::offline::OfflineModel,
633    ) -> std::io::Result<()> {
634        let path = self.offline_model_path(scenario);
635        if let Some(parent) = path.parent() {
636            fs::create_dir_all(parent)?;
637        }
638        let json = serde_json::to_string_pretty(model)?;
639        fs::write(path, json)
640    }
641
642    /// Offline 学習を実行
643    pub fn run_offline_learning(
644        &self,
645        scenario: &str,
646        session_limit: usize,
647    ) -> std::io::Result<super::offline::OfflineModel> {
648        let snapshots = self.query_latest(scenario, session_limit)?;
649        if snapshots.is_empty() {
650            return Err(std::io::Error::new(
651                std::io::ErrorKind::NotFound,
652                "no sessions found for offline learning",
653            ));
654        }
655
656        let analyzer = super::offline::OfflineAnalyzer::new(&snapshots);
657        let model = analyzer.analyze();
658
659        self.save_offline_model(scenario, &model)?;
660
661        Ok(model)
662    }
663
664    fn offline_model_path(&self, scenario: &str) -> PathBuf {
665        self.storage
666            .base_dir()
667            .join("scenarios")
668            .join(scenario)
669            .join("offline_model.json")
670    }
671}
672
673// ============================================================================
674// Merge Function(独立関数として分離)
675// ============================================================================
676
677/// 複数のスナップショットをマージ
678pub fn merge_snapshots(
679    snapshots: &[LearningSnapshot],
680    strategy: MergeStrategy,
681) -> LearningSnapshot {
682    if snapshots.is_empty() {
683        return LearningSnapshot::empty();
684    }
685    if snapshots.len() == 1 {
686        return snapshots[0].clone();
687    }
688
689    let mut result = LearningSnapshot::empty();
690
691    // 重みを計算
692    let weights: Vec<f64> = match strategy {
693        MergeStrategy::Additive => vec![1.0; snapshots.len()],
694        MergeStrategy::TimeDecay { half_life_sessions } => {
695            let half_life = half_life_sessions as f64;
696            snapshots
697                .iter()
698                .enumerate()
699                .map(|(i, _)| {
700                    let age = (snapshots.len() - 1 - i) as f64;
701                    0.5_f64.powf(age / half_life)
702                })
703                .collect()
704        }
705        MergeStrategy::SuccessWeighted => snapshots
706            .iter()
707            .map(|s| {
708                let total = s.metadata.total_episodes as f64;
709                let success = s.episode_transitions.success_episodes as f64;
710                if total == 0.0 {
711                    1.0
712                } else {
713                    1.0 + success / total
714                }
715            })
716            .collect(),
717    };
718
719    // メタデータをマージ
720    result.metadata = SnapshotMetadata {
721        scenario_name: snapshots
722            .last()
723            .and_then(|s| s.metadata.scenario_name.clone()),
724        task_description: snapshots
725            .last()
726            .and_then(|s| s.metadata.task_description.clone()),
727        created_at: SystemTime::now()
728            .duration_since(UNIX_EPOCH)
729            .map(|d| d.as_secs())
730            .unwrap_or(0),
731        session_count: snapshots.iter().map(|s| s.metadata.session_count).sum(),
732        total_episodes: snapshots.iter().map(|s| s.metadata.total_episodes).sum(),
733        total_actions: snapshots.iter().map(|s| s.metadata.total_actions).sum(),
734    };
735
736    // エピソード遷移をマージ
737    for (snapshot, &weight) in snapshots.iter().zip(weights.iter()) {
738        for (key, &count) in &snapshot.episode_transitions.success_transitions {
739            let weighted_count = (count as f64 * weight).round() as u32;
740            *result
741                .episode_transitions
742                .success_transitions
743                .entry(key.clone())
744                .or_default() += weighted_count;
745        }
746        for (key, &count) in &snapshot.episode_transitions.failure_transitions {
747            let weighted_count = (count as f64 * weight).round() as u32;
748            *result
749                .episode_transitions
750                .failure_transitions
751                .entry(key.clone())
752                .or_default() += weighted_count;
753        }
754        result.episode_transitions.success_episodes +=
755            (snapshot.episode_transitions.success_episodes as f64 * weight).round() as u32;
756        result.episode_transitions.failure_episodes +=
757            (snapshot.episode_transitions.failure_episodes as f64 * weight).round() as u32;
758    }
759
760    // N-gram をマージ
761    for (snapshot, &weight) in snapshots.iter().zip(weights.iter()) {
762        for (key, &(success, failure)) in &snapshot.ngram_stats.trigrams {
763            let entry = result
764                .ngram_stats
765                .trigrams
766                .entry(key.clone())
767                .or_insert((0, 0));
768            entry.0 += (success as f64 * weight).round() as u32;
769            entry.1 += (failure as f64 * weight).round() as u32;
770        }
771        for (key, &(success, failure)) in &snapshot.ngram_stats.quadgrams {
772            let entry = result
773                .ngram_stats
774                .quadgrams
775                .entry(key.clone())
776                .or_insert((0, 0));
777            entry.0 += (success as f64 * weight).round() as u32;
778            entry.1 += (failure as f64 * weight).round() as u32;
779        }
780    }
781
782    // コンテキスト統計をマージ
783    for (snapshot, &weight) in snapshots.iter().zip(weights.iter()) {
784        for (key, stats) in &snapshot.contextual_stats {
785            let entry = result.contextual_stats.entry(key.clone()).or_default();
786            entry.visits += (stats.visits as f64 * weight).round() as u32;
787            entry.successes += (stats.successes as f64 * weight).round() as u32;
788            entry.failures += (stats.failures as f64 * weight).round() as u32;
789        }
790    }
791
792    // アクション統計をマージ
793    for (snapshot, &weight) in snapshots.iter().zip(weights.iter()) {
794        for (key, stats) in &snapshot.action_stats {
795            let entry = result.action_stats.entry(key.clone()).or_default();
796            entry.visits += (stats.visits as f64 * weight).round() as u32;
797            entry.successes += (stats.successes as f64 * weight).round() as u32;
798            entry.failures += (stats.failures as f64 * weight).round() as u32;
799        }
800    }
801
802    // Selection パフォーマンスをマージ
803    for (snapshot, &weight) in snapshots.iter().zip(weights.iter()) {
804        for (strat, stats) in &snapshot.selection_performance.strategy_stats {
805            let entry = result
806                .selection_performance
807                .strategy_stats
808                .entry(strat.clone())
809                .or_default();
810            entry.visits += (stats.visits as f64 * weight).round() as u32;
811            entry.successes += (stats.successes as f64 * weight).round() as u32;
812            entry.failures += (stats.failures as f64 * weight).round() as u32;
813            entry.episodes_success += (stats.episodes_success as f64 * weight).round() as u32;
814            entry.episodes_failure += (stats.episodes_failure as f64 * weight).round() as u32;
815        }
816    }
817
818    result
819}
820
821// ============================================================================
822// JSON Export/Import
823// ============================================================================
824
825impl LearningSnapshot {
826    /// JSONとしてエクスポート
827    pub fn to_json(&self) -> serde_json::Result<String> {
828        serde_json::to_string_pretty(self)
829    }
830
831    /// JSONからインポート
832    pub fn from_json(json: &str) -> serde_json::Result<Self> {
833        serde_json::from_str(json)
834    }
835
836    /// JSONファイルとして保存
837    pub fn save_json(&self, path: impl AsRef<Path>) -> std::io::Result<()> {
838        let json = self.to_json()?;
839        fs::write(path, json)
840    }
841
842    /// JSONファイルからロード
843    pub fn load_json(path: impl AsRef<Path>) -> std::io::Result<Self> {
844        let json = fs::read_to_string(path)?;
845        Self::from_json(&json).map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
846    }
847}
848
849// ============================================================================
850// Tests
851// ============================================================================
852
853#[cfg(test)]
854mod tests {
855    use super::*;
856    use tempfile::tempdir;
857
858    #[test]
859    fn test_snapshot_serialization() {
860        let snapshot = LearningSnapshot::empty()
861            .with_metadata(SnapshotMetadata::default().with_scenario("test"));
862
863        let json = snapshot.to_json().unwrap();
864        let loaded = LearningSnapshot::from_json(&json).unwrap();
865
866        assert_eq!(loaded.version, SNAPSHOT_VERSION);
867        assert_eq!(loaded.metadata.scenario_name, Some("test".to_string()));
868    }
869
870    #[test]
871    fn test_learning_store_save_load() {
872        let dir = tempdir().unwrap();
873        let store = LearningStore::new(dir.path()).unwrap();
874
875        let snapshot = LearningSnapshot::empty()
876            .with_metadata(SnapshotMetadata::default().with_scenario("troubleshooting"));
877
878        store.save_scenario("troubleshooting", &snapshot).unwrap();
879        let loaded = store.load_scenario("troubleshooting").unwrap();
880
881        assert_eq!(
882            loaded.metadata.scenario_name,
883            Some("troubleshooting".to_string())
884        );
885    }
886
887    #[test]
888    fn test_merge_additive() {
889        let dir = tempdir().unwrap();
890        let store = LearningStore::new(dir.path()).unwrap();
891
892        let mut s1 = LearningSnapshot::empty();
893        s1.episode_transitions
894            .success_transitions
895            .insert(("A".to_string(), "B".to_string()), 5);
896        s1.metadata.total_episodes = 10;
897
898        let mut s2 = LearningSnapshot::empty();
899        s2.episode_transitions
900            .success_transitions
901            .insert(("A".to_string(), "B".to_string()), 3);
902        s2.metadata.total_episodes = 5;
903
904        let merged = store.merge(&[s1, s2], MergeStrategy::Additive);
905
906        assert_eq!(
907            merged
908                .episode_transitions
909                .success_transitions
910                .get(&("A".to_string(), "B".to_string())),
911            Some(&8)
912        );
913        assert_eq!(merged.metadata.total_episodes, 15);
914    }
915
916    #[test]
917    fn test_merge_time_decay() {
918        let dir = tempdir().unwrap();
919        let store = LearningStore::new(dir.path()).unwrap();
920
921        // 古いセッション(最初)
922        let mut s1 = LearningSnapshot::empty();
923        s1.episode_transitions
924            .success_transitions
925            .insert(("A".to_string(), "B".to_string()), 100);
926
927        // 新しいセッション(最後)
928        let mut s2 = LearningSnapshot::empty();
929        s2.episode_transitions
930            .success_transitions
931            .insert(("A".to_string(), "B".to_string()), 100);
932
933        let merged = store.merge(
934            &[s1, s2],
935            MergeStrategy::TimeDecay {
936                half_life_sessions: 1,
937            },
938        );
939
940        // s1 は重み 0.5、s2 は重み 1.0 → 50 + 100 = 150
941        let count = merged
942            .episode_transitions
943            .success_transitions
944            .get(&("A".to_string(), "B".to_string()))
945            .unwrap();
946        assert_eq!(*count, 150);
947    }
948
949    #[test]
950    fn test_session_management() {
951        let dir = tempdir().unwrap();
952        let store = LearningStore::new(dir.path()).unwrap();
953
954        // LearningSnapshot を直接作成
955        let metadata = SnapshotMetadata::default().with_scenario("test");
956        let snapshot = LearningSnapshot {
957            version: SNAPSHOT_VERSION,
958            metadata,
959            action_stats: Default::default(),
960            episode_transitions: Default::default(),
961            ngram_stats: Default::default(),
962            selection_performance: Default::default(),
963            contextual_stats: Default::default(),
964        };
965
966        let session_id = store.save_session("test", &snapshot).unwrap();
967        assert!(!session_id.0.is_empty());
968
969        let sessions = store.list_sessions("test").unwrap();
970        assert_eq!(sessions.len(), 1);
971        assert_eq!(sessions[0], session_id);
972    }
973}