1use std::collections::HashMap;
14
15use super::episode::Episode;
16use super::learn_model::LearnError;
17use super::learned_component::{
18 ComponentLearner, LearnedDepGraph, LearnedExploration, LearnedStrategy,
19};
20use super::record::ActionRecord;
21use super::RecommendedPath;
22use crate::exploration::DependencyGraph;
23
24#[derive(Debug, Clone, Default)]
39pub struct DepGraphLearner {
40 pub min_episodes: usize,
42 pub min_order_count: usize,
44}
45
46impl DepGraphLearner {
47 pub fn new() -> Self {
49 Self {
50 min_episodes: 3,
51 min_order_count: 2,
52 }
53 }
54
55 pub fn with_min_episodes(mut self, n: usize) -> Self {
57 self.min_episodes = n;
58 self
59 }
60
61 fn extract_order_relations(
63 &self,
64 action_sequences: &[Vec<String>],
65 ) -> HashMap<(String, String), usize> {
66 let mut relations: HashMap<(String, String), usize> = HashMap::new();
67
68 for sequence in action_sequences {
69 for i in 0..sequence.len() {
71 for j in (i + 1)..sequence.len() {
72 let key = (sequence[i].clone(), sequence[j].clone());
73 *relations.entry(key).or_insert(0) += 1;
74 }
75 }
76 }
77
78 relations
79 }
80
81 fn compute_action_order(&self, relations: &HashMap<(String, String), usize>) -> Vec<String> {
86 let mut scores: HashMap<String, i64> = HashMap::new();
88
89 for ((from, to), &count) in relations {
90 *scores.entry(from.clone()).or_insert(0) += count as i64;
92 *scores.entry(to.clone()).or_insert(0) -= count as i64;
93 }
94
95 let mut actions: Vec<_> = scores.into_iter().collect();
97 actions.sort_by(|a, b| b.1.cmp(&a.1));
98
99 actions.into_iter().map(|(action, _)| action).collect()
100 }
101
102 fn compute_recommended_paths(
104 &self,
105 success_count: &HashMap<Vec<String>, usize>,
106 total_success: usize,
107 ) -> Vec<RecommendedPath> {
108 let mut paths: Vec<_> = success_count
109 .iter()
110 .map(|(actions, &count)| {
111 let success_rate = count as f64 / total_success.max(1) as f64;
112 RecommendedPath {
113 actions: actions.clone(),
114 success_rate,
115 observations: count as u32,
116 }
117 })
118 .collect();
119
120 paths.sort_by(|a, b| {
122 b.success_rate
123 .partial_cmp(&a.success_rate)
124 .unwrap_or(std::cmp::Ordering::Equal)
125 });
126
127 paths.truncate(10);
129 paths
130 }
131}
132
133impl ComponentLearner for DepGraphLearner {
134 type Output = LearnedDepGraph;
135
136 fn name(&self) -> &str {
137 "dep_graph_learner"
138 }
139
140 fn objective(&self) -> &str {
141 "Learn action dependency graph from successful execution traces"
142 }
143
144 fn learn(&self, episodes: &[Episode]) -> Result<Self::Output, LearnError> {
145 let success_episodes: Vec<_> = episodes.iter().filter(|e| e.outcome.is_success()).collect();
147
148 if success_episodes.is_empty() {
149 return Err(LearnError::InsufficientData(
150 "No successful episodes to learn from".into(),
151 ));
152 }
153
154 let mut action_sequences: Vec<Vec<String>> = Vec::new();
156 let mut success_count: HashMap<Vec<String>, usize> = HashMap::new();
157 let mut session_ids: Vec<String> = Vec::new();
158
159 for episode in &success_episodes {
160 let actions: Vec<String> = episode
162 .context
163 .iter::<ActionRecord>()
164 .map(|r| r.action.clone())
165 .collect();
166
167 if !actions.is_empty() {
168 *success_count.entry(actions.clone()).or_insert(0) += 1;
169 action_sequences.push(actions);
170 }
171
172 let episode_id = episode.id.to_string();
174 if !session_ids.contains(&episode_id) {
175 session_ids.push(episode_id);
176 }
177 }
178
179 let relations = self.extract_order_relations(&action_sequences);
181
182 let action_order = self.compute_action_order(&relations);
184
185 let recommended_paths =
187 self.compute_recommended_paths(&success_count, success_episodes.len());
188
189 let confidence = if success_episodes.len() >= self.min_episodes {
191 (success_episodes.len() as f64 / (self.min_episodes as f64 * 2.0)).min(1.0)
192 } else {
193 success_episodes.len() as f64 / self.min_episodes as f64
194 };
195
196 let graph = DependencyGraph::new();
198
199 Ok(LearnedDepGraph::new(graph, action_order)
200 .with_confidence(confidence)
201 .with_sessions(session_ids)
202 .with_recommended_paths(recommended_paths))
203 }
204}
205
206#[derive(Debug, Clone, Default)]
214pub struct ExplorationLearner {
215 pub initial_ucb1_c: f64,
217}
218
219impl ExplorationLearner {
220 pub fn new() -> Self {
222 Self {
223 initial_ucb1_c: 1.414,
224 }
225 }
226}
227
228impl ComponentLearner for ExplorationLearner {
229 type Output = LearnedExploration;
230
231 fn name(&self) -> &str {
232 "exploration_learner"
233 }
234
235 fn objective(&self) -> &str {
236 "Optimize exploration parameters from session statistics"
237 }
238
239 fn learn(&self, episodes: &[Episode]) -> Result<Self::Output, LearnError> {
240 if episodes.is_empty() {
241 return Err(LearnError::InsufficientData(
242 "No episodes to learn from".into(),
243 ));
244 }
245
246 let total = episodes.len();
248 let success = episodes.iter().filter(|e| e.outcome.is_success()).count();
249 let success_rate = success as f64 / total as f64;
250
251 let ucb1_c = if success_rate < 0.3 {
255 2.0 } else if success_rate < 0.7 {
257 1.414 } else {
259 1.0 };
261
262 let confidence = (total as f64 / 10.0).min(1.0);
264
265 Ok(LearnedExploration {
266 ucb1_c,
267 learning_weight: 0.3,
268 ngram_weight: 1.0,
269 confidence,
270 session_count: total,
271 updated_at: std::time::SystemTime::now()
272 .duration_since(std::time::UNIX_EPOCH)
273 .map(|d| d.as_secs())
274 .unwrap_or(0),
275 })
276 }
277}
278
279#[derive(Debug, Clone, Default)]
287pub struct StrategyLearner;
288
289impl StrategyLearner {
290 pub fn new() -> Self {
292 Self
293 }
294}
295
296impl ComponentLearner for StrategyLearner {
297 type Output = LearnedStrategy;
298
299 fn name(&self) -> &str {
300 "strategy_learner"
301 }
302
303 fn objective(&self) -> &str {
304 "Determine optimal strategy selection settings"
305 }
306
307 fn learn(&self, episodes: &[Episode]) -> Result<Self::Output, LearnError> {
308 if episodes.is_empty() {
309 return Err(LearnError::InsufficientData(
310 "No episodes to learn from".into(),
311 ));
312 }
313
314 let total = episodes.len();
315 let success = episodes.iter().filter(|e| e.outcome.is_success()).count();
316 let success_rate = success as f64 / total as f64;
317
318 let initial_strategy = if success_rate < 0.5 {
320 "ucb1".to_string() } else {
322 "greedy".to_string() };
324
325 let error_rate_threshold = if success_rate < 0.3 {
327 0.6 } else {
329 0.45 };
331
332 let confidence = (total as f64 / 10.0).min(1.0);
333
334 Ok(LearnedStrategy {
335 initial_strategy,
336 maturity_threshold: 5,
337 error_rate_threshold,
338 confidence,
339 session_count: total,
340 updated_at: std::time::SystemTime::now()
341 .duration_since(std::time::UNIX_EPOCH)
342 .map(|d| d.as_secs())
343 .unwrap_or(0),
344 })
345 }
346}
347
348#[cfg(test)]
353mod tests {
354 use super::*;
355 use crate::learn::episode::{Episode, EpisodeContext, Outcome};
356
357 fn make_success_episode(_actions: Vec<&str>) -> Episode {
358 let context = EpisodeContext::new();
359 Episode::builder()
362 .learn_model("test")
363 .context(context)
364 .outcome(Outcome::success(1.0))
365 .build()
366 }
367
368 fn make_failure_episode() -> Episode {
369 Episode::builder()
370 .learn_model("test")
371 .context(EpisodeContext::new())
372 .outcome(Outcome::failure("test failure"))
373 .build()
374 }
375
376 #[test]
377 fn test_dep_graph_learner_empty() {
378 let learner = DepGraphLearner::new();
379 let result = learner.learn(&[]);
380 assert!(result.is_err());
381 }
382
383 #[test]
384 fn test_dep_graph_learner_no_success() {
385 let learner = DepGraphLearner::new();
386 let episodes = vec![make_failure_episode(), make_failure_episode()];
387 let result = learner.learn(&episodes);
388 assert!(result.is_err());
389 }
390
391 #[test]
392 fn test_dep_graph_learner_with_success() {
393 let learner = DepGraphLearner::new();
394 let episodes = vec![
395 make_success_episode(vec!["A", "B", "C"]),
396 make_success_episode(vec!["A", "B", "C"]),
397 make_success_episode(vec!["A", "B", "C"]),
398 ];
399 let result = learner.learn(&episodes);
400 assert!(result.is_ok());
401
402 let learned = result.unwrap();
403 assert!(learned.confidence > 0.0);
404 }
405
406 #[test]
407 fn test_exploration_learner() {
408 let learner = ExplorationLearner::new();
409 let episodes = vec![
410 make_success_episode(vec![]),
411 make_success_episode(vec![]),
412 make_failure_episode(),
413 ];
414 let result = learner.learn(&episodes);
415 assert!(result.is_ok());
416
417 let learned = result.unwrap();
418 assert!(learned.ucb1_c > 0.0);
419 assert_eq!(learned.session_count, 3);
420 }
421
422 #[test]
423 fn test_strategy_learner() {
424 let learner = StrategyLearner::new();
425 let episodes = vec![make_success_episode(vec![]), make_failure_episode()];
426 let result = learner.learn(&episodes);
427 assert!(result.is_ok());
428
429 let learned = result.unwrap();
430 assert!(!learned.initial_strategy.is_empty());
431 }
432
433 #[test]
434 fn test_extract_order_relations() {
435 let learner = DepGraphLearner::new().with_min_episodes(1);
436
437 let sequences = vec![
438 vec!["A".to_string(), "B".to_string(), "C".to_string()],
439 vec!["A".to_string(), "B".to_string(), "C".to_string()],
440 ];
441
442 let relations = learner.extract_order_relations(&sequences);
443
444 assert_eq!(relations.get(&("A".to_string(), "B".to_string())), Some(&2));
446 assert_eq!(relations.get(&("A".to_string(), "C".to_string())), Some(&2));
447 assert_eq!(relations.get(&("B".to_string(), "C".to_string())), Some(&2));
448 }
449}