1use std::collections::HashMap;
27
28use serde::{Deserialize, Serialize};
29
30use super::record::{ActionRecord, FromRecord, Record};
31use crate::util::{epoch_millis, epoch_millis_for_ordering};
32
33#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
43pub struct EpisodeId {
44 pub timestamp_ms: u64,
46 pub counter: u32,
48}
49
50impl EpisodeId {
51 pub fn new() -> Self {
52 use std::sync::atomic::{AtomicU32, Ordering};
53 static COUNTER: AtomicU32 = AtomicU32::new(0);
54
55 Self {
56 timestamp_ms: epoch_millis_for_ordering(),
57 counter: COUNTER.fetch_add(1, Ordering::Relaxed),
58 }
59 }
60
61 pub fn from_parts(timestamp_ms: u64, counter: u32) -> Self {
63 Self {
64 timestamp_ms,
65 counter,
66 }
67 }
68}
69
70impl Default for EpisodeId {
71 fn default() -> Self {
72 Self::new()
73 }
74}
75
76impl std::fmt::Display for EpisodeId {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 write!(f, "{}-{:08x}", self.timestamp_ms, self.counter)
79 }
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
90#[serde(tag = "type")]
91#[derive(Default)]
92pub enum Outcome {
93 Success {
95 score: f64,
97 },
98 Failure {
100 reason: String,
102 },
103 Timeout {
105 partial_score: Option<f64>,
107 },
108 #[default]
110 Unknown,
111}
112
113impl Outcome {
114 pub fn success(score: f64) -> Self {
115 Self::Success { score }
116 }
117
118 pub fn success_binary() -> Self {
119 Self::Success { score: 1.0 }
120 }
121
122 pub fn failure(reason: impl Into<String>) -> Self {
123 Self::Failure {
124 reason: reason.into(),
125 }
126 }
127
128 pub fn timeout(partial_score: Option<f64>) -> Self {
129 Self::Timeout { partial_score }
130 }
131
132 pub fn is_success(&self) -> bool {
133 matches!(self, Self::Success { .. })
134 }
135
136 pub fn is_failure(&self) -> bool {
137 matches!(self, Self::Failure { .. } | Self::Timeout { .. })
138 }
139
140 pub fn score(&self) -> f64 {
142 match self {
143 Self::Success { score } => *score,
144 Self::Timeout { partial_score } => partial_score.unwrap_or(0.0),
145 _ => 0.0,
146 }
147 }
148}
149
150
151#[derive(Debug, Clone, Default, Serialize, Deserialize)]
160pub struct EpisodeContext {
161 pub records: Vec<Record>,
163}
164
165impl EpisodeContext {
166 pub fn new() -> Self {
167 Self::default()
168 }
169
170 pub fn push(&mut self, record: impl Into<Record>) {
172 self.records.push(record.into());
173 }
174
175 pub fn with_record(mut self, record: impl Into<Record>) -> Self {
177 self.records.push(record.into());
178 self
179 }
180
181 pub fn iter<'a, T: FromRecord + 'a>(&'a self) -> impl Iterator<Item = &'a T> {
188 self.records.iter().filter_map(T::from_record)
189 }
190
191 pub fn len(&self) -> usize {
193 self.records.len()
194 }
195
196 pub fn is_empty(&self) -> bool {
198 self.records.is_empty()
199 }
200}
201
202#[derive(Debug, Clone, Default, Serialize, Deserialize)]
208pub struct EpisodeMetadata {
209 pub strategy_name: Option<String>,
211 pub scenario_name: Option<String>,
213 pub created_at: u64,
215 pub started_at: Option<u64>,
217 pub ended_at: Option<u64>,
219 pub tags: HashMap<String, String>,
221}
222
223impl EpisodeMetadata {
224 pub fn new() -> Self {
225 Self {
226 created_at: epoch_millis(),
227 ..Default::default()
228 }
229 }
230
231 pub fn with_strategy(mut self, name: impl Into<String>) -> Self {
232 self.strategy_name = Some(name.into());
233 self
234 }
235
236 pub fn with_scenario(mut self, name: impl Into<String>) -> Self {
237 self.scenario_name = Some(name.into());
238 self
239 }
240
241 pub fn with_duration(mut self, start: u64, end: u64) -> Self {
242 self.started_at = Some(start);
243 self.ended_at = Some(end);
244 self
245 }
246
247 pub fn with_tag(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
248 self.tags.insert(key.into(), value.into());
249 self
250 }
251
252 pub fn duration_ms(&self) -> Option<u64> {
254 match (self.started_at, self.ended_at) {
255 (Some(start), Some(end)) => Some(end.saturating_sub(start)),
256 _ => None,
257 }
258 }
259}
260
261#[derive(Debug, Clone, Serialize, Deserialize)]
270pub struct Episode {
271 pub id: EpisodeId,
273 pub learn_model: String,
275 pub context: EpisodeContext,
277 pub outcome: Outcome,
279 pub metadata: EpisodeMetadata,
281}
282
283impl Episode {
284 pub fn new(learn_model: impl Into<String>, outcome: Outcome) -> Self {
286 Self {
287 id: EpisodeId::new(),
288 learn_model: learn_model.into(),
289 context: EpisodeContext::default(),
290 outcome,
291 metadata: EpisodeMetadata::new(),
292 }
293 }
294
295 pub fn builder() -> EpisodeBuilder {
297 EpisodeBuilder::default()
298 }
299
300 pub fn is_success(&self) -> bool {
302 self.outcome.is_success()
303 }
304
305 pub fn worker_id(&self) -> Option<usize> {
307 self.context
308 .iter::<ActionRecord>()
309 .next()
310 .map(|a| a.worker_id)
311 }
312}
313
314#[derive(Debug, Default)]
320pub struct EpisodeBuilder {
321 id: Option<EpisodeId>,
322 learn_model: Option<String>,
323 context: EpisodeContext,
324 outcome: Option<Outcome>,
325 metadata: EpisodeMetadata,
326}
327
328impl EpisodeBuilder {
329 pub fn id(mut self, id: EpisodeId) -> Self {
331 self.id = Some(id);
332 self
333 }
334
335 pub fn learn_model(mut self, name: impl Into<String>) -> Self {
337 self.learn_model = Some(name.into());
338 self
339 }
340
341 pub fn record(mut self, record: impl Into<Record>) -> Self {
343 self.context.push(record);
344 self
345 }
346
347 pub fn context(mut self, context: EpisodeContext) -> Self {
349 self.context = context;
350 self
351 }
352
353 pub fn outcome(mut self, outcome: Outcome) -> Self {
354 self.outcome = Some(outcome);
355 self
356 }
357
358 pub fn scenario(mut self, name: impl Into<String>) -> Self {
359 self.metadata.scenario_name = Some(name.into());
360 self
361 }
362
363 pub fn tag(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
364 self.metadata.tags.insert(key.into(), value.into());
365 self
366 }
367
368 pub fn metadata(mut self, metadata: EpisodeMetadata) -> Self {
370 self.metadata = metadata;
371 self
372 }
373
374 pub fn build(self) -> Episode {
375 Episode {
376 id: self.id.unwrap_or_default(),
377 learn_model: self.learn_model.unwrap_or_else(|| "unknown".to_string()),
378 context: self.context,
379 outcome: self.outcome.unwrap_or(Outcome::Unknown),
380 metadata: self.metadata,
381 }
382 }
383}
384
385#[cfg(test)]
390mod tests {
391 use std::time::Duration;
392
393 use super::*;
394 use crate::events::{ActionContext, ActionEvent, ActionEventBuilder, ActionEventResult};
395 use crate::learn::record::LlmCallRecord;
396 use crate::types::WorkerId;
397
398 fn make_action_event(tick: u64, worker_id: usize, action: &str, success: bool) -> ActionEvent {
399 let result = if success {
400 ActionEventResult::success()
401 } else {
402 ActionEventResult::failure("test error")
403 };
404
405 ActionEventBuilder::new(tick, WorkerId(worker_id), action)
406 .result(result)
407 .duration(Duration::from_millis(50))
408 .context(
409 ActionContext::new()
410 .with_selection_logic("UCB1")
411 .with_previous_action("PrevAction"),
412 )
413 .build()
414 }
415
416 #[test]
417 fn test_action_record_from_action_event() {
418 let event = make_action_event(10, 1, "CheckStatus", true);
419 let record = ActionRecord::from(&event);
420
421 assert_eq!(record.tick, 10);
422 assert_eq!(record.worker_id, 1);
423 assert_eq!(record.action, "CheckStatus");
424 assert!(record.success);
425 assert_eq!(record.duration_ms, 50);
426 assert_eq!(record.selection_logic, Some("UCB1".to_string()));
427 assert_eq!(record.previous_action, Some("PrevAction".to_string()));
428 }
429
430 #[test]
431 fn test_episode_builder_with_actions() {
432 let event1 = make_action_event(1, 0, "Grep", true);
433 let event2 = make_action_event(2, 0, "Read", true);
434 let event3 = make_action_event(3, 0, "done", true);
435
436 let episode = Episode::builder()
437 .learn_model("worker_task")
438 .record(ActionRecord::from(&event1))
439 .record(ActionRecord::from(&event2))
440 .record(ActionRecord::from(&event3))
441 .outcome(Outcome::success_binary())
442 .scenario("troubleshooting")
443 .build();
444
445 assert_eq!(episode.learn_model, "worker_task");
446 assert_eq!(episode.context.iter::<ActionRecord>().count(), 3);
447
448 let actions: Vec<&str> = episode
449 .context
450 .iter::<ActionRecord>()
451 .map(|a| a.action.as_str())
452 .collect();
453 assert_eq!(actions, vec!["Grep", "Read", "done"]);
454
455 assert!(episode.is_success());
456 assert_eq!(
457 episode.metadata.scenario_name,
458 Some("troubleshooting".to_string())
459 );
460 }
461
462 #[test]
463 fn test_episode_builder_with_llm_call() {
464 let llm_record = LlmCallRecord::new("decide", "qwen2.5")
465 .prompt("What action?")
466 .response("CheckStatus")
467 .latency_ms(150)
468 .worker_id(0);
469
470 let episode = Episode::builder()
471 .learn_model("llm_call")
472 .record(llm_record.clone())
473 .outcome(Outcome::success(0.9))
474 .build();
475
476 assert_eq!(episode.learn_model, "llm_call");
477 assert_eq!(episode.context.iter::<LlmCallRecord>().count(), 1);
478
479 let llm_call = episode.context.iter::<LlmCallRecord>().next().unwrap();
480 assert_eq!(llm_call.prompt, "What action?");
481 assert_eq!(llm_call.response, "CheckStatus");
482 }
483
484 #[test]
485 fn test_outcome_variants() {
486 assert!(Outcome::success(1.0).is_success());
487 assert!(!Outcome::success(1.0).is_failure());
488 assert_eq!(Outcome::success(0.8).score(), 0.8);
489
490 assert!(!Outcome::failure("test").is_success());
491 assert!(Outcome::failure("test").is_failure());
492 assert_eq!(Outcome::failure("test").score(), 0.0);
493
494 assert!(!Outcome::timeout(Some(0.5)).is_success());
495 assert!(Outcome::timeout(Some(0.5)).is_failure());
496 assert_eq!(Outcome::timeout(Some(0.5)).score(), 0.5);
497
498 assert!(!Outcome::Unknown.is_success());
499 assert!(!Outcome::Unknown.is_failure());
500 }
501
502 #[test]
503 fn test_episode_context_iter() {
504 let mut context = EpisodeContext::new();
505 context.push(ActionRecord::new(1, 0, "A").success(true));
506 context.push(ActionRecord::new(2, 0, "B").success(true));
507 context.push(ActionRecord::new(3, 0, "C").success(false));
508
509 assert_eq!(context.iter::<ActionRecord>().count(), 3);
511
512 let success_count = context.iter::<ActionRecord>().filter(|a| a.success).count();
514 assert_eq!(success_count, 2);
515
516 let actions: Vec<&str> = context
518 .iter::<ActionRecord>()
519 .map(|a| a.action.as_str())
520 .collect();
521 assert_eq!(actions, vec!["A", "B", "C"]);
522 }
523
524 #[test]
525 fn test_episode_serialization() {
526 let episode = Episode::builder()
527 .learn_model("worker_task")
528 .record(ActionRecord::new(1, 0, "CheckStatus").success(true))
529 .outcome(Outcome::success_binary())
530 .build();
531
532 let json = serde_json::to_string(&episode).unwrap();
534 assert!(json.contains("\"learn_model\":\"worker_task\""));
535 assert!(json.contains("\"action\":\"CheckStatus\""));
536
537 let restored: Episode = serde_json::from_str(&json).unwrap();
539 assert_eq!(restored.learn_model, "worker_task");
540 assert_eq!(restored.context.iter::<ActionRecord>().count(), 1);
541 assert!(restored.is_success());
542 }
543
544 #[test]
545 fn test_llm_call_record_builder() {
546 let record = LlmCallRecord::new("decide", "qwen2.5")
547 .prompt("prompt")
548 .response("response")
549 .endpoint("http://localhost:11434")
550 .lora("adapter1")
551 .latency_ms(100)
552 .worker_id(5);
553
554 assert_eq!(record.call_type, "decide");
555 assert_eq!(record.model, "qwen2.5");
556 assert_eq!(record.prompt, "prompt");
557 assert_eq!(record.response, "response");
558 assert_eq!(record.lora, Some("adapter1".to_string()));
559 assert_eq!(record.worker_id, Some(5));
560 assert!(record.is_success());
561
562 let error_record = LlmCallRecord::new("decide", "model").error("timeout");
563 assert!(!error_record.is_success());
564 }
565
566 #[test]
567 fn test_episode_builder_with_id_and_metadata() {
568 let custom_id = EpisodeId::from_parts(12345, 1);
569 let mut custom_metadata = EpisodeMetadata::new();
570 custom_metadata.scenario_name = Some("custom-scenario".to_string());
571 custom_metadata
572 .tags
573 .insert("key".to_string(), "value".to_string());
574
575 let episode = Episode::builder()
576 .id(custom_id.clone())
577 .learn_model("test")
578 .metadata(custom_metadata)
579 .outcome(Outcome::Unknown)
580 .build();
581
582 assert_eq!(episode.id, custom_id);
583 assert_eq!(
584 episode.metadata.scenario_name,
585 Some("custom-scenario".to_string())
586 );
587 assert_eq!(episode.metadata.tags.get("key"), Some(&"value".to_string()));
588 }
589}