Skip to main content

swarm_engine_core/learn/
snapshot.rs

1//! Learning Snapshot - 学習データの永続化システム
2//!
3//! Multi-Session 間で統計を共有・マージするための永続化層。
4//!
5//! # アーキテクチャ
6//!
7//! ```text
8//! ┌─────────────────────────────────────────────────────────────┐
9//! │                    LearningStore                            │
10//! │                   (Facade / Entry Point)                    │
11//! ├─────────────────────────────────────────────────────────────┤
12//! │                                                             │
13//! │  ┌─────────────────┐    ┌─────────────────────────────┐    │
14//! │  │ SnapshotStorage │    │     TimeSeriesQuery         │    │
15//! │  │     (trait)     │    │        (trait)              │    │
16//! │  ├─────────────────┤    ├─────────────────────────────┤    │
17//! │  │ save(key, snap) │    │ query_range(from, to)       │    │
18//! │  │ load(key)       │    │ query_latest(n)             │    │
19//! │  │ delete(key)     │    │ list_sessions(scenario)     │    │
20//! │  │ exists(key)     │    └─────────────────────────────┘    │
21//! │  └─────────────────┘                                       │
22//! │           │                         │                      │
23//! │           └────────────┬────────────┘                      │
24//! │                        ▼                                   │
25//! │              ┌─────────────────┐                           │
26//! │              │FileSystemStorage│                           │
27//! │              │  (implements)   │                           │
28//! │              └─────────────────┘                           │
29//! └─────────────────────────────────────────────────────────────┘
30//!
31//! ~/.swarm-engine/learning/
32//! ├── global_stats.json
33//! └── scenarios/
34//!     └── troubleshooting/
35//!         ├── stats.json
36//!         └── sessions/
37//!             └── {timestamp}/
38//!                 ├── meta.json
39//!                 └── stats.json
40//! ```
41
42use std::collections::HashMap;
43use std::fs;
44use std::path::{Path, PathBuf};
45use std::time::{SystemTime, UNIX_EPOCH};
46
47use serde::{Deserialize, Serialize};
48
49use super::session_group::{LearningPhase, SessionGroupId};
50use super::{EpisodeTransitions, NgramStats, SelectionPerformance};
51use crate::online_stats::ActionStats;
52
53/// フォーマットバージョン
54pub const SNAPSHOT_VERSION: u32 = 1;
55
56/// タイムスタンプ(Unix seconds)
57pub type Timestamp = u64;
58
59/// セッションID(タイムスタンプベース)
60#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
61pub struct SessionId(pub String);
62
63impl SessionId {
64    /// タイムスタンプから SessionId を取得
65    pub fn timestamp(&self) -> Option<Timestamp> {
66        self.0.parse().ok()
67    }
68}
69
70// ============================================================================
71// Storage Traits - 永続化の抽象化
72// ============================================================================
73
74/// スナップショットのキー(階層的)
75#[derive(Debug, Clone, PartialEq, Eq, Hash)]
76pub enum SnapshotKey {
77    /// グローバル統計
78    Global,
79    /// シナリオ別統計
80    Scenario(String),
81    /// セッション別統計
82    Session {
83        scenario: String,
84        session_id: SessionId,
85    },
86}
87
88impl std::fmt::Display for SnapshotKey {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        match self {
91            Self::Global => write!(f, "global"),
92            Self::Scenario(s) => write!(f, "scenario:{}", s),
93            Self::Session {
94                scenario,
95                session_id,
96            } => {
97                write!(f, "session:{}:{}", scenario, session_id.0)
98            }
99        }
100    }
101}
102
103/// スナップショットの永続化(CRUD)
104pub trait SnapshotStorage {
105    /// エラー型
106    type Error: std::error::Error + Send + Sync + 'static;
107
108    /// スナップショットを保存
109    fn save(&self, key: &SnapshotKey, snapshot: &LearningSnapshot) -> Result<(), Self::Error>;
110
111    /// スナップショットをロード
112    fn load(&self, key: &SnapshotKey) -> Result<Option<LearningSnapshot>, Self::Error>;
113
114    /// スナップショットを削除
115    fn delete(&self, key: &SnapshotKey) -> Result<bool, Self::Error>;
116
117    /// スナップショットの存在確認
118    fn exists(&self, key: &SnapshotKey) -> Result<bool, Self::Error>;
119}
120
121/// 時系列クエリ
122pub trait TimeSeriesQuery {
123    /// エラー型
124    type Error: std::error::Error + Send + Sync + 'static;
125
126    /// 時間範囲でセッションを取得
127    fn query_range(
128        &self,
129        scenario: &str,
130        from: Timestamp,
131        to: Timestamp,
132    ) -> Result<Vec<LearningSnapshot>, Self::Error>;
133
134    /// 最新 N 件のセッションを取得
135    fn query_latest(
136        &self,
137        scenario: &str,
138        limit: usize,
139    ) -> Result<Vec<LearningSnapshot>, Self::Error>;
140
141    /// 指定時刻以降のセッションを取得
142    fn query_since(
143        &self,
144        scenario: &str,
145        since: Timestamp,
146    ) -> Result<Vec<LearningSnapshot>, Self::Error> {
147        self.query_range(scenario, since, u64::MAX)
148    }
149
150    /// シナリオのセッション一覧を取得
151    fn list_sessions(&self, scenario: &str) -> Result<Vec<SessionId>, Self::Error>;
152}
153
154// ============================================================================
155// LearningSnapshot - 永続化単位
156// ============================================================================
157
158/// 学習データのスナップショット(永続化単位)
159#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct LearningSnapshot {
161    /// フォーマットバージョン(互換性用)
162    pub version: u32,
163    /// メタデータ
164    pub metadata: SnapshotMetadata,
165    /// エピソード遷移統計(2-gram)
166    pub episode_transitions: EpisodeTransitions,
167    /// N-gram 統計(3-gram, 4-gram)
168    pub ngram_stats: NgramStats,
169    /// Selection 戦略効果
170    pub selection_performance: SelectionPerformance,
171    /// コンテキスト条件付き統計
172    pub contextual_stats: HashMap<(String, String), ActionStats>,
173    /// アクション別統計
174    pub action_stats: HashMap<String, ActionStats>,
175}
176
177/// スナップショットのメタデータ
178#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct SnapshotMetadata {
180    /// シナリオ名
181    pub scenario_name: Option<String>,
182    /// タスクの説明
183    pub task_description: Option<String>,
184    /// 作成日時(Unix timestamp)
185    pub created_at: u64,
186    /// マージされたセッション数
187    pub session_count: u32,
188    /// 総エピソード数
189    pub total_episodes: u32,
190    /// 総アクション数
191    pub total_actions: u32,
192    /// 学習フェーズ(Bootstrap / Release / Validate)
193    #[serde(default, skip_serializing_if = "Option::is_none")]
194    pub phase: Option<LearningPhase>,
195    /// セッショングループ ID
196    #[serde(default, skip_serializing_if = "Option::is_none")]
197    pub group_id: Option<SessionGroupId>,
198}
199
200impl Default for SnapshotMetadata {
201    fn default() -> Self {
202        Self {
203            scenario_name: None,
204            task_description: None,
205            created_at: SystemTime::now()
206                .duration_since(UNIX_EPOCH)
207                .map(|d| d.as_secs())
208                .unwrap_or(0),
209            session_count: 1,
210            total_episodes: 0,
211            total_actions: 0,
212            phase: None,
213            group_id: None,
214        }
215    }
216}
217
218impl SnapshotMetadata {
219    /// シナリオ名を設定
220    pub fn with_scenario(mut self, name: impl Into<String>) -> Self {
221        self.scenario_name = Some(name.into());
222        self
223    }
224
225    /// タスク説明を設定
226    pub fn with_task(mut self, desc: impl Into<String>) -> Self {
227        self.task_description = Some(desc.into());
228        self
229    }
230
231    /// 学習フェーズを設定
232    pub fn with_phase(mut self, phase: LearningPhase) -> Self {
233        self.phase = Some(phase);
234        self
235    }
236
237    /// セッショングループ ID を設定
238    pub fn with_group_id(mut self, group_id: SessionGroupId) -> Self {
239        self.group_id = Some(group_id);
240        self
241    }
242}
243
244impl LearningSnapshot {
245    /// 空のスナップショットを作成
246    pub fn empty() -> Self {
247        Self {
248            version: SNAPSHOT_VERSION,
249            metadata: SnapshotMetadata::default(),
250            episode_transitions: EpisodeTransitions::default(),
251            ngram_stats: NgramStats::default(),
252            selection_performance: SelectionPerformance::default(),
253            contextual_stats: HashMap::new(),
254            action_stats: HashMap::new(),
255        }
256    }
257
258    /// メタデータを設定
259    pub fn with_metadata(mut self, metadata: SnapshotMetadata) -> Self {
260        self.metadata = metadata;
261        self
262    }
263}
264
265// ============================================================================
266// MergeStrategy - マージ戦略
267// ============================================================================
268
269/// マージ戦略
270#[derive(Debug, Clone, Copy, Default)]
271pub enum MergeStrategy {
272    /// 単純加算(デフォルト)
273    #[default]
274    Additive,
275    /// 時間減衰(古いデータの影響を減らす)
276    TimeDecay {
277        /// 半減期(セッション数)
278        half_life_sessions: u32,
279    },
280    /// 成功エピソード重視
281    SuccessWeighted,
282}
283
284// ============================================================================
285// FileSystemStorage - ファイルシステムベースの実装
286// ============================================================================
287
288/// ファイルシステムベースのストレージ実装
289#[derive(Clone)]
290pub struct FileSystemStorage {
291    base_dir: PathBuf,
292}
293
294impl FileSystemStorage {
295    /// 新しい FileSystemStorage を作成
296    pub fn new(base_dir: impl AsRef<Path>) -> std::io::Result<Self> {
297        let base_dir = base_dir.as_ref().to_path_buf();
298        fs::create_dir_all(&base_dir)?;
299        Ok(Self { base_dir })
300    }
301
302    /// ベースディレクトリを取得
303    pub fn base_dir(&self) -> &Path {
304        &self.base_dir
305    }
306
307    fn key_to_path(&self, key: &SnapshotKey) -> PathBuf {
308        match key {
309            SnapshotKey::Global => self.base_dir.join("global_stats.json"),
310            SnapshotKey::Scenario(scenario) => self
311                .base_dir
312                .join("scenarios")
313                .join(scenario)
314                .join("stats.json"),
315            SnapshotKey::Session {
316                scenario,
317                session_id,
318            } => self
319                .base_dir
320                .join("scenarios")
321                .join(scenario)
322                .join("sessions")
323                .join(&session_id.0)
324                .join("stats.json"),
325        }
326    }
327
328    fn sessions_dir(&self, scenario: &str) -> PathBuf {
329        self.base_dir
330            .join("scenarios")
331            .join(scenario)
332            .join("sessions")
333    }
334}
335
336impl SnapshotStorage for FileSystemStorage {
337    type Error = std::io::Error;
338
339    fn save(&self, key: &SnapshotKey, snapshot: &LearningSnapshot) -> Result<(), Self::Error> {
340        let path = self.key_to_path(key);
341        if let Some(parent) = path.parent() {
342            fs::create_dir_all(parent)?;
343        }
344        snapshot.save_json(&path)
345    }
346
347    fn load(&self, key: &SnapshotKey) -> Result<Option<LearningSnapshot>, Self::Error> {
348        let path = self.key_to_path(key);
349        if !path.exists() {
350            return Ok(None);
351        }
352        LearningSnapshot::load_json(&path).map(Some)
353    }
354
355    fn delete(&self, key: &SnapshotKey) -> Result<bool, Self::Error> {
356        let path = self.key_to_path(key);
357        if path.exists() {
358            fs::remove_file(&path)?;
359            Ok(true)
360        } else {
361            Ok(false)
362        }
363    }
364
365    fn exists(&self, key: &SnapshotKey) -> Result<bool, Self::Error> {
366        Ok(self.key_to_path(key).exists())
367    }
368}
369
370impl TimeSeriesQuery for FileSystemStorage {
371    type Error = std::io::Error;
372
373    fn query_range(
374        &self,
375        scenario: &str,
376        from: Timestamp,
377        to: Timestamp,
378    ) -> Result<Vec<LearningSnapshot>, Self::Error> {
379        let sessions = self.list_sessions(scenario)?;
380        let mut results = Vec::new();
381
382        for session_id in sessions {
383            if let Some(ts) = session_id.timestamp() {
384                if ts >= from && ts <= to {
385                    let key = SnapshotKey::Session {
386                        scenario: scenario.to_string(),
387                        session_id,
388                    };
389                    if let Some(snapshot) = self.load(&key)? {
390                        results.push(snapshot);
391                    }
392                }
393            }
394        }
395        Ok(results)
396    }
397
398    fn query_latest(
399        &self,
400        scenario: &str,
401        limit: usize,
402    ) -> Result<Vec<LearningSnapshot>, Self::Error> {
403        let mut sessions = self.list_sessions(scenario)?;
404        // 降順ソート(最新が先頭)
405        sessions.sort_by(|a, b| b.0.cmp(&a.0));
406        sessions.truncate(limit);
407
408        let mut results = Vec::new();
409        for session_id in sessions {
410            let key = SnapshotKey::Session {
411                scenario: scenario.to_string(),
412                session_id,
413            };
414            if let Some(snapshot) = self.load(&key)? {
415                results.push(snapshot);
416            }
417        }
418        Ok(results)
419    }
420
421    fn list_sessions(&self, scenario: &str) -> Result<Vec<SessionId>, Self::Error> {
422        let sessions_dir = self.sessions_dir(scenario);
423        if !sessions_dir.exists() {
424            return Ok(Vec::new());
425        }
426
427        let mut sessions = Vec::new();
428        for entry in fs::read_dir(sessions_dir)? {
429            let entry = entry?;
430            if entry.file_type()?.is_dir() {
431                if let Some(name) = entry.file_name().to_str() {
432                    sessions.push(SessionId(name.to_string()));
433                }
434            }
435        }
436        sessions.sort_by(|a, b| a.0.cmp(&b.0));
437        Ok(sessions)
438    }
439}
440
441// ============================================================================
442// LearningStore - Facade
443// ============================================================================
444
445/// 学習データの永続化マネージャ(Facade)
446///
447/// 内部で FileSystemStorage を使用。高レベル API を提供。
448#[derive(Clone)]
449pub struct LearningStore {
450    storage: FileSystemStorage,
451}
452
453impl LearningStore {
454    /// 新しい LearningStore を作成
455    pub fn new(base_dir: impl AsRef<Path>) -> std::io::Result<Self> {
456        let storage = FileSystemStorage::new(base_dir)?;
457        Ok(Self { storage })
458    }
459
460    /// デフォルトのパスで作成
461    pub fn default_path() -> PathBuf {
462        dirs::data_dir()
463            .unwrap_or_else(|| PathBuf::from("."))
464            .join("swarm-engine")
465            .join("learning")
466    }
467
468    /// 内部ストレージへの参照を取得
469    pub fn storage(&self) -> &FileSystemStorage {
470        &self.storage
471    }
472
473    // ========================================================================
474    // High-Level API(後方互換性)
475    // ========================================================================
476
477    /// グローバル統計をロード
478    pub fn load_global(&self) -> std::io::Result<LearningSnapshot> {
479        self.storage.load(&SnapshotKey::Global)?.ok_or_else(|| {
480            std::io::Error::new(std::io::ErrorKind::NotFound, "global stats not found")
481        })
482    }
483
484    /// グローバル統計を保存
485    pub fn save_global(&self, snapshot: &LearningSnapshot) -> std::io::Result<()> {
486        self.storage.save(&SnapshotKey::Global, snapshot)
487    }
488
489    /// シナリオ別統計をロード
490    pub fn load_scenario(&self, scenario: &str) -> std::io::Result<LearningSnapshot> {
491        self.storage
492            .load(&SnapshotKey::Scenario(scenario.to_string()))?
493            .ok_or_else(|| {
494                std::io::Error::new(std::io::ErrorKind::NotFound, "scenario stats not found")
495            })
496    }
497
498    /// シナリオ別統計を保存
499    pub fn save_scenario(
500        &self,
501        scenario: &str,
502        snapshot: &LearningSnapshot,
503    ) -> std::io::Result<()> {
504        self.storage
505            .save(&SnapshotKey::Scenario(scenario.to_string()), snapshot)
506    }
507
508    /// セッション統計を保存(自動マージ)
509    pub fn save_session(
510        &self,
511        scenario: &str,
512        snapshot: &LearningSnapshot,
513    ) -> std::io::Result<SessionId> {
514        // セッションIDを生成
515        let session_id = self.generate_session_id();
516
517        // セッション統計を保存
518        let key = SnapshotKey::Session {
519            scenario: scenario.to_string(),
520            session_id: session_id.clone(),
521        };
522        self.storage.save(&key, snapshot)?;
523
524        // メタデータを別途 JSON で保存(可読性用)
525        let meta_path = self.storage.key_to_path(&key).with_file_name("meta.json");
526        let meta_json = serde_json::to_string_pretty(&snapshot.metadata)?;
527        fs::write(meta_path, meta_json)?;
528
529        // シナリオ別統計にマージ
530        self.merge_into_scenario(scenario, snapshot)?;
531
532        // グローバル統計にマージ
533        self.merge_into_global(snapshot)?;
534
535        Ok(session_id)
536    }
537
538    /// シナリオのセッション一覧を取得
539    pub fn list_sessions(&self, scenario: &str) -> std::io::Result<Vec<SessionId>> {
540        self.storage.list_sessions(scenario)
541    }
542
543    /// 特定セッションの統計をロード
544    pub fn load_session(
545        &self,
546        scenario: &str,
547        session_id: &SessionId,
548    ) -> std::io::Result<LearningSnapshot> {
549        let key = SnapshotKey::Session {
550            scenario: scenario.to_string(),
551            session_id: session_id.clone(),
552        };
553        self.storage
554            .load(&key)?
555            .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::NotFound, "session not found"))
556    }
557
558    // ========================================================================
559    // Time Series Query(新規 API)
560    // ========================================================================
561
562    /// 時間範囲でセッションを取得
563    pub fn query_range(
564        &self,
565        scenario: &str,
566        from: Timestamp,
567        to: Timestamp,
568    ) -> std::io::Result<Vec<LearningSnapshot>> {
569        self.storage.query_range(scenario, from, to)
570    }
571
572    /// 最新 N 件のセッションを取得
573    pub fn query_latest(
574        &self,
575        scenario: &str,
576        limit: usize,
577    ) -> std::io::Result<Vec<LearningSnapshot>> {
578        self.storage.query_latest(scenario, limit)
579    }
580
581    // ========================================================================
582    // Merge Logic
583    // ========================================================================
584
585    /// 複数のスナップショットをマージ
586    pub fn merge(
587        &self,
588        snapshots: &[LearningSnapshot],
589        strategy: MergeStrategy,
590    ) -> LearningSnapshot {
591        merge_snapshots(snapshots, strategy)
592    }
593
594    fn generate_session_id(&self) -> SessionId {
595        let timestamp = SystemTime::now()
596            .duration_since(UNIX_EPOCH)
597            .map(|d| d.as_secs())
598            .unwrap_or(0);
599        SessionId(format!("{:010}", timestamp))
600    }
601
602    fn merge_into_scenario(
603        &self,
604        scenario: &str,
605        snapshot: &LearningSnapshot,
606    ) -> std::io::Result<()> {
607        let existing = self
608            .storage
609            .load(&SnapshotKey::Scenario(scenario.to_string()))?;
610        let merged = match existing {
611            Some(existing) => {
612                merge_snapshots(&[existing, snapshot.clone()], MergeStrategy::Additive)
613            }
614            None => snapshot.clone(),
615        };
616        self.save_scenario(scenario, &merged)
617    }
618
619    fn merge_into_global(&self, snapshot: &LearningSnapshot) -> std::io::Result<()> {
620        let existing = self.storage.load(&SnapshotKey::Global)?;
621        let merged = match existing {
622            Some(existing) => {
623                merge_snapshots(&[existing, snapshot.clone()], MergeStrategy::Additive)
624            }
625            None => snapshot.clone(),
626        };
627        self.save_global(&merged)
628    }
629
630    // ========================================================================
631    // Offline Model API
632    // ========================================================================
633
634    /// Offline Model をロード
635    pub fn load_offline_model(
636        &self,
637        scenario: &str,
638    ) -> std::io::Result<super::offline::OfflineModel> {
639        let path = self.offline_model_path(scenario);
640        if !path.exists() {
641            return Err(std::io::Error::new(
642                std::io::ErrorKind::NotFound,
643                "offline model not found",
644            ));
645        }
646        let json = fs::read_to_string(&path)?;
647        serde_json::from_str(&json)
648            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
649    }
650
651    /// Offline Model を保存
652    pub fn save_offline_model(
653        &self,
654        scenario: &str,
655        model: &super::offline::OfflineModel,
656    ) -> std::io::Result<()> {
657        let path = self.offline_model_path(scenario);
658        if let Some(parent) = path.parent() {
659            fs::create_dir_all(parent)?;
660        }
661        let json = serde_json::to_string_pretty(model)?;
662        fs::write(path, json)
663    }
664
665    /// Offline 学習を実行
666    pub fn run_offline_learning(
667        &self,
668        scenario: &str,
669        session_limit: usize,
670    ) -> std::io::Result<super::offline::OfflineModel> {
671        let snapshots = self.query_latest(scenario, session_limit)?;
672        if snapshots.is_empty() {
673            return Err(std::io::Error::new(
674                std::io::ErrorKind::NotFound,
675                "no sessions found for offline learning",
676            ));
677        }
678
679        let analyzer = super::offline::OfflineAnalyzer::new(&snapshots);
680        let model = analyzer.analyze();
681
682        self.save_offline_model(scenario, &model)?;
683
684        Ok(model)
685    }
686
687    fn offline_model_path(&self, scenario: &str) -> PathBuf {
688        self.storage
689            .base_dir()
690            .join("scenarios")
691            .join(scenario)
692            .join("offline_model.json")
693    }
694}
695
696// ============================================================================
697// Merge Function(独立関数として分離)
698// ============================================================================
699
700/// 複数のスナップショットをマージ
701pub fn merge_snapshots(
702    snapshots: &[LearningSnapshot],
703    strategy: MergeStrategy,
704) -> LearningSnapshot {
705    if snapshots.is_empty() {
706        return LearningSnapshot::empty();
707    }
708    if snapshots.len() == 1 {
709        return snapshots[0].clone();
710    }
711
712    let mut result = LearningSnapshot::empty();
713
714    // 重みを計算
715    let weights: Vec<f64> = match strategy {
716        MergeStrategy::Additive => vec![1.0; snapshots.len()],
717        MergeStrategy::TimeDecay { half_life_sessions } => {
718            let half_life = half_life_sessions as f64;
719            snapshots
720                .iter()
721                .enumerate()
722                .map(|(i, _)| {
723                    let age = (snapshots.len() - 1 - i) as f64;
724                    0.5_f64.powf(age / half_life)
725                })
726                .collect()
727        }
728        MergeStrategy::SuccessWeighted => snapshots
729            .iter()
730            .map(|s| {
731                let total = s.metadata.total_episodes as f64;
732                let success = s.episode_transitions.success_episodes as f64;
733                if total == 0.0 {
734                    1.0
735                } else {
736                    1.0 + success / total
737                }
738            })
739            .collect(),
740    };
741
742    // メタデータをマージ
743    result.metadata = SnapshotMetadata {
744        scenario_name: snapshots
745            .last()
746            .and_then(|s| s.metadata.scenario_name.clone()),
747        task_description: snapshots
748            .last()
749            .and_then(|s| s.metadata.task_description.clone()),
750        created_at: SystemTime::now()
751            .duration_since(UNIX_EPOCH)
752            .map(|d| d.as_secs())
753            .unwrap_or(0),
754        session_count: snapshots.iter().map(|s| s.metadata.session_count).sum(),
755        total_episodes: snapshots.iter().map(|s| s.metadata.total_episodes).sum(),
756        total_actions: snapshots.iter().map(|s| s.metadata.total_actions).sum(),
757        // マージ時は phase/group_id は引き継がない(マージ結果は複数のグループを含む可能性があるため)
758        phase: None,
759        group_id: None,
760    };
761
762    // エピソード遷移をマージ
763    for (snapshot, &weight) in snapshots.iter().zip(weights.iter()) {
764        for (key, &count) in &snapshot.episode_transitions.success_transitions {
765            let weighted_count = (count as f64 * weight).round() as u32;
766            *result
767                .episode_transitions
768                .success_transitions
769                .entry(key.clone())
770                .or_default() += weighted_count;
771        }
772        for (key, &count) in &snapshot.episode_transitions.failure_transitions {
773            let weighted_count = (count as f64 * weight).round() as u32;
774            *result
775                .episode_transitions
776                .failure_transitions
777                .entry(key.clone())
778                .or_default() += weighted_count;
779        }
780        result.episode_transitions.success_episodes +=
781            (snapshot.episode_transitions.success_episodes as f64 * weight).round() as u32;
782        result.episode_transitions.failure_episodes +=
783            (snapshot.episode_transitions.failure_episodes as f64 * weight).round() as u32;
784    }
785
786    // N-gram をマージ
787    for (snapshot, &weight) in snapshots.iter().zip(weights.iter()) {
788        for (key, &(success, failure)) in &snapshot.ngram_stats.trigrams {
789            let entry = result
790                .ngram_stats
791                .trigrams
792                .entry(key.clone())
793                .or_insert((0, 0));
794            entry.0 += (success as f64 * weight).round() as u32;
795            entry.1 += (failure as f64 * weight).round() as u32;
796        }
797        for (key, &(success, failure)) in &snapshot.ngram_stats.quadgrams {
798            let entry = result
799                .ngram_stats
800                .quadgrams
801                .entry(key.clone())
802                .or_insert((0, 0));
803            entry.0 += (success as f64 * weight).round() as u32;
804            entry.1 += (failure as f64 * weight).round() as u32;
805        }
806    }
807
808    // コンテキスト統計をマージ
809    for (snapshot, &weight) in snapshots.iter().zip(weights.iter()) {
810        for (key, stats) in &snapshot.contextual_stats {
811            let entry = result.contextual_stats.entry(key.clone()).or_default();
812            entry.visits += (stats.visits as f64 * weight).round() as u32;
813            entry.successes += (stats.successes as f64 * weight).round() as u32;
814            entry.failures += (stats.failures as f64 * weight).round() as u32;
815        }
816    }
817
818    // アクション統計をマージ
819    for (snapshot, &weight) in snapshots.iter().zip(weights.iter()) {
820        for (key, stats) in &snapshot.action_stats {
821            let entry = result.action_stats.entry(key.clone()).or_default();
822            entry.visits += (stats.visits as f64 * weight).round() as u32;
823            entry.successes += (stats.successes as f64 * weight).round() as u32;
824            entry.failures += (stats.failures as f64 * weight).round() as u32;
825        }
826    }
827
828    // Selection パフォーマンスをマージ
829    for (snapshot, &weight) in snapshots.iter().zip(weights.iter()) {
830        for (strat, stats) in &snapshot.selection_performance.strategy_stats {
831            let entry = result
832                .selection_performance
833                .strategy_stats
834                .entry(strat.clone())
835                .or_default();
836            entry.visits += (stats.visits as f64 * weight).round() as u32;
837            entry.successes += (stats.successes as f64 * weight).round() as u32;
838            entry.failures += (stats.failures as f64 * weight).round() as u32;
839            entry.episodes_success += (stats.episodes_success as f64 * weight).round() as u32;
840            entry.episodes_failure += (stats.episodes_failure as f64 * weight).round() as u32;
841        }
842    }
843
844    result
845}
846
847// ============================================================================
848// JSON Export/Import
849// ============================================================================
850
851impl LearningSnapshot {
852    /// JSONとしてエクスポート
853    pub fn to_json(&self) -> serde_json::Result<String> {
854        serde_json::to_string_pretty(self)
855    }
856
857    /// JSONからインポート
858    pub fn from_json(json: &str) -> serde_json::Result<Self> {
859        serde_json::from_str(json)
860    }
861
862    /// JSONファイルとして保存
863    pub fn save_json(&self, path: impl AsRef<Path>) -> std::io::Result<()> {
864        let json = self.to_json()?;
865        fs::write(path, json)
866    }
867
868    /// JSONファイルからロード
869    pub fn load_json(path: impl AsRef<Path>) -> std::io::Result<Self> {
870        let json = fs::read_to_string(path)?;
871        Self::from_json(&json).map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
872    }
873}
874
875// ============================================================================
876// Tests
877// ============================================================================
878
879#[cfg(test)]
880mod tests {
881    use super::*;
882    use tempfile::tempdir;
883
884    #[test]
885    fn test_snapshot_serialization() {
886        let snapshot = LearningSnapshot::empty()
887            .with_metadata(SnapshotMetadata::default().with_scenario("test"));
888
889        let json = snapshot.to_json().unwrap();
890        let loaded = LearningSnapshot::from_json(&json).unwrap();
891
892        assert_eq!(loaded.version, SNAPSHOT_VERSION);
893        assert_eq!(loaded.metadata.scenario_name, Some("test".to_string()));
894    }
895
896    #[test]
897    fn test_learning_store_save_load() {
898        let dir = tempdir().unwrap();
899        let store = LearningStore::new(dir.path()).unwrap();
900
901        let snapshot = LearningSnapshot::empty()
902            .with_metadata(SnapshotMetadata::default().with_scenario("troubleshooting"));
903
904        store.save_scenario("troubleshooting", &snapshot).unwrap();
905        let loaded = store.load_scenario("troubleshooting").unwrap();
906
907        assert_eq!(
908            loaded.metadata.scenario_name,
909            Some("troubleshooting".to_string())
910        );
911    }
912
913    #[test]
914    fn test_merge_additive() {
915        let dir = tempdir().unwrap();
916        let store = LearningStore::new(dir.path()).unwrap();
917
918        let mut s1 = LearningSnapshot::empty();
919        s1.episode_transitions
920            .success_transitions
921            .insert(("A".to_string(), "B".to_string()), 5);
922        s1.metadata.total_episodes = 10;
923
924        let mut s2 = LearningSnapshot::empty();
925        s2.episode_transitions
926            .success_transitions
927            .insert(("A".to_string(), "B".to_string()), 3);
928        s2.metadata.total_episodes = 5;
929
930        let merged = store.merge(&[s1, s2], MergeStrategy::Additive);
931
932        assert_eq!(
933            merged
934                .episode_transitions
935                .success_transitions
936                .get(&("A".to_string(), "B".to_string())),
937            Some(&8)
938        );
939        assert_eq!(merged.metadata.total_episodes, 15);
940    }
941
942    #[test]
943    fn test_merge_time_decay() {
944        let dir = tempdir().unwrap();
945        let store = LearningStore::new(dir.path()).unwrap();
946
947        // 古いセッション(最初)
948        let mut s1 = LearningSnapshot::empty();
949        s1.episode_transitions
950            .success_transitions
951            .insert(("A".to_string(), "B".to_string()), 100);
952
953        // 新しいセッション(最後)
954        let mut s2 = LearningSnapshot::empty();
955        s2.episode_transitions
956            .success_transitions
957            .insert(("A".to_string(), "B".to_string()), 100);
958
959        let merged = store.merge(
960            &[s1, s2],
961            MergeStrategy::TimeDecay {
962                half_life_sessions: 1,
963            },
964        );
965
966        // s1 は重み 0.5、s2 は重み 1.0 → 50 + 100 = 150
967        let count = merged
968            .episode_transitions
969            .success_transitions
970            .get(&("A".to_string(), "B".to_string()))
971            .unwrap();
972        assert_eq!(*count, 150);
973    }
974
975    #[test]
976    fn test_session_management() {
977        let dir = tempdir().unwrap();
978        let store = LearningStore::new(dir.path()).unwrap();
979
980        // LearningSnapshot を直接作成
981        let metadata = SnapshotMetadata::default().with_scenario("test");
982        let snapshot = LearningSnapshot {
983            version: SNAPSHOT_VERSION,
984            metadata,
985            action_stats: Default::default(),
986            episode_transitions: Default::default(),
987            ngram_stats: Default::default(),
988            selection_performance: Default::default(),
989            contextual_stats: Default::default(),
990        };
991
992        let session_id = store.save_session("test", &snapshot).unwrap();
993        assert!(!session_id.0.is_empty());
994
995        let sessions = store.list_sessions("test").unwrap();
996        assert_eq!(sessions.len(), 1);
997        assert_eq!(sessions[0], session_id);
998    }
999}