Skip to main content

swarm_engine_core/learn/
session_group.rs

1//! Session Group - 複数セッションをまとめる単位
2//!
3//! Bootstrap / Release / Validate の各フェーズで実行された
4//! 複数のセッションをグループとして管理する。
5
6use std::time::{SystemTime, UNIX_EPOCH};
7
8use serde::{Deserialize, Serialize};
9
10use super::snapshot::SessionId;
11
12/// セッショングループ ID
13#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
14pub struct SessionGroupId(pub String);
15
16impl SessionGroupId {
17    /// 新しいグループ ID を生成(タイムスタンプベース)
18    pub fn new() -> Self {
19        let timestamp = SystemTime::now()
20            .duration_since(UNIX_EPOCH)
21            .map(|d| d.as_millis())
22            .unwrap_or(0);
23        Self(format!("g{}", timestamp))
24    }
25
26    /// 文字列から生成
27    pub fn from_raw(s: impl Into<String>) -> Self {
28        Self(s.into())
29    }
30}
31
32impl Default for SessionGroupId {
33    fn default() -> Self {
34        Self::new()
35    }
36}
37
38impl std::fmt::Display for SessionGroupId {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        write!(f, "{}", self.0)
41    }
42}
43
44/// 学習フェーズ
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
46#[serde(rename_all = "snake_case")]
47pub enum LearningPhase {
48    /// Bootstrap: 正解グラフで強制的に成功させ、学習データを蓄積
49    Bootstrap,
50    /// Release: 学習済みモデルで自律実行
51    Release,
52    /// Validate: 検証・修正(将来用)
53    Validate,
54}
55
56impl std::fmt::Display for LearningPhase {
57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58        match self {
59            Self::Bootstrap => write!(f, "bootstrap"),
60            Self::Release => write!(f, "release"),
61            Self::Validate => write!(f, "validate"),
62        }
63    }
64}
65
66impl std::str::FromStr for LearningPhase {
67    type Err = String;
68
69    fn from_str(s: &str) -> Result<Self, Self::Err> {
70        match s.to_lowercase().as_str() {
71            "bootstrap" => Ok(Self::Bootstrap),
72            "release" => Ok(Self::Release),
73            "validate" => Ok(Self::Validate),
74            _ => Err(format!("Unknown phase: {}", s)),
75        }
76    }
77}
78
79/// セッショングループのメタデータ
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct SessionGroupMetadata {
82    /// シナリオ名
83    pub scenario: String,
84    /// 作成日時(Unix timestamp)
85    pub created_at: u64,
86    /// 完了日時(Unix timestamp)
87    pub completed_at: Option<u64>,
88    /// 目標実行回数
89    pub target_runs: usize,
90    /// 成功回数
91    pub success_count: usize,
92    /// 失敗回数
93    pub failure_count: usize,
94    /// 使用した variant(with_graph 等)
95    pub variant: Option<String>,
96}
97
98impl SessionGroupMetadata {
99    /// 新しいメタデータを作成
100    pub fn new(scenario: impl Into<String>, target_runs: usize) -> Self {
101        Self {
102            scenario: scenario.into(),
103            created_at: SystemTime::now()
104                .duration_since(UNIX_EPOCH)
105                .map(|d| d.as_secs())
106                .unwrap_or(0),
107            completed_at: None,
108            target_runs,
109            success_count: 0,
110            failure_count: 0,
111            variant: None,
112        }
113    }
114
115    /// variant を設定
116    pub fn with_variant(mut self, variant: impl Into<String>) -> Self {
117        self.variant = Some(variant.into());
118        self
119    }
120
121    /// 成功を記録
122    pub fn record_success(&mut self) {
123        self.success_count += 1;
124    }
125
126    /// 失敗を記録
127    pub fn record_failure(&mut self) {
128        self.failure_count += 1;
129    }
130
131    /// 完了をマーク
132    pub fn mark_completed(&mut self) {
133        self.completed_at = Some(
134            SystemTime::now()
135                .duration_since(UNIX_EPOCH)
136                .map(|d| d.as_secs())
137                .unwrap_or(0),
138        );
139    }
140
141    /// 成功率を計算
142    pub fn success_rate(&self) -> f64 {
143        let total = self.success_count + self.failure_count;
144        if total == 0 {
145            0.0
146        } else {
147            self.success_count as f64 / total as f64
148        }
149    }
150
151    /// 完了した実行回数
152    pub fn completed_runs(&self) -> usize {
153        self.success_count + self.failure_count
154    }
155}
156
157/// セッショングループ
158///
159/// 複数の Eval セッションをまとめて管理する単位。
160/// Bootstrap / Release / Validate の各フェーズで使用。
161#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct SessionGroup {
163    /// グループ ID
164    pub id: SessionGroupId,
165    /// フェーズ
166    pub phase: LearningPhase,
167    /// 含まれるセッション ID
168    pub session_ids: Vec<SessionId>,
169    /// メタデータ
170    pub metadata: SessionGroupMetadata,
171}
172
173impl SessionGroup {
174    /// 新しいセッショングループを作成
175    pub fn new(phase: LearningPhase, scenario: impl Into<String>, target_runs: usize) -> Self {
176        let scenario = scenario.into();
177        Self {
178            id: SessionGroupId::new(),
179            phase,
180            session_ids: Vec::new(),
181            metadata: SessionGroupMetadata::new(&scenario, target_runs),
182        }
183    }
184
185    /// variant を設定
186    pub fn with_variant(mut self, variant: impl Into<String>) -> Self {
187        self.metadata = self.metadata.with_variant(variant);
188        self
189    }
190
191    /// セッションを追加
192    pub fn add_session(&mut self, session_id: SessionId, success: bool) {
193        self.session_ids.push(session_id);
194        if success {
195            self.metadata.record_success();
196        } else {
197            self.metadata.record_failure();
198        }
199    }
200
201    /// 完了をマーク
202    pub fn mark_completed(&mut self) {
203        self.metadata.mark_completed();
204    }
205
206    /// 成功率を取得
207    pub fn success_rate(&self) -> f64 {
208        self.metadata.success_rate()
209    }
210
211    /// 目標回数に達したか
212    pub fn is_target_reached(&self) -> bool {
213        self.metadata.completed_runs() >= self.metadata.target_runs
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220
221    #[test]
222    fn test_session_group_id_generation() {
223        let id1 = SessionGroupId::new();
224        let id2 = SessionGroupId::new();
225        // 同じミリ秒内でも異なる可能性があるが、フォーマットは一貫
226        assert!(id1.0.starts_with('g'));
227        assert!(id2.0.starts_with('g'));
228    }
229
230    #[test]
231    fn test_learning_phase_display() {
232        assert_eq!(LearningPhase::Bootstrap.to_string(), "bootstrap");
233        assert_eq!(LearningPhase::Release.to_string(), "release");
234        assert_eq!(LearningPhase::Validate.to_string(), "validate");
235    }
236
237    #[test]
238    fn test_learning_phase_parse() {
239        assert_eq!(
240            "bootstrap".parse::<LearningPhase>().unwrap(),
241            LearningPhase::Bootstrap
242        );
243        assert_eq!(
244            "RELEASE".parse::<LearningPhase>().unwrap(),
245            LearningPhase::Release
246        );
247        assert!("unknown".parse::<LearningPhase>().is_err());
248    }
249
250    #[test]
251    fn test_session_group_success_rate() {
252        let mut group = SessionGroup::new(LearningPhase::Bootstrap, "test", 10);
253
254        // 初期状態
255        assert_eq!(group.success_rate(), 0.0);
256
257        // 3 成功、2 失敗
258        group.add_session(SessionId("1".to_string()), true);
259        group.add_session(SessionId("2".to_string()), true);
260        group.add_session(SessionId("3".to_string()), true);
261        group.add_session(SessionId("4".to_string()), false);
262        group.add_session(SessionId("5".to_string()), false);
263
264        assert_eq!(group.success_rate(), 0.6);
265        assert!(!group.is_target_reached());
266
267        // 残り 5 回追加
268        for i in 6..=10 {
269            group.add_session(SessionId(i.to_string()), true);
270        }
271        assert!(group.is_target_reached());
272    }
273}