1use std::collections::HashMap;
27
28use serde::{Deserialize, Serialize};
29
30use super::record::{ActionRecord, FromRecord, Record};
31use crate::types::{GroupId, TaskId};
32use crate::util::{epoch_millis, epoch_millis_for_ordering};
33
34pub trait EpisodeTrait: Send + Sync {
62 fn id(&self) -> &EpisodeId;
64
65 fn learn_model_name(&self) -> &str;
67
68 fn task_id(&self) -> Option<TaskId>;
70
71 fn group_id(&self) -> Option<GroupId>;
73
74 fn outcome(&self) -> &Outcome;
76
77 fn is_success(&self) -> bool {
79 self.outcome().is_success()
80 }
81
82 fn scenario_name(&self) -> Option<&str>;
84}
85
86#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
96pub struct EpisodeId {
97 pub timestamp_ms: u64,
99 pub counter: u32,
101}
102
103impl EpisodeId {
104 pub fn new() -> Self {
105 use std::sync::atomic::{AtomicU32, Ordering};
106 static COUNTER: AtomicU32 = AtomicU32::new(0);
107
108 Self {
109 timestamp_ms: epoch_millis_for_ordering(),
110 counter: COUNTER.fetch_add(1, Ordering::Relaxed),
111 }
112 }
113
114 pub fn from_parts(timestamp_ms: u64, counter: u32) -> Self {
116 Self {
117 timestamp_ms,
118 counter,
119 }
120 }
121}
122
123impl Default for EpisodeId {
124 fn default() -> Self {
125 Self::new()
126 }
127}
128
129impl std::fmt::Display for EpisodeId {
130 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131 write!(f, "{}-{:08x}", self.timestamp_ms, self.counter)
132 }
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize)]
153#[serde(tag = "type")]
154#[derive(Default)]
155pub enum Outcome {
156 Success {
161 score: f64,
163 },
164 Failure {
166 reason: String,
168 },
169 Timeout {
171 partial_score: Option<f64>,
173 },
174
175 #[default]
180 Unknown,
181}
182
183impl Outcome {
184 pub fn success(score: f64) -> Self {
189 Self::Success { score }
190 }
191
192 pub fn success_binary() -> Self {
193 Self::Success { score: 1.0 }
194 }
195
196 pub fn failure(reason: impl Into<String>) -> Self {
197 Self::Failure {
198 reason: reason.into(),
199 }
200 }
201
202 pub fn timeout(partial_score: Option<f64>) -> Self {
203 Self::Timeout { partial_score }
204 }
205
206 pub fn is_success(&self) -> bool {
211 matches!(self, Self::Success { .. })
212 }
213
214 pub fn is_failure(&self) -> bool {
215 matches!(self, Self::Failure { .. } | Self::Timeout { .. })
216 }
217
218 pub fn is_unknown(&self) -> bool {
220 matches!(self, Self::Unknown)
221 }
222
223 pub fn score(&self) -> f64 {
229 match self {
230 Self::Success { score } => *score,
231 Self::Timeout { partial_score } => partial_score.unwrap_or(0.0),
232 _ => 0.0,
233 }
234 }
235}
236
237#[derive(Debug, Clone, Default, Serialize, Deserialize)]
246pub struct EpisodeContext {
247 pub records: Vec<Record>,
249}
250
251impl EpisodeContext {
252 pub fn new() -> Self {
253 Self::default()
254 }
255
256 pub fn push(&mut self, record: impl Into<Record>) {
258 self.records.push(record.into());
259 }
260
261 pub fn with_record(mut self, record: impl Into<Record>) -> Self {
263 self.records.push(record.into());
264 self
265 }
266
267 pub fn iter<'a, T: FromRecord + 'a>(&'a self) -> impl Iterator<Item = &'a T> {
274 self.records.iter().filter_map(T::from_record)
275 }
276
277 pub fn first<T: FromRecord>(&self) -> Option<&T> {
284 self.iter::<T>().next()
285 }
286
287 pub fn len(&self) -> usize {
289 self.records.len()
290 }
291
292 pub fn is_empty(&self) -> bool {
294 self.records.is_empty()
295 }
296}
297
298#[derive(Debug, Clone, Default, Serialize, Deserialize)]
304pub struct EpisodeMetadata {
305 pub strategy_name: Option<String>,
307 pub scenario_name: Option<String>,
309 pub created_at: u64,
311 pub started_at: Option<u64>,
313 pub ended_at: Option<u64>,
315 pub tags: HashMap<String, String>,
317}
318
319impl EpisodeMetadata {
320 pub fn new() -> Self {
321 Self {
322 created_at: epoch_millis(),
323 ..Default::default()
324 }
325 }
326
327 pub fn with_strategy(mut self, name: impl Into<String>) -> Self {
328 self.strategy_name = Some(name.into());
329 self
330 }
331
332 pub fn with_scenario(mut self, name: impl Into<String>) -> Self {
333 self.scenario_name = Some(name.into());
334 self
335 }
336
337 pub fn with_duration(mut self, start: u64, end: u64) -> Self {
338 self.started_at = Some(start);
339 self.ended_at = Some(end);
340 self
341 }
342
343 pub fn with_tag(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
344 self.tags.insert(key.into(), value.into());
345 self
346 }
347
348 pub fn duration_ms(&self) -> Option<u64> {
350 match (self.started_at, self.ended_at) {
351 (Some(start), Some(end)) => Some(end.saturating_sub(start)),
352 _ => None,
353 }
354 }
355}
356
357#[derive(Debug, Clone, Serialize, Deserialize)]
371pub struct Episode {
372 pub id: EpisodeId,
374 pub learn_model: String,
376 #[serde(default, skip_serializing_if = "Option::is_none")]
378 pub task_id: Option<TaskId>,
379 #[serde(default, skip_serializing_if = "Option::is_none")]
381 pub group_id: Option<GroupId>,
382 pub context: EpisodeContext,
384 pub outcome: Outcome,
386 pub metadata: EpisodeMetadata,
388}
389
390impl Episode {
391 pub fn new(learn_model: impl Into<String>, outcome: Outcome) -> Self {
393 Self {
394 id: EpisodeId::new(),
395 learn_model: learn_model.into(),
396 task_id: None,
397 group_id: None,
398 context: EpisodeContext::default(),
399 outcome,
400 metadata: EpisodeMetadata::new(),
401 }
402 }
403
404 pub fn builder() -> EpisodeBuilder {
406 EpisodeBuilder::default()
407 }
408
409 pub fn is_success(&self) -> bool {
411 self.outcome.is_success()
412 }
413
414 pub fn worker_id(&self) -> Option<usize> {
416 self.context
417 .iter::<ActionRecord>()
418 .next()
419 .map(|a| a.worker_id)
420 }
421
422 pub fn get_task_id(&self) -> Option<TaskId> {
424 self.task_id.or_else(|| {
425 self.context
426 .iter::<ActionRecord>()
427 .next()
428 .map(|a| a.task_id)
429 })
430 }
431
432 pub fn get_group_id(&self) -> Option<GroupId> {
434 self.group_id.or_else(|| {
435 self.context
436 .iter::<ActionRecord>()
437 .next()
438 .and_then(|a| a.group_id)
439 })
440 }
441}
442
443impl EpisodeTrait for Episode {
444 fn id(&self) -> &EpisodeId {
445 &self.id
446 }
447
448 fn learn_model_name(&self) -> &str {
449 &self.learn_model
450 }
451
452 fn task_id(&self) -> Option<TaskId> {
453 self.get_task_id()
454 }
455
456 fn group_id(&self) -> Option<GroupId> {
457 self.get_group_id()
458 }
459
460 fn outcome(&self) -> &Outcome {
461 &self.outcome
462 }
463
464 fn scenario_name(&self) -> Option<&str> {
465 self.metadata.scenario_name.as_deref()
466 }
467}
468
469#[derive(Debug, Default)]
475pub struct EpisodeBuilder {
476 id: Option<EpisodeId>,
477 learn_model: Option<String>,
478 task_id: Option<TaskId>,
479 group_id: Option<GroupId>,
480 context: EpisodeContext,
481 outcome: Option<Outcome>,
482 metadata: EpisodeMetadata,
483}
484
485impl EpisodeBuilder {
486 pub fn id(mut self, id: EpisodeId) -> Self {
488 self.id = Some(id);
489 self
490 }
491
492 pub fn learn_model(mut self, name: impl Into<String>) -> Self {
494 self.learn_model = Some(name.into());
495 self
496 }
497
498 pub fn task_id(mut self, task_id: TaskId) -> Self {
500 self.task_id = Some(task_id);
501 self
502 }
503
504 pub fn group_id(mut self, group_id: GroupId) -> Self {
506 self.group_id = Some(group_id);
507 self
508 }
509
510 pub fn record(mut self, record: impl Into<Record>) -> Self {
512 self.context.push(record);
513 self
514 }
515
516 pub fn context(mut self, context: EpisodeContext) -> Self {
518 self.context = context;
519 self
520 }
521
522 pub fn outcome(mut self, outcome: Outcome) -> Self {
523 self.outcome = Some(outcome);
524 self
525 }
526
527 pub fn scenario(mut self, name: impl Into<String>) -> Self {
528 self.metadata.scenario_name = Some(name.into());
529 self
530 }
531
532 pub fn tag(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
533 self.metadata.tags.insert(key.into(), value.into());
534 self
535 }
536
537 pub fn metadata(mut self, metadata: EpisodeMetadata) -> Self {
539 self.metadata = metadata;
540 self
541 }
542
543 pub fn build(self) -> Episode {
544 Episode {
545 id: self.id.unwrap_or_default(),
546 learn_model: self.learn_model.unwrap_or_else(|| "unknown".to_string()),
547 task_id: self.task_id,
548 group_id: self.group_id,
549 context: self.context,
550 outcome: self.outcome.unwrap_or(Outcome::Unknown),
551 metadata: self.metadata,
552 }
553 }
554}
555
556#[cfg(test)]
561mod tests {
562 use std::time::Duration;
563
564 use super::*;
565 use crate::events::{ActionContext, ActionEvent, ActionEventBuilder, ActionEventResult};
566 use crate::learn::record::LlmCallRecord;
567 use crate::types::WorkerId;
568
569 fn make_action_event(tick: u64, worker_id: usize, action: &str, success: bool) -> ActionEvent {
570 let result = if success {
571 ActionEventResult::success()
572 } else {
573 ActionEventResult::failure("test error")
574 };
575
576 ActionEventBuilder::new(tick, WorkerId(worker_id), action)
577 .result(result)
578 .duration(Duration::from_millis(50))
579 .context(
580 ActionContext::new()
581 .with_selection_logic("UCB1")
582 .with_previous_action("PrevAction"),
583 )
584 .build()
585 }
586
587 #[test]
588 fn test_action_record_from_action_event() {
589 let event = make_action_event(10, 1, "CheckStatus", true);
590 let record = ActionRecord::from(&event);
591
592 assert_eq!(record.tick, 10);
593 assert_eq!(record.worker_id, 1);
594 assert_eq!(record.action, "CheckStatus");
595 assert!(record.success);
596 assert_eq!(record.duration_ms, 50);
597 assert_eq!(record.selection_logic, Some("UCB1".to_string()));
598 assert_eq!(record.previous_action, Some("PrevAction".to_string()));
599 }
600
601 #[test]
602 fn test_episode_builder_with_actions() {
603 let event1 = make_action_event(1, 0, "Grep", true);
604 let event2 = make_action_event(2, 0, "Read", true);
605 let event3 = make_action_event(3, 0, "done", true);
606
607 let episode = Episode::builder()
608 .learn_model("worker_task")
609 .record(ActionRecord::from(&event1))
610 .record(ActionRecord::from(&event2))
611 .record(ActionRecord::from(&event3))
612 .outcome(Outcome::success_binary())
613 .scenario("troubleshooting")
614 .build();
615
616 assert_eq!(episode.learn_model, "worker_task");
617 assert_eq!(episode.context.iter::<ActionRecord>().count(), 3);
618
619 let actions: Vec<&str> = episode
620 .context
621 .iter::<ActionRecord>()
622 .map(|a| a.action.as_str())
623 .collect();
624 assert_eq!(actions, vec!["Grep", "Read", "done"]);
625
626 assert!(episode.is_success());
627 assert_eq!(
628 episode.metadata.scenario_name,
629 Some("troubleshooting".to_string())
630 );
631 }
632
633 #[test]
634 fn test_episode_builder_with_llm_call() {
635 let llm_record = LlmCallRecord::new("decide", "qwen2.5")
636 .prompt("What action?")
637 .response("CheckStatus")
638 .latency_ms(150)
639 .worker_id(0);
640
641 let episode = Episode::builder()
642 .learn_model("llm_call")
643 .record(llm_record.clone())
644 .outcome(Outcome::success(0.9))
645 .build();
646
647 assert_eq!(episode.learn_model, "llm_call");
648 assert_eq!(episode.context.iter::<LlmCallRecord>().count(), 1);
649
650 let llm_call = episode.context.first::<LlmCallRecord>().unwrap();
651 assert_eq!(llm_call.prompt, "What action?");
652 assert_eq!(llm_call.response, "CheckStatus");
653 }
654
655 #[test]
656 fn test_outcome_variants() {
657 assert!(Outcome::success(1.0).is_success());
659 assert!(!Outcome::success(1.0).is_failure());
660 assert_eq!(Outcome::success(0.8).score(), 0.8);
661
662 assert!(!Outcome::failure("test").is_success());
664 assert!(Outcome::failure("test").is_failure());
665 assert_eq!(Outcome::failure("test").score(), 0.0);
666
667 assert!(!Outcome::timeout(Some(0.5)).is_success());
669 assert!(Outcome::timeout(Some(0.5)).is_failure());
670 assert_eq!(Outcome::timeout(Some(0.5)).score(), 0.5);
671
672 assert!(!Outcome::Unknown.is_success());
674 assert!(!Outcome::Unknown.is_failure());
675 }
676
677 #[test]
678 fn test_episode_context_iter() {
679 let mut context = EpisodeContext::new();
680 context.push(ActionRecord::new(1, 0, "A").success(true));
681 context.push(ActionRecord::new(2, 0, "B").success(true));
682 context.push(ActionRecord::new(3, 0, "C").success(false));
683
684 assert_eq!(context.iter::<ActionRecord>().count(), 3);
686
687 let success_count = context.iter::<ActionRecord>().filter(|a| a.success).count();
689 assert_eq!(success_count, 2);
690
691 let actions: Vec<&str> = context
693 .iter::<ActionRecord>()
694 .map(|a| a.action.as_str())
695 .collect();
696 assert_eq!(actions, vec!["A", "B", "C"]);
697 }
698
699 #[test]
700 fn test_episode_serialization() {
701 let episode = Episode::builder()
702 .learn_model("worker_task")
703 .record(ActionRecord::new(1, 0, "CheckStatus").success(true))
704 .outcome(Outcome::success_binary())
705 .build();
706
707 let json = serde_json::to_string(&episode).unwrap();
709 assert!(json.contains("\"learn_model\":\"worker_task\""));
710 assert!(json.contains("\"action\":\"CheckStatus\""));
711
712 let restored: Episode = serde_json::from_str(&json).unwrap();
714 assert_eq!(restored.learn_model, "worker_task");
715 assert_eq!(restored.context.iter::<ActionRecord>().count(), 1);
716 assert!(restored.is_success());
717 }
718
719 #[test]
720 fn test_llm_call_record_builder() {
721 let record = LlmCallRecord::new("decide", "qwen2.5")
722 .prompt("prompt")
723 .response("response")
724 .endpoint("http://localhost:11434")
725 .lora("adapter1")
726 .latency_ms(100)
727 .worker_id(5);
728
729 assert_eq!(record.call_type, "decide");
730 assert_eq!(record.model, "qwen2.5");
731 assert_eq!(record.prompt, "prompt");
732 assert_eq!(record.response, "response");
733 assert_eq!(record.lora, Some("adapter1".to_string()));
734 assert_eq!(record.worker_id, Some(5));
735 assert!(record.is_success());
736
737 let error_record = LlmCallRecord::new("decide", "model").error("timeout");
738 assert!(!error_record.is_success());
739 }
740
741 #[test]
742 fn test_episode_builder_with_id_and_metadata() {
743 let custom_id = EpisodeId::from_parts(12345, 1);
744 let mut custom_metadata = EpisodeMetadata::new();
745 custom_metadata.scenario_name = Some("custom-scenario".to_string());
746 custom_metadata
747 .tags
748 .insert("key".to_string(), "value".to_string());
749
750 let episode = Episode::builder()
751 .id(custom_id.clone())
752 .learn_model("test")
753 .metadata(custom_metadata)
754 .outcome(Outcome::Unknown)
755 .build();
756
757 assert_eq!(episode.id, custom_id);
758 assert_eq!(
759 episode.metadata.scenario_name,
760 Some("custom-scenario".to_string())
761 );
762 assert_eq!(episode.metadata.tags.get("key"), Some(&"value".to_string()));
763 }
764}