1use crate::events::ActionEvent;
35
36use super::episode::{Episode, EpisodeContext, Outcome};
37use super::record::{ActionRecord, Record, RecordStream};
38use super::training::TrainingData;
39
40pub mod system_events {
46 pub const TICK_START: &str = "tick_start";
48 pub const TICK_END: &str = "tick_end";
50 pub const DONE: &str = "done";
52
53 pub const DEFAULT_SYSTEM_EVENTS: &[&str] = &[TICK_START, TICK_END, DONE];
55}
56
57pub trait LearnModel: Send + Sync {
75 fn name(&self) -> &str;
77
78 fn objective(&self) -> &str;
80
81 fn build_episodes(&self, records: &[Record]) -> Vec<Episode>;
87
88 fn evaluate(&self, context: &EpisodeContext) -> Outcome;
93
94 fn convert(&self, episode: &Episode) -> Result<TrainingData, LearnError>;
96
97 fn convert_batch(&self, episodes: &[Episode]) -> Vec<TrainingData> {
99 episodes
100 .iter()
101 .filter_map(|ep| self.convert(ep).ok())
102 .collect()
103 }
104
105 fn build_episodes_from_actions(&self, actions: &[ActionEvent]) -> Vec<Episode> {
107 let records: Vec<Record> = actions.iter().map(Record::from).collect();
108 self.build_episodes(&records)
109 }
110}
111
112#[derive(Debug, thiserror::Error)]
118pub enum LearnError {
119 #[error("Build error: {0}")]
120 Build(String),
121
122 #[error("Conversion error: {0}")]
123 Conversion(String),
124
125 #[error("Missing data: {0}")]
126 MissingData(String),
127
128 #[error("Invalid episode: {0}")]
129 InvalidEpisode(String),
130
131 #[error("{0}")]
132 Other(String),
133}
134
135pub struct WorkerTaskLearn {
143 system_prompt: String,
145 min_actions: usize,
147}
148
149impl WorkerTaskLearn {
150 pub fn new() -> Self {
151 Self {
152 system_prompt:
153 "You are an intelligent agent that diagnoses and resolves system issues."
154 .to_string(),
155 min_actions: 2,
156 }
157 }
158
159 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
160 self.system_prompt = prompt.into();
161 self
162 }
163
164 pub fn with_min_actions(mut self, min: usize) -> Self {
165 self.min_actions = min;
166 self
167 }
168}
169
170impl Default for WorkerTaskLearn {
171 fn default() -> Self {
172 Self::new()
173 }
174}
175
176impl LearnModel for WorkerTaskLearn {
177 fn name(&self) -> &str {
178 "worker_task"
179 }
180
181 fn objective(&self) -> &str {
182 "Learn complete worker task sequences from start to done"
183 }
184
185 fn evaluate(&self, context: &EpisodeContext) -> Outcome {
186 if context.is_empty() {
188 return Outcome::failure("Empty context: no actions to evaluate");
189 }
190
191 let last_action = context.iter::<ActionRecord>().last();
193
194 match last_action {
195 Some(action) if action.is_terminal() => {
196 if action.success {
197 Outcome::success_binary()
198 } else {
199 Outcome::failure("Task failed")
200 }
201 }
202 _ => Outcome::Unknown,
203 }
204 }
205
206 fn build_episodes(&self, records: &[Record]) -> Vec<Episode> {
207 use std::collections::HashMap;
208
209 let stream = RecordStream::new(records);
210
211 let mut worker_actions: HashMap<usize, Vec<&ActionRecord>> = HashMap::new();
213 for record in stream.actions() {
214 worker_actions
215 .entry(record.worker_id)
216 .or_default()
217 .push(record);
218 }
219
220 let mut episodes = Vec::new();
221
222 for (_worker_id, worker_records) in worker_actions {
223 let mut current_sequence: Vec<&ActionRecord> = Vec::new();
225
226 for record in worker_records {
227 current_sequence.push(record);
228
229 if record.is_terminal() {
230 if current_sequence.len() >= self.min_actions {
232 let mut context = EpisodeContext::new();
234 for r in ¤t_sequence {
235 context.push((*r).clone());
236 }
237
238 let outcome = self.evaluate(&context);
240
241 let episode = Episode::builder()
242 .learn_model("worker_task")
243 .context(context)
244 .outcome(outcome)
245 .build();
246
247 episodes.push(episode);
248 }
249
250 current_sequence.clear();
251 }
252 }
253 }
254
255 episodes
256 }
257
258 fn convert(&self, episode: &Episode) -> Result<TrainingData, LearnError> {
259 if !episode.outcome.is_success() {
260 return Err(LearnError::InvalidEpisode(
261 "Episode is not successful".into(),
262 ));
263 }
264
265 let action_count = episode.context.iter::<ActionRecord>().count();
266 if action_count < self.min_actions {
267 return Err(LearnError::InvalidEpisode(format!(
268 "Too few actions: {} < {}",
269 action_count, self.min_actions
270 )));
271 }
272
273 let actions: Vec<&str> = episode
274 .context
275 .iter::<ActionRecord>()
276 .filter(|a| !a.is_terminal())
277 .map(|a| a.action.as_str())
278 .collect();
279
280 let prompt = format!(
281 "Diagnose and resolve the issue.\nAvailable actions: {}",
282 actions.join(", ")
283 );
284
285 let response = format!("Execute the following sequence: {}", actions.join(" -> "));
286
287 Ok(TrainingData::sft(&self.system_prompt, &prompt, &response)
288 .with_episode_id(episode.id.to_string())
289 .with_outcome_score(episode.outcome.score()))
290 }
291}
292
293pub struct WorkerDecisionSequenceLearn {
313 system_prompt: String,
315 min_actions: usize,
317 available_actions: Vec<String>,
319 system_events: Vec<String>,
321}
322
323impl WorkerDecisionSequenceLearn {
324 pub fn new() -> Self {
325 Self {
326 system_prompt: "You are an intelligent agent that diagnoses and resolves system issues. \
327 Given a context and available actions, determine the optimal action sequence.".to_string(),
328 min_actions: 3,
329 available_actions: vec![
330 "CheckStatus".to_string(),
331 "ReadLogs".to_string(),
332 "AnalyzeMetrics".to_string(),
333 "Diagnose".to_string(),
334 "Restart".to_string(),
335 ],
336 system_events: system_events::DEFAULT_SYSTEM_EVENTS
337 .iter()
338 .map(|s| s.to_string())
339 .collect(),
340 }
341 }
342
343 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
345 self.system_prompt = prompt.into();
346 self
347 }
348
349 pub fn with_min_actions(mut self, min: usize) -> Self {
351 self.min_actions = min;
352 self
353 }
354
355 pub fn with_available_actions(mut self, actions: Vec<String>) -> Self {
357 self.available_actions = actions;
358 self
359 }
360
361 pub fn with_system_event(mut self, event: impl Into<String>) -> Self {
363 self.system_events.push(event.into());
364 self
365 }
366
367 fn is_system_event(&self, action: &str) -> bool {
369 self.system_events.iter().any(|e| e == action)
370 }
371}
372
373impl Default for WorkerDecisionSequenceLearn {
374 fn default() -> Self {
375 Self::new()
376 }
377}
378
379impl LearnModel for WorkerDecisionSequenceLearn {
380 fn name(&self) -> &str {
381 "worker_decision_sequence"
382 }
383
384 fn objective(&self) -> &str {
385 "Learn successful action sequences for problem resolution"
386 }
387
388 fn evaluate(&self, context: &EpisodeContext) -> Outcome {
389 if context.is_empty() {
391 return Outcome::failure("Empty context: no actions to evaluate");
392 }
393
394 let successful_actions: Vec<_> = context
396 .iter::<ActionRecord>()
397 .filter(|a| a.success && !self.is_system_event(&a.action))
398 .collect();
399
400 if successful_actions.len() >= self.min_actions {
401 Outcome::success(1.0)
402 } else {
403 Outcome::failure(format!(
404 "Insufficient successful actions: {} < {}",
405 successful_actions.len(),
406 self.min_actions
407 ))
408 }
409 }
410
411 fn build_episodes(&self, records: &[Record]) -> Vec<Episode> {
412 let successful_actions: Vec<&ActionRecord> = records
414 .iter()
415 .filter_map(Record::as_action)
416 .filter(|a| a.success && !self.is_system_event(&a.action))
417 .collect();
418
419 if successful_actions.len() < self.min_actions {
420 return vec![];
421 }
422
423 let mut context = EpisodeContext::new();
425 for action in &successful_actions {
426 context.push((*action).clone());
427 }
428
429 let outcome = self.evaluate(&context);
430
431 let episode = Episode::builder()
432 .learn_model("worker_decision_sequence")
433 .context(context)
434 .outcome(outcome)
435 .build();
436
437 vec![episode]
438 }
439
440 fn convert(&self, episode: &Episode) -> Result<TrainingData, LearnError> {
441 if !episode.outcome.is_success() {
442 return Err(LearnError::InvalidEpisode(
443 "Episode is not successful".into(),
444 ));
445 }
446
447 let actions: Vec<&str> = episode
448 .context
449 .iter::<ActionRecord>()
450 .map(|a| a.action.as_str())
451 .collect();
452
453 if actions.len() < self.min_actions {
454 return Err(LearnError::InvalidEpisode(format!(
455 "Too few actions: {} < {}",
456 actions.len(),
457 self.min_actions
458 )));
459 }
460
461 let available = self.available_actions.join(", ");
463 let prompt = format!(
464 "Current context: default\n\
465 Available actions: {}\n\n\
466 What is the best sequence of actions to resolve this issue?",
467 available
468 );
469
470 let action_sequence = actions.join(" -> ");
472 let response = format!(
473 "Based on the context, the optimal action sequence is: {}",
474 action_sequence
475 );
476
477 Ok(TrainingData::sft(&self.system_prompt, &prompt, &response)
478 .with_episode_id(episode.id.to_string())
479 .with_outcome_score(episode.outcome.score()))
480 }
481}
482
483#[cfg(test)]
488mod tests {
489 use super::*;
490 use crate::events::{ActionContext, ActionEventBuilder, ActionEventResult};
491 use crate::types::WorkerId;
492 use std::time::Duration;
493
494 fn make_action(tick: u64, worker_id: usize, action: &str, success: bool) -> ActionEvent {
495 let result = if success {
496 ActionEventResult::success()
497 } else {
498 ActionEventResult::failure("error")
499 };
500
501 ActionEventBuilder::new(tick, WorkerId(worker_id), action)
502 .result(result)
503 .duration(Duration::from_millis(10))
504 .context(ActionContext::new())
505 .build()
506 }
507
508 fn make_records(actions: &[ActionEvent]) -> Vec<Record> {
509 actions.iter().map(Record::from).collect()
510 }
511
512 #[test]
513 fn test_record_accessors() {
514 let action = make_action(1, 5, "CheckStatus", true);
515 let record = Record::from(&action);
516
517 assert!(record.is_action());
518 assert!(!record.is_llm());
519 assert_eq!(record.worker_id(), Some(5));
520 assert!(record.as_action().is_some());
521 assert!(record.as_llm().is_none());
522 }
523
524 #[test]
525 fn test_record_stream_group_by_worker() {
526 let actions = vec![
527 make_action(1, 0, "A", true),
528 make_action(2, 1, "B", true),
529 make_action(3, 0, "C", true),
530 make_action(4, 1, "D", true),
531 ];
532 let records = make_records(&actions);
533 let stream = RecordStream::new(&records);
534
535 let groups = stream.group_by_worker();
536 assert_eq!(groups.len(), 2);
537 assert_eq!(groups.get(&0).map(|v| v.len()), Some(2));
538 assert_eq!(groups.get(&1).map(|v| v.len()), Some(2));
539 }
540
541 #[test]
542 fn test_worker_task_learn_build_episodes() {
543 let learn = WorkerTaskLearn::new().with_min_actions(2);
544
545 let actions = vec![
546 make_action(1, 0, "CheckStatus", true),
547 make_action(2, 0, "ReadLogs", true),
548 make_action(3, 0, "done", true),
549 make_action(4, 1, "Grep", true),
550 make_action(5, 1, "done", false),
551 ];
552 let records = make_records(&actions);
553
554 let episodes = learn.build_episodes(&records);
555
556 assert_eq!(episodes.len(), 2);
558
559 let worker0_ep = episodes.iter().find(|ep| ep.worker_id() == Some(0));
560 assert!(worker0_ep.is_some());
561 assert!(worker0_ep.unwrap().outcome.is_success());
562
563 let worker1_ep = episodes.iter().find(|ep| ep.worker_id() == Some(1));
564 assert!(worker1_ep.is_some());
565 assert!(worker1_ep.unwrap().outcome.is_failure());
566 }
567
568 #[test]
569 fn test_worker_task_learn_convert_success_only() {
570 let learn = WorkerTaskLearn::new();
571
572 let failed_ep = Episode::builder()
574 .learn_model("worker_task")
575 .outcome(Outcome::failure("test"))
576 .build();
577
578 assert!(learn.convert(&failed_ep).is_err());
579
580 let success_ep = Episode::builder()
582 .learn_model("worker_task")
583 .record(ActionRecord::new(1, 0, "Check").success(true))
584 .record(ActionRecord::new(2, 0, "Fix").success(true))
585 .record(ActionRecord::new(3, 0, "done").success(true))
586 .outcome(Outcome::success_binary())
587 .build();
588
589 assert!(learn.convert(&success_ep).is_ok());
590 }
591
592 #[test]
597 fn test_worker_decision_sequence_learn_build_episodes() {
598 let learn = WorkerDecisionSequenceLearn::new().with_min_actions(3);
599
600 let actions = vec![
601 make_action(1, 0, "CheckStatus", true),
602 make_action(2, 0, "ReadLogs", true),
603 make_action(3, 0, "Restart", true),
604 make_action(4, 0, "tick_end", true), make_action(5, 0, "done", true), ];
607 let records = make_records(&actions);
608
609 let episodes = learn.build_episodes(&records);
610
611 assert_eq!(episodes.len(), 1);
613 assert!(episodes[0].outcome.is_success());
614
615 let action_count = episodes[0].context.iter::<ActionRecord>().count();
617 assert_eq!(action_count, 3);
618 }
619
620 #[test]
621 fn test_worker_decision_sequence_learn_filters_failed_actions() {
622 let learn = WorkerDecisionSequenceLearn::new().with_min_actions(3);
623
624 let actions = vec![
625 make_action(1, 0, "CheckStatus", true),
626 make_action(2, 0, "ReadLogs", false), make_action(3, 0, "Restart", true),
628 ];
629 let records = make_records(&actions);
630
631 let episodes = learn.build_episodes(&records);
632
633 assert_eq!(episodes.len(), 0);
635 }
636
637 #[test]
638 fn test_worker_decision_sequence_learn_convert() {
639 let learn = WorkerDecisionSequenceLearn::new().with_available_actions(vec![
640 "A".to_string(),
641 "B".to_string(),
642 "C".to_string(),
643 ]);
644
645 let episode = Episode::builder()
646 .learn_model("worker_decision_sequence")
647 .record(ActionRecord::new(1, 0, "A").success(true))
648 .record(ActionRecord::new(2, 0, "B").success(true))
649 .record(ActionRecord::new(3, 0, "C").success(true))
650 .outcome(Outcome::success(1.0))
651 .build();
652
653 let result = learn.convert(&episode);
654 assert!(result.is_ok());
655
656 let training_data = result.unwrap();
657 assert!(training_data.prompt.contains("Available actions: A, B, C"));
658 assert!(training_data.chosen.contains("A -> B -> C"));
659 }
660}