1use std::time::{SystemTime, UNIX_EPOCH};
7
8use serde::{Deserialize, Serialize};
9
10use super::snapshot::SessionId;
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
14pub struct SessionGroupId(pub String);
15
16impl SessionGroupId {
17 pub fn new() -> Self {
19 let timestamp = SystemTime::now()
20 .duration_since(UNIX_EPOCH)
21 .map(|d| d.as_millis())
22 .unwrap_or(0);
23 Self(format!("g{}", timestamp))
24 }
25
26 pub fn from_raw(s: impl Into<String>) -> Self {
28 Self(s.into())
29 }
30}
31
32impl Default for SessionGroupId {
33 fn default() -> Self {
34 Self::new()
35 }
36}
37
38impl std::fmt::Display for SessionGroupId {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 write!(f, "{}", self.0)
41 }
42}
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
46#[serde(rename_all = "snake_case")]
47pub enum LearningPhase {
48 Bootstrap,
50 Release,
52 Validate,
54}
55
56impl std::fmt::Display for LearningPhase {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 match self {
59 Self::Bootstrap => write!(f, "bootstrap"),
60 Self::Release => write!(f, "release"),
61 Self::Validate => write!(f, "validate"),
62 }
63 }
64}
65
66impl std::str::FromStr for LearningPhase {
67 type Err = String;
68
69 fn from_str(s: &str) -> Result<Self, Self::Err> {
70 match s.to_lowercase().as_str() {
71 "bootstrap" => Ok(Self::Bootstrap),
72 "release" => Ok(Self::Release),
73 "validate" => Ok(Self::Validate),
74 _ => Err(format!("Unknown phase: {}", s)),
75 }
76 }
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct SessionGroupMetadata {
82 pub scenario: String,
84 pub created_at: u64,
86 pub completed_at: Option<u64>,
88 pub target_runs: usize,
90 pub success_count: usize,
92 pub failure_count: usize,
94 pub variant: Option<String>,
96}
97
98impl SessionGroupMetadata {
99 pub fn new(scenario: impl Into<String>, target_runs: usize) -> Self {
101 Self {
102 scenario: scenario.into(),
103 created_at: SystemTime::now()
104 .duration_since(UNIX_EPOCH)
105 .map(|d| d.as_secs())
106 .unwrap_or(0),
107 completed_at: None,
108 target_runs,
109 success_count: 0,
110 failure_count: 0,
111 variant: None,
112 }
113 }
114
115 pub fn with_variant(mut self, variant: impl Into<String>) -> Self {
117 self.variant = Some(variant.into());
118 self
119 }
120
121 pub fn record_success(&mut self) {
123 self.success_count += 1;
124 }
125
126 pub fn record_failure(&mut self) {
128 self.failure_count += 1;
129 }
130
131 pub fn mark_completed(&mut self) {
133 self.completed_at = Some(
134 SystemTime::now()
135 .duration_since(UNIX_EPOCH)
136 .map(|d| d.as_secs())
137 .unwrap_or(0),
138 );
139 }
140
141 pub fn success_rate(&self) -> f64 {
143 let total = self.success_count + self.failure_count;
144 if total == 0 {
145 0.0
146 } else {
147 self.success_count as f64 / total as f64
148 }
149 }
150
151 pub fn completed_runs(&self) -> usize {
153 self.success_count + self.failure_count
154 }
155}
156
157#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct SessionGroup {
163 pub id: SessionGroupId,
165 pub phase: LearningPhase,
167 pub session_ids: Vec<SessionId>,
169 pub metadata: SessionGroupMetadata,
171}
172
173impl SessionGroup {
174 pub fn new(phase: LearningPhase, scenario: impl Into<String>, target_runs: usize) -> Self {
176 let scenario = scenario.into();
177 Self {
178 id: SessionGroupId::new(),
179 phase,
180 session_ids: Vec::new(),
181 metadata: SessionGroupMetadata::new(&scenario, target_runs),
182 }
183 }
184
185 pub fn with_variant(mut self, variant: impl Into<String>) -> Self {
187 self.metadata = self.metadata.with_variant(variant);
188 self
189 }
190
191 pub fn add_session(&mut self, session_id: SessionId, success: bool) {
193 self.session_ids.push(session_id);
194 if success {
195 self.metadata.record_success();
196 } else {
197 self.metadata.record_failure();
198 }
199 }
200
201 pub fn mark_completed(&mut self) {
203 self.metadata.mark_completed();
204 }
205
206 pub fn success_rate(&self) -> f64 {
208 self.metadata.success_rate()
209 }
210
211 pub fn is_target_reached(&self) -> bool {
213 self.metadata.completed_runs() >= self.metadata.target_runs
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220
221 #[test]
222 fn test_session_group_id_generation() {
223 let id1 = SessionGroupId::new();
224 let id2 = SessionGroupId::new();
225 assert!(id1.0.starts_with('g'));
227 assert!(id2.0.starts_with('g'));
228 }
229
230 #[test]
231 fn test_learning_phase_display() {
232 assert_eq!(LearningPhase::Bootstrap.to_string(), "bootstrap");
233 assert_eq!(LearningPhase::Release.to_string(), "release");
234 assert_eq!(LearningPhase::Validate.to_string(), "validate");
235 }
236
237 #[test]
238 fn test_learning_phase_parse() {
239 assert_eq!(
240 "bootstrap".parse::<LearningPhase>().unwrap(),
241 LearningPhase::Bootstrap
242 );
243 assert_eq!(
244 "RELEASE".parse::<LearningPhase>().unwrap(),
245 LearningPhase::Release
246 );
247 assert!("unknown".parse::<LearningPhase>().is_err());
248 }
249
250 #[test]
251 fn test_session_group_success_rate() {
252 let mut group = SessionGroup::new(LearningPhase::Bootstrap, "test", 10);
253
254 assert_eq!(group.success_rate(), 0.0);
256
257 group.add_session(SessionId("1".to_string()), true);
259 group.add_session(SessionId("2".to_string()), true);
260 group.add_session(SessionId("3".to_string()), true);
261 group.add_session(SessionId("4".to_string()), false);
262 group.add_session(SessionId("5".to_string()), false);
263
264 assert_eq!(group.success_rate(), 0.6);
265 assert!(!group.is_target_reached());
266
267 for i in 6..=10 {
269 group.add_session(SessionId(i.to_string()), true);
270 }
271 assert!(group.is_target_reached());
272 }
273}