1use std::collections::HashMap;
27
28use serde::{Deserialize, Serialize};
29
30use super::record::{ActionRecord, FromRecord, Record};
31use crate::types::{GroupId, TaskId};
32use crate::util::{epoch_millis, epoch_millis_for_ordering};
33
34pub trait EpisodeTrait: Send + Sync {
62 fn id(&self) -> &EpisodeId;
64
65 fn learn_model_name(&self) -> &str;
67
68 fn task_id(&self) -> Option<TaskId>;
70
71 fn group_id(&self) -> Option<GroupId>;
73
74 fn outcome(&self) -> &Outcome;
76
77 fn is_success(&self) -> bool {
79 self.outcome().is_success()
80 }
81
82 fn scenario_name(&self) -> Option<&str>;
84}
85
86#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
96pub struct EpisodeId {
97 pub timestamp_ms: u64,
99 pub counter: u32,
101}
102
103impl EpisodeId {
104 pub fn new() -> Self {
105 use std::sync::atomic::{AtomicU32, Ordering};
106 static COUNTER: AtomicU32 = AtomicU32::new(0);
107
108 Self {
109 timestamp_ms: epoch_millis_for_ordering(),
110 counter: COUNTER.fetch_add(1, Ordering::Relaxed),
111 }
112 }
113
114 pub fn from_parts(timestamp_ms: u64, counter: u32) -> Self {
116 Self {
117 timestamp_ms,
118 counter,
119 }
120 }
121}
122
123impl Default for EpisodeId {
124 fn default() -> Self {
125 Self::new()
126 }
127}
128
129impl std::fmt::Display for EpisodeId {
130 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131 write!(f, "{}-{:08x}", self.timestamp_ms, self.counter)
132 }
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize)]
149#[serde(tag = "type")]
150#[derive(Default)]
151pub enum Outcome {
152 Success {
157 score: f64,
159 },
160 Failure {
162 reason: String,
164 },
165 Timeout {
167 partial_score: Option<f64>,
169 },
170
171 Aggregated {
178 success_rate: f64,
180 total_tasks: u32,
182 successful_tasks: u32,
184 total_ticks: u32,
186 },
187
188 #[default]
193 Unknown,
194}
195
196impl Outcome {
197 pub fn success(score: f64) -> Self {
202 Self::Success { score }
203 }
204
205 pub fn success_binary() -> Self {
206 Self::Success { score: 1.0 }
207 }
208
209 pub fn failure(reason: impl Into<String>) -> Self {
210 Self::Failure {
211 reason: reason.into(),
212 }
213 }
214
215 pub fn timeout(partial_score: Option<f64>) -> Self {
216 Self::Timeout { partial_score }
217 }
218
219 pub fn aggregated(
225 success_rate: f64,
226 total_tasks: u32,
227 successful_tasks: u32,
228 total_ticks: u32,
229 ) -> Self {
230 Self::Aggregated {
231 success_rate,
232 total_tasks,
233 successful_tasks,
234 total_ticks,
235 }
236 }
237
238 pub fn from_eval_result(
259 total_tasks: u32,
260 successful_tasks: u32,
261 success_rate: f64,
262 total_ticks: u32,
263 ) -> Self {
264 Self::Aggregated {
265 success_rate,
266 total_tasks,
267 successful_tasks,
268 total_ticks,
269 }
270 }
271
272 pub fn is_success(&self) -> bool {
277 matches!(self, Self::Success { .. })
278 }
279
280 pub fn is_failure(&self) -> bool {
281 matches!(self, Self::Failure { .. } | Self::Timeout { .. })
282 }
283
284 pub fn is_aggregated(&self) -> bool {
286 matches!(self, Self::Aggregated { .. })
287 }
288
289 pub fn is_aggregated_success(&self, threshold: f64) -> bool {
291 match self {
292 Self::Aggregated { success_rate, .. } => *success_rate >= threshold,
293 _ => false,
294 }
295 }
296
297 pub fn is_aggregated_failure(&self, threshold: f64) -> bool {
299 match self {
300 Self::Aggregated { success_rate, .. } => *success_rate < threshold,
301 _ => false,
302 }
303 }
304
305 pub fn score(&self) -> f64 {
311 match self {
312 Self::Success { score } => *score,
313 Self::Timeout { partial_score } => partial_score.unwrap_or(0.0),
314 Self::Aggregated { success_rate, .. } => *success_rate,
315 _ => 0.0,
316 }
317 }
318
319 pub fn success_rate(&self) -> Option<f64> {
324 match self {
325 Self::Aggregated { success_rate, .. } => Some(*success_rate),
326 _ => None,
327 }
328 }
329
330 pub fn ticks(&self) -> Option<u32> {
332 match self {
333 Self::Aggregated { total_ticks, .. } => Some(*total_ticks),
334 _ => None,
335 }
336 }
337}
338
339#[derive(Debug, Clone, Default, Serialize, Deserialize)]
348pub struct EpisodeContext {
349 pub records: Vec<Record>,
351}
352
353impl EpisodeContext {
354 pub fn new() -> Self {
355 Self::default()
356 }
357
358 pub fn push(&mut self, record: impl Into<Record>) {
360 self.records.push(record.into());
361 }
362
363 pub fn with_record(mut self, record: impl Into<Record>) -> Self {
365 self.records.push(record.into());
366 self
367 }
368
369 pub fn iter<'a, T: FromRecord + 'a>(&'a self) -> impl Iterator<Item = &'a T> {
376 self.records.iter().filter_map(T::from_record)
377 }
378
379 pub fn first<T: FromRecord>(&self) -> Option<&T> {
386 self.iter::<T>().next()
387 }
388
389 pub fn len(&self) -> usize {
391 self.records.len()
392 }
393
394 pub fn is_empty(&self) -> bool {
396 self.records.is_empty()
397 }
398}
399
400#[derive(Debug, Clone, Default, Serialize, Deserialize)]
406pub struct EpisodeMetadata {
407 pub strategy_name: Option<String>,
409 pub scenario_name: Option<String>,
411 pub created_at: u64,
413 pub started_at: Option<u64>,
415 pub ended_at: Option<u64>,
417 pub tags: HashMap<String, String>,
419}
420
421impl EpisodeMetadata {
422 pub fn new() -> Self {
423 Self {
424 created_at: epoch_millis(),
425 ..Default::default()
426 }
427 }
428
429 pub fn with_strategy(mut self, name: impl Into<String>) -> Self {
430 self.strategy_name = Some(name.into());
431 self
432 }
433
434 pub fn with_scenario(mut self, name: impl Into<String>) -> Self {
435 self.scenario_name = Some(name.into());
436 self
437 }
438
439 pub fn with_duration(mut self, start: u64, end: u64) -> Self {
440 self.started_at = Some(start);
441 self.ended_at = Some(end);
442 self
443 }
444
445 pub fn with_tag(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
446 self.tags.insert(key.into(), value.into());
447 self
448 }
449
450 pub fn duration_ms(&self) -> Option<u64> {
452 match (self.started_at, self.ended_at) {
453 (Some(start), Some(end)) => Some(end.saturating_sub(start)),
454 _ => None,
455 }
456 }
457}
458
459#[derive(Debug, Clone, Serialize, Deserialize)]
473pub struct Episode {
474 pub id: EpisodeId,
476 pub learn_model: String,
478 #[serde(default, skip_serializing_if = "Option::is_none")]
480 pub task_id: Option<TaskId>,
481 #[serde(default, skip_serializing_if = "Option::is_none")]
483 pub group_id: Option<GroupId>,
484 pub context: EpisodeContext,
486 pub outcome: Outcome,
488 pub metadata: EpisodeMetadata,
490}
491
492impl Episode {
493 pub fn new(learn_model: impl Into<String>, outcome: Outcome) -> Self {
495 Self {
496 id: EpisodeId::new(),
497 learn_model: learn_model.into(),
498 task_id: None,
499 group_id: None,
500 context: EpisodeContext::default(),
501 outcome,
502 metadata: EpisodeMetadata::new(),
503 }
504 }
505
506 pub fn builder() -> EpisodeBuilder {
508 EpisodeBuilder::default()
509 }
510
511 pub fn is_success(&self) -> bool {
513 self.outcome.is_success()
514 }
515
516 pub fn worker_id(&self) -> Option<usize> {
518 self.context
519 .iter::<ActionRecord>()
520 .next()
521 .map(|a| a.worker_id)
522 }
523
524 pub fn get_task_id(&self) -> Option<TaskId> {
526 self.task_id.or_else(|| {
527 self.context
528 .iter::<ActionRecord>()
529 .next()
530 .map(|a| a.task_id)
531 })
532 }
533
534 pub fn get_group_id(&self) -> Option<GroupId> {
536 self.group_id.or_else(|| {
537 self.context
538 .iter::<ActionRecord>()
539 .next()
540 .and_then(|a| a.group_id)
541 })
542 }
543}
544
545impl EpisodeTrait for Episode {
546 fn id(&self) -> &EpisodeId {
547 &self.id
548 }
549
550 fn learn_model_name(&self) -> &str {
551 &self.learn_model
552 }
553
554 fn task_id(&self) -> Option<TaskId> {
555 self.get_task_id()
556 }
557
558 fn group_id(&self) -> Option<GroupId> {
559 self.get_group_id()
560 }
561
562 fn outcome(&self) -> &Outcome {
563 &self.outcome
564 }
565
566 fn scenario_name(&self) -> Option<&str> {
567 self.metadata.scenario_name.as_deref()
568 }
569}
570
571#[derive(Debug, Default)]
577pub struct EpisodeBuilder {
578 id: Option<EpisodeId>,
579 learn_model: Option<String>,
580 task_id: Option<TaskId>,
581 group_id: Option<GroupId>,
582 context: EpisodeContext,
583 outcome: Option<Outcome>,
584 metadata: EpisodeMetadata,
585}
586
587impl EpisodeBuilder {
588 pub fn id(mut self, id: EpisodeId) -> Self {
590 self.id = Some(id);
591 self
592 }
593
594 pub fn learn_model(mut self, name: impl Into<String>) -> Self {
596 self.learn_model = Some(name.into());
597 self
598 }
599
600 pub fn task_id(mut self, task_id: TaskId) -> Self {
602 self.task_id = Some(task_id);
603 self
604 }
605
606 pub fn group_id(mut self, group_id: GroupId) -> Self {
608 self.group_id = Some(group_id);
609 self
610 }
611
612 pub fn record(mut self, record: impl Into<Record>) -> Self {
614 self.context.push(record);
615 self
616 }
617
618 pub fn context(mut self, context: EpisodeContext) -> Self {
620 self.context = context;
621 self
622 }
623
624 pub fn outcome(mut self, outcome: Outcome) -> Self {
625 self.outcome = Some(outcome);
626 self
627 }
628
629 pub fn scenario(mut self, name: impl Into<String>) -> Self {
630 self.metadata.scenario_name = Some(name.into());
631 self
632 }
633
634 pub fn tag(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
635 self.metadata.tags.insert(key.into(), value.into());
636 self
637 }
638
639 pub fn metadata(mut self, metadata: EpisodeMetadata) -> Self {
641 self.metadata = metadata;
642 self
643 }
644
645 pub fn build(self) -> Episode {
646 Episode {
647 id: self.id.unwrap_or_default(),
648 learn_model: self.learn_model.unwrap_or_else(|| "unknown".to_string()),
649 task_id: self.task_id,
650 group_id: self.group_id,
651 context: self.context,
652 outcome: self.outcome.unwrap_or(Outcome::Unknown),
653 metadata: self.metadata,
654 }
655 }
656}
657
658#[cfg(test)]
663mod tests {
664 use std::time::Duration;
665
666 use super::*;
667 use crate::events::{ActionContext, ActionEvent, ActionEventBuilder, ActionEventResult};
668 use crate::learn::record::LlmCallRecord;
669 use crate::types::WorkerId;
670
671 fn make_action_event(tick: u64, worker_id: usize, action: &str, success: bool) -> ActionEvent {
672 let result = if success {
673 ActionEventResult::success()
674 } else {
675 ActionEventResult::failure("test error")
676 };
677
678 ActionEventBuilder::new(tick, WorkerId(worker_id), action)
679 .result(result)
680 .duration(Duration::from_millis(50))
681 .context(
682 ActionContext::new()
683 .with_selection_logic("UCB1")
684 .with_previous_action("PrevAction"),
685 )
686 .build()
687 }
688
689 #[test]
690 fn test_action_record_from_action_event() {
691 let event = make_action_event(10, 1, "CheckStatus", true);
692 let record = ActionRecord::from(&event);
693
694 assert_eq!(record.tick, 10);
695 assert_eq!(record.worker_id, 1);
696 assert_eq!(record.action, "CheckStatus");
697 assert!(record.success);
698 assert_eq!(record.duration_ms, 50);
699 assert_eq!(record.selection_logic, Some("UCB1".to_string()));
700 assert_eq!(record.previous_action, Some("PrevAction".to_string()));
701 }
702
703 #[test]
704 fn test_episode_builder_with_actions() {
705 let event1 = make_action_event(1, 0, "Grep", true);
706 let event2 = make_action_event(2, 0, "Read", true);
707 let event3 = make_action_event(3, 0, "done", true);
708
709 let episode = Episode::builder()
710 .learn_model("worker_task")
711 .record(ActionRecord::from(&event1))
712 .record(ActionRecord::from(&event2))
713 .record(ActionRecord::from(&event3))
714 .outcome(Outcome::success_binary())
715 .scenario("troubleshooting")
716 .build();
717
718 assert_eq!(episode.learn_model, "worker_task");
719 assert_eq!(episode.context.iter::<ActionRecord>().count(), 3);
720
721 let actions: Vec<&str> = episode
722 .context
723 .iter::<ActionRecord>()
724 .map(|a| a.action.as_str())
725 .collect();
726 assert_eq!(actions, vec!["Grep", "Read", "done"]);
727
728 assert!(episode.is_success());
729 assert_eq!(
730 episode.metadata.scenario_name,
731 Some("troubleshooting".to_string())
732 );
733 }
734
735 #[test]
736 fn test_episode_builder_with_llm_call() {
737 let llm_record = LlmCallRecord::new("decide", "qwen2.5")
738 .prompt("What action?")
739 .response("CheckStatus")
740 .latency_ms(150)
741 .worker_id(0);
742
743 let episode = Episode::builder()
744 .learn_model("llm_call")
745 .record(llm_record.clone())
746 .outcome(Outcome::success(0.9))
747 .build();
748
749 assert_eq!(episode.learn_model, "llm_call");
750 assert_eq!(episode.context.iter::<LlmCallRecord>().count(), 1);
751
752 let llm_call = episode.context.first::<LlmCallRecord>().unwrap();
753 assert_eq!(llm_call.prompt, "What action?");
754 assert_eq!(llm_call.response, "CheckStatus");
755 }
756
757 #[test]
758 fn test_outcome_variants() {
759 assert!(Outcome::success(1.0).is_success());
761 assert!(!Outcome::success(1.0).is_failure());
762 assert_eq!(Outcome::success(0.8).score(), 0.8);
763
764 assert!(!Outcome::failure("test").is_success());
766 assert!(Outcome::failure("test").is_failure());
767 assert_eq!(Outcome::failure("test").score(), 0.0);
768
769 assert!(!Outcome::timeout(Some(0.5)).is_success());
771 assert!(Outcome::timeout(Some(0.5)).is_failure());
772 assert_eq!(Outcome::timeout(Some(0.5)).score(), 0.5);
773
774 assert!(!Outcome::Unknown.is_success());
776 assert!(!Outcome::Unknown.is_failure());
777 }
778
779 #[test]
780 fn test_outcome_aggregated() {
781 let high = Outcome::aggregated(0.9, 10, 9, 100);
783 assert!(high.is_aggregated());
784 assert!(high.is_aggregated_success(0.5));
785 assert!(!high.is_aggregated_failure(0.5));
786 assert_eq!(high.score(), 0.9);
787 assert_eq!(high.success_rate(), Some(0.9));
788 assert_eq!(high.ticks(), Some(100));
789
790 let low = Outcome::aggregated(0.3, 10, 3, 50);
792 assert!(low.is_aggregated());
793 assert!(!low.is_aggregated_success(0.5));
794 assert!(low.is_aggregated_failure(0.5));
795 assert_eq!(low.score(), 0.3);
796
797 let from_eval = Outcome::from_eval_result(5, 4, 0.8, 20);
799 assert!(from_eval.is_aggregated());
800 assert_eq!(from_eval.success_rate(), Some(0.8));
801 assert_eq!(from_eval.ticks(), Some(20));
802 }
803
804 #[test]
805 fn test_episode_context_iter() {
806 let mut context = EpisodeContext::new();
807 context.push(ActionRecord::new(1, 0, "A").success(true));
808 context.push(ActionRecord::new(2, 0, "B").success(true));
809 context.push(ActionRecord::new(3, 0, "C").success(false));
810
811 assert_eq!(context.iter::<ActionRecord>().count(), 3);
813
814 let success_count = context.iter::<ActionRecord>().filter(|a| a.success).count();
816 assert_eq!(success_count, 2);
817
818 let actions: Vec<&str> = context
820 .iter::<ActionRecord>()
821 .map(|a| a.action.as_str())
822 .collect();
823 assert_eq!(actions, vec!["A", "B", "C"]);
824 }
825
826 #[test]
827 fn test_episode_serialization() {
828 let episode = Episode::builder()
829 .learn_model("worker_task")
830 .record(ActionRecord::new(1, 0, "CheckStatus").success(true))
831 .outcome(Outcome::success_binary())
832 .build();
833
834 let json = serde_json::to_string(&episode).unwrap();
836 assert!(json.contains("\"learn_model\":\"worker_task\""));
837 assert!(json.contains("\"action\":\"CheckStatus\""));
838
839 let restored: Episode = serde_json::from_str(&json).unwrap();
841 assert_eq!(restored.learn_model, "worker_task");
842 assert_eq!(restored.context.iter::<ActionRecord>().count(), 1);
843 assert!(restored.is_success());
844 }
845
846 #[test]
847 fn test_llm_call_record_builder() {
848 let record = LlmCallRecord::new("decide", "qwen2.5")
849 .prompt("prompt")
850 .response("response")
851 .endpoint("http://localhost:11434")
852 .lora("adapter1")
853 .latency_ms(100)
854 .worker_id(5);
855
856 assert_eq!(record.call_type, "decide");
857 assert_eq!(record.model, "qwen2.5");
858 assert_eq!(record.prompt, "prompt");
859 assert_eq!(record.response, "response");
860 assert_eq!(record.lora, Some("adapter1".to_string()));
861 assert_eq!(record.worker_id, Some(5));
862 assert!(record.is_success());
863
864 let error_record = LlmCallRecord::new("decide", "model").error("timeout");
865 assert!(!error_record.is_success());
866 }
867
868 #[test]
869 fn test_episode_builder_with_id_and_metadata() {
870 let custom_id = EpisodeId::from_parts(12345, 1);
871 let mut custom_metadata = EpisodeMetadata::new();
872 custom_metadata.scenario_name = Some("custom-scenario".to_string());
873 custom_metadata
874 .tags
875 .insert("key".to_string(), "value".to_string());
876
877 let episode = Episode::builder()
878 .id(custom_id.clone())
879 .learn_model("test")
880 .metadata(custom_metadata)
881 .outcome(Outcome::Unknown)
882 .build();
883
884 assert_eq!(episode.id, custom_id);
885 assert_eq!(
886 episode.metadata.scenario_name,
887 Some("custom-scenario".to_string())
888 );
889 assert_eq!(episode.metadata.tags.get("key"), Some(&"value".to_string()));
890 }
891}