1use std::collections::HashMap;
7
8use serde::{Deserialize, Serialize};
9
10type Tuple3Map<V> = HashMap<(String, String, String), V>;
16
17type Tuple4Map<V> = HashMap<(String, String, String, String), V>;
19
20#[derive(Debug, Clone, Default, Serialize, Deserialize)]
29pub struct LearnStats {
30 pub episode_transitions: EpisodeTransitions,
32 pub ngram_stats: NgramStats,
34 pub selection_performance: SelectionPerformance,
36 #[serde(
38 serialize_with = "serialize_tuple2_map",
39 deserialize_with = "deserialize_tuple2_map"
40 )]
41 pub contextual_stats: HashMap<(String, String), ContextualActionStats>,
42}
43
44impl LearnStats {
45 pub fn load_prior(&mut self, snapshot: &crate::learn::LearningSnapshot) {
47 self.episode_transitions = snapshot.episode_transitions.clone();
48 self.ngram_stats = snapshot.ngram_stats.clone();
49 self.selection_performance = snapshot.selection_performance.clone();
50 for ((prev, action), stats) in &snapshot.contextual_stats {
52 let contextual = ContextualActionStats {
53 visits: stats.visits,
54 successes: stats.successes,
55 failures: stats.failures,
56 };
57 self.contextual_stats
58 .insert((prev.clone(), action.clone()), contextual);
59 }
60 }
61}
62
63#[derive(Debug, Clone, Default, Serialize, Deserialize)]
65pub struct ContextualActionStats {
66 pub visits: u32,
67 pub successes: u32,
68 pub failures: u32,
69}
70
71impl ContextualActionStats {
72 pub fn success_rate(&self) -> f64 {
73 if self.visits == 0 {
74 0.5
75 } else {
76 self.successes as f64 / self.visits as f64
77 }
78 }
79}
80
81#[derive(Debug, Clone, Default, Serialize, Deserialize)]
90pub struct EpisodeTransitions {
91 #[serde(
93 serialize_with = "serialize_tuple2_map",
94 deserialize_with = "deserialize_tuple2_map"
95 )]
96 pub success_transitions: HashMap<(String, String), u32>,
97 #[serde(
99 serialize_with = "serialize_tuple2_map",
100 deserialize_with = "deserialize_tuple2_map"
101 )]
102 pub failure_transitions: HashMap<(String, String), u32>,
103 pub success_episodes: u32,
105 pub failure_episodes: u32,
107}
108
109impl EpisodeTransitions {
110 pub fn success_transition_rate(&self, from: &str, to: &str) -> f64 {
112 let key = (from.to_string(), to.to_string());
113 let success_count = self.success_transitions.get(&key).copied().unwrap_or(0);
114 let failure_count = self.failure_transitions.get(&key).copied().unwrap_or(0);
115 let total = success_count + failure_count;
116
117 if total == 0 {
118 0.5
119 } else {
120 success_count as f64 / total as f64
121 }
122 }
123
124 pub fn transition_value(&self, from: &str, to: &str) -> f64 {
126 let key = (from.to_string(), to.to_string());
127 let success_count = self.success_transitions.get(&key).copied().unwrap_or(0) as f64;
128 let failure_count = self.failure_transitions.get(&key).copied().unwrap_or(0) as f64;
129
130 let total_success = self.success_transitions.values().sum::<u32>() as f64;
131 let total_failure = self.failure_transitions.values().sum::<u32>() as f64;
132
133 let success_rate = if total_success > 0.0 {
134 success_count / total_success
135 } else {
136 0.0
137 };
138 let failure_rate = if total_failure > 0.0 {
139 failure_count / total_failure
140 } else {
141 0.0
142 };
143
144 success_rate - failure_rate
145 }
146
147 pub fn recommended_next_actions(&self, from: &str) -> Vec<(String, f64)> {
149 let mut candidates: Vec<_> = self
150 .success_transitions
151 .iter()
152 .filter(|((f, _), _)| f == from)
153 .map(|((_, to), _)| {
154 let value = self.transition_value(from, to);
155 (to.clone(), value)
156 })
157 .collect();
158
159 candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
160 candidates
161 }
162}
163
164#[derive(Debug, Clone, Default, Serialize, Deserialize)]
170pub struct NgramStats {
171 #[serde(
173 serialize_with = "serialize_tuple3_map",
174 deserialize_with = "deserialize_tuple3_map"
175 )]
176 pub trigrams: HashMap<(String, String, String), (u32, u32)>,
177 #[serde(
179 serialize_with = "serialize_tuple4_map",
180 deserialize_with = "deserialize_tuple4_map"
181 )]
182 pub quadgrams: HashMap<(String, String, String, String), (u32, u32)>,
183}
184
185impl NgramStats {
186 pub fn trigram_success_rate(&self, a1: &str, a2: &str, a3: &str) -> f64 {
188 let key = (a1.to_string(), a2.to_string(), a3.to_string());
189 match self.trigrams.get(&key) {
190 Some(&(success, failure)) => {
191 let total = success + failure;
192 if total == 0 {
193 0.5
194 } else {
195 success as f64 / total as f64
196 }
197 }
198 None => 0.5,
199 }
200 }
201
202 pub fn quadgram_success_rate(&self, a1: &str, a2: &str, a3: &str, a4: &str) -> f64 {
204 let key = (
205 a1.to_string(),
206 a2.to_string(),
207 a3.to_string(),
208 a4.to_string(),
209 );
210 match self.quadgrams.get(&key) {
211 Some(&(success, failure)) => {
212 let total = success + failure;
213 if total == 0 {
214 0.5
215 } else {
216 success as f64 / total as f64
217 }
218 }
219 None => 0.5,
220 }
221 }
222
223 pub fn trigram_value(&self, a1: &str, a2: &str, a3: &str) -> f64 {
225 let key = (a1.to_string(), a2.to_string(), a3.to_string());
226 match self.trigrams.get(&key) {
227 Some(&(success, failure)) => {
228 let total = success + failure;
229 if total == 0 {
230 0.0
231 } else {
232 (success as f64 / total as f64) * 2.0 - 1.0
233 }
234 }
235 None => 0.0,
236 }
237 }
238
239 pub fn recommended_after(&self, a1: &str, a2: &str) -> Vec<(String, f64)> {
241 let mut candidates: Vec<_> = self
242 .trigrams
243 .iter()
244 .filter(|((x1, x2, _), _)| x1 == a1 && x2 == a2)
245 .map(|((_, _, a3), &(success, failure))| {
246 let total = success + failure;
247 let score = if total == 0 {
248 0.0
249 } else {
250 (success as f64 / total as f64) * 2.0 - 1.0
251 };
252 (a3.clone(), score)
253 })
254 .collect();
255
256 candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
257 candidates
258 }
259
260 pub fn recommended_after_three(&self, a1: &str, a2: &str, a3: &str) -> Vec<(String, f64)> {
262 let mut candidates: Vec<_> = self
263 .quadgrams
264 .iter()
265 .filter(|((x1, x2, x3, _), _)| x1 == a1 && x2 == a2 && x3 == a3)
266 .map(|((_, _, _, a4), &(success, failure))| {
267 let total = success + failure;
268 let score = if total == 0 {
269 0.0
270 } else {
271 (success as f64 / total as f64) * 2.0 - 1.0
272 };
273 (a4.clone(), score)
274 })
275 .collect();
276
277 candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
278 candidates
279 }
280
281 pub fn trigram_count(&self) -> usize {
282 self.trigrams.len()
283 }
284
285 pub fn quadgram_count(&self) -> usize {
286 self.quadgrams.len()
287 }
288}
289
290#[derive(Debug, Clone, Default, Serialize, Deserialize)]
296pub struct SelectionPerformance {
297 pub strategy_stats: HashMap<String, StrategyStats>,
299 pub switch_history: Vec<StrategySwitchEvent>,
301 pub current_strategy: Option<String>,
303 pub strategy_start_visits: u32,
305 pub strategy_start_success_rate: f64,
307}
308
309#[derive(Debug, Clone, Default, Serialize, Deserialize)]
311pub struct StrategyStats {
312 pub visits: u32,
313 pub successes: u32,
314 pub failures: u32,
315 pub usage_count: u32,
316 pub episodes_success: u32,
317 pub episodes_failure: u32,
318}
319
320impl StrategyStats {
321 pub fn success_rate(&self) -> f64 {
322 if self.visits == 0 {
323 0.5
324 } else {
325 self.successes as f64 / self.visits as f64
326 }
327 }
328
329 pub fn episode_success_rate(&self) -> f64 {
330 let total = self.episodes_success + self.episodes_failure;
331 if total == 0 {
332 0.5
333 } else {
334 self.episodes_success as f64 / total as f64
335 }
336 }
337}
338
339#[derive(Debug, Clone, Serialize, Deserialize)]
341pub struct StrategySwitchEvent {
342 pub from: String,
343 pub to: String,
344 pub visits_at_switch: u32,
345 pub success_rate_at_switch: f64,
346 pub from_strategy_success_rate: f64,
347}
348
349impl SelectionPerformance {
350 pub fn start_strategy(
352 &mut self,
353 strategy: &str,
354 current_visits: u32,
355 current_success_rate: f64,
356 ) {
357 if let Some(ref current) = self.current_strategy {
358 if current != strategy {
359 let from_stats = self
360 .strategy_stats
361 .get(current)
362 .cloned()
363 .unwrap_or_default();
364 self.switch_history.push(StrategySwitchEvent {
365 from: current.clone(),
366 to: strategy.to_string(),
367 visits_at_switch: current_visits,
368 success_rate_at_switch: current_success_rate,
369 from_strategy_success_rate: from_stats.success_rate(),
370 });
371 }
372 }
373
374 self.current_strategy = Some(strategy.to_string());
375 self.strategy_start_visits = current_visits;
376 self.strategy_start_success_rate = current_success_rate;
377
378 self.strategy_stats
379 .entry(strategy.to_string())
380 .or_default()
381 .usage_count += 1;
382 }
383
384 pub fn record_action(&mut self, success: bool) {
386 if let Some(ref strategy) = self.current_strategy {
387 let stats = self.strategy_stats.entry(strategy.clone()).or_default();
388 stats.visits += 1;
389 if success {
390 stats.successes += 1;
391 } else {
392 stats.failures += 1;
393 }
394 }
395 }
396
397 pub fn record_episode_end(&mut self, success: bool) {
399 if let Some(ref strategy) = self.current_strategy {
400 let stats = self.strategy_stats.entry(strategy.clone()).or_default();
401 if success {
402 stats.episodes_success += 1;
403 } else {
404 stats.episodes_failure += 1;
405 }
406 }
407 }
408
409 pub fn strategy_effectiveness(&self, strategy: &str) -> Option<f64> {
411 self.strategy_stats.get(strategy).map(|s| s.success_rate())
412 }
413
414 pub fn best_strategy(&self) -> Option<(&str, f64)> {
416 self.strategy_stats
417 .iter()
418 .filter(|(_, stats)| stats.visits >= 10)
419 .max_by(|(_, a), (_, b)| {
420 a.success_rate()
421 .partial_cmp(&b.success_rate())
422 .unwrap_or(std::cmp::Ordering::Equal)
423 })
424 .map(|(name, stats)| (name.as_str(), stats.success_rate()))
425 }
426
427 pub fn recommended_strategy(&self, failure_rate: f64, visits: u32) -> &str {
429 let ucb1_score = self.strategy_score_for_context("UCB1", failure_rate, visits);
430 let greedy_score = self.strategy_score_for_context("Greedy", failure_rate, visits);
431 let thompson_score = self.strategy_score_for_context("Thompson", failure_rate, visits);
432
433 if ucb1_score >= greedy_score && ucb1_score >= thompson_score {
434 "UCB1"
435 } else if greedy_score >= thompson_score {
436 "Greedy"
437 } else {
438 "Thompson"
439 }
440 }
441
442 fn strategy_score_for_context(&self, strategy: &str, failure_rate: f64, _visits: u32) -> f64 {
443 let base_score = self
444 .strategy_stats
445 .get(strategy)
446 .map(|s| s.success_rate())
447 .unwrap_or(0.5);
448
449 match strategy {
450 "UCB1" => base_score + failure_rate * 0.2,
451 "Greedy" => base_score + (1.0 - failure_rate) * 0.2,
452 "Thompson" => {
453 let distance_from_middle = (failure_rate - 0.5).abs();
454 base_score + (0.5 - distance_from_middle) * 0.2
455 }
456 _ => base_score,
457 }
458 }
459}
460
461fn serialize_tuple2_map<V, S>(
466 map: &HashMap<(String, String), V>,
467 serializer: S,
468) -> Result<S::Ok, S::Error>
469where
470 V: Serialize,
471 S: serde::Serializer,
472{
473 use serde::ser::SerializeMap;
474 let mut ser_map = serializer.serialize_map(Some(map.len()))?;
475 for ((a, b), v) in map {
476 ser_map.serialize_entry(&format!("{}:{}", a, b), v)?;
477 }
478 ser_map.end()
479}
480
481fn deserialize_tuple2_map<'de, V, D>(
482 deserializer: D,
483) -> Result<HashMap<(String, String), V>, D::Error>
484where
485 V: Deserialize<'de>,
486 D: serde::Deserializer<'de>,
487{
488 use serde::de::Error;
489 let string_map: HashMap<String, V> = HashMap::deserialize(deserializer)?;
490 let mut result = HashMap::new();
491 for (k, v) in string_map {
492 let parts: Vec<&str> = k.splitn(2, ':').collect();
493 if parts.len() != 2 {
494 return Err(D::Error::custom(format!("invalid tuple2 key: {}", k)));
495 }
496 result.insert((parts[0].to_string(), parts[1].to_string()), v);
497 }
498 Ok(result)
499}
500
501fn serialize_tuple3_map<V, S>(
502 map: &HashMap<(String, String, String), V>,
503 serializer: S,
504) -> Result<S::Ok, S::Error>
505where
506 V: Serialize,
507 S: serde::Serializer,
508{
509 use serde::ser::SerializeMap;
510 let mut ser_map = serializer.serialize_map(Some(map.len()))?;
511 for ((a, b, c), v) in map {
512 ser_map.serialize_entry(&format!("{}:{}:{}", a, b, c), v)?;
513 }
514 ser_map.end()
515}
516
517fn deserialize_tuple3_map<'de, V, D>(deserializer: D) -> Result<Tuple3Map<V>, D::Error>
518where
519 V: Deserialize<'de>,
520 D: serde::Deserializer<'de>,
521{
522 use serde::de::Error;
523 let string_map: HashMap<String, V> = HashMap::deserialize(deserializer)?;
524 let mut result = HashMap::new();
525 for (k, v) in string_map {
526 let parts: Vec<&str> = k.splitn(3, ':').collect();
527 if parts.len() != 3 {
528 return Err(D::Error::custom(format!("invalid tuple3 key: {}", k)));
529 }
530 result.insert(
531 (
532 parts[0].to_string(),
533 parts[1].to_string(),
534 parts[2].to_string(),
535 ),
536 v,
537 );
538 }
539 Ok(result)
540}
541
542fn serialize_tuple4_map<V, S>(
543 map: &HashMap<(String, String, String, String), V>,
544 serializer: S,
545) -> Result<S::Ok, S::Error>
546where
547 V: Serialize,
548 S: serde::Serializer,
549{
550 use serde::ser::SerializeMap;
551 let mut ser_map = serializer.serialize_map(Some(map.len()))?;
552 for ((a, b, c, d), v) in map {
553 ser_map.serialize_entry(&format!("{}:{}:{}:{}", a, b, c, d), v)?;
554 }
555 ser_map.end()
556}
557
558fn deserialize_tuple4_map<'de, V, D>(deserializer: D) -> Result<Tuple4Map<V>, D::Error>
559where
560 V: Deserialize<'de>,
561 D: serde::Deserializer<'de>,
562{
563 use serde::de::Error;
564 let string_map: HashMap<String, V> = HashMap::deserialize(deserializer)?;
565 let mut result = HashMap::new();
566 for (k, v) in string_map {
567 let parts: Vec<&str> = k.splitn(4, ':').collect();
568 if parts.len() != 4 {
569 return Err(D::Error::custom(format!("invalid tuple4 key: {}", k)));
570 }
571 result.insert(
572 (
573 parts[0].to_string(),
574 parts[1].to_string(),
575 parts[2].to_string(),
576 parts[3].to_string(),
577 ),
578 v,
579 );
580 }
581 Ok(result)
582}