1use std::collections::HashMap;
7
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone, Default, Serialize, Deserialize)]
19pub struct LearnStats {
20 pub episode_transitions: EpisodeTransitions,
22 pub ngram_stats: NgramStats,
24 pub selection_performance: SelectionPerformance,
26 #[serde(
28 serialize_with = "serialize_tuple2_map",
29 deserialize_with = "deserialize_tuple2_map"
30 )]
31 pub contextual_stats: HashMap<(String, String), ContextualActionStats>,
32}
33
34impl LearnStats {
35 pub fn load_prior(&mut self, snapshot: &crate::learn::LearningSnapshot) {
37 self.episode_transitions = snapshot.episode_transitions.clone();
38 self.ngram_stats = snapshot.ngram_stats.clone();
39 self.selection_performance = snapshot.selection_performance.clone();
40 for ((prev, action), stats) in &snapshot.contextual_stats {
42 let contextual = ContextualActionStats {
43 visits: stats.visits,
44 successes: stats.successes,
45 failures: stats.failures,
46 };
47 self.contextual_stats
48 .insert((prev.clone(), action.clone()), contextual);
49 }
50 }
51}
52
53#[derive(Debug, Clone, Default, Serialize, Deserialize)]
55pub struct ContextualActionStats {
56 pub visits: u32,
57 pub successes: u32,
58 pub failures: u32,
59}
60
61impl ContextualActionStats {
62 pub fn success_rate(&self) -> f64 {
63 if self.visits == 0 {
64 0.5
65 } else {
66 self.successes as f64 / self.visits as f64
67 }
68 }
69}
70
71#[derive(Debug, Clone, Default, Serialize, Deserialize)]
80pub struct EpisodeTransitions {
81 #[serde(
83 serialize_with = "serialize_tuple2_map",
84 deserialize_with = "deserialize_tuple2_map"
85 )]
86 pub success_transitions: HashMap<(String, String), u32>,
87 #[serde(
89 serialize_with = "serialize_tuple2_map",
90 deserialize_with = "deserialize_tuple2_map"
91 )]
92 pub failure_transitions: HashMap<(String, String), u32>,
93 pub success_episodes: u32,
95 pub failure_episodes: u32,
97}
98
99impl EpisodeTransitions {
100 pub fn success_transition_rate(&self, from: &str, to: &str) -> f64 {
102 let key = (from.to_string(), to.to_string());
103 let success_count = self.success_transitions.get(&key).copied().unwrap_or(0);
104 let failure_count = self.failure_transitions.get(&key).copied().unwrap_or(0);
105 let total = success_count + failure_count;
106
107 if total == 0 {
108 0.5
109 } else {
110 success_count as f64 / total as f64
111 }
112 }
113
114 pub fn transition_value(&self, from: &str, to: &str) -> f64 {
116 let key = (from.to_string(), to.to_string());
117 let success_count = self.success_transitions.get(&key).copied().unwrap_or(0) as f64;
118 let failure_count = self.failure_transitions.get(&key).copied().unwrap_or(0) as f64;
119
120 let total_success = self.success_transitions.values().sum::<u32>() as f64;
121 let total_failure = self.failure_transitions.values().sum::<u32>() as f64;
122
123 let success_rate = if total_success > 0.0 {
124 success_count / total_success
125 } else {
126 0.0
127 };
128 let failure_rate = if total_failure > 0.0 {
129 failure_count / total_failure
130 } else {
131 0.0
132 };
133
134 success_rate - failure_rate
135 }
136
137 pub fn recommended_next_actions(&self, from: &str) -> Vec<(String, f64)> {
139 let mut candidates: Vec<_> = self
140 .success_transitions
141 .iter()
142 .filter(|((f, _), _)| f == from)
143 .map(|((_, to), _)| {
144 let value = self.transition_value(from, to);
145 (to.clone(), value)
146 })
147 .collect();
148
149 candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
150 candidates
151 }
152}
153
154#[derive(Debug, Clone, Default, Serialize, Deserialize)]
160pub struct NgramStats {
161 #[serde(
163 serialize_with = "serialize_tuple3_map",
164 deserialize_with = "deserialize_tuple3_map"
165 )]
166 pub trigrams: HashMap<(String, String, String), (u32, u32)>,
167 #[serde(
169 serialize_with = "serialize_tuple4_map",
170 deserialize_with = "deserialize_tuple4_map"
171 )]
172 pub quadgrams: HashMap<(String, String, String, String), (u32, u32)>,
173}
174
175impl NgramStats {
176 pub fn trigram_success_rate(&self, a1: &str, a2: &str, a3: &str) -> f64 {
178 let key = (a1.to_string(), a2.to_string(), a3.to_string());
179 match self.trigrams.get(&key) {
180 Some(&(success, failure)) => {
181 let total = success + failure;
182 if total == 0 {
183 0.5
184 } else {
185 success as f64 / total as f64
186 }
187 }
188 None => 0.5,
189 }
190 }
191
192 pub fn quadgram_success_rate(&self, a1: &str, a2: &str, a3: &str, a4: &str) -> f64 {
194 let key = (
195 a1.to_string(),
196 a2.to_string(),
197 a3.to_string(),
198 a4.to_string(),
199 );
200 match self.quadgrams.get(&key) {
201 Some(&(success, failure)) => {
202 let total = success + failure;
203 if total == 0 {
204 0.5
205 } else {
206 success as f64 / total as f64
207 }
208 }
209 None => 0.5,
210 }
211 }
212
213 pub fn trigram_value(&self, a1: &str, a2: &str, a3: &str) -> f64 {
215 let key = (a1.to_string(), a2.to_string(), a3.to_string());
216 match self.trigrams.get(&key) {
217 Some(&(success, failure)) => {
218 let total = success + failure;
219 if total == 0 {
220 0.0
221 } else {
222 (success as f64 / total as f64) * 2.0 - 1.0
223 }
224 }
225 None => 0.0,
226 }
227 }
228
229 pub fn recommended_after(&self, a1: &str, a2: &str) -> Vec<(String, f64)> {
231 let mut candidates: Vec<_> = self
232 .trigrams
233 .iter()
234 .filter(|((x1, x2, _), _)| x1 == a1 && x2 == a2)
235 .map(|((_, _, a3), &(success, failure))| {
236 let total = success + failure;
237 let score = if total == 0 {
238 0.0
239 } else {
240 (success as f64 / total as f64) * 2.0 - 1.0
241 };
242 (a3.clone(), score)
243 })
244 .collect();
245
246 candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
247 candidates
248 }
249
250 pub fn recommended_after_three(&self, a1: &str, a2: &str, a3: &str) -> Vec<(String, f64)> {
252 let mut candidates: Vec<_> = self
253 .quadgrams
254 .iter()
255 .filter(|((x1, x2, x3, _), _)| x1 == a1 && x2 == a2 && x3 == a3)
256 .map(|((_, _, _, a4), &(success, failure))| {
257 let total = success + failure;
258 let score = if total == 0 {
259 0.0
260 } else {
261 (success as f64 / total as f64) * 2.0 - 1.0
262 };
263 (a4.clone(), score)
264 })
265 .collect();
266
267 candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
268 candidates
269 }
270
271 pub fn trigram_count(&self) -> usize {
272 self.trigrams.len()
273 }
274
275 pub fn quadgram_count(&self) -> usize {
276 self.quadgrams.len()
277 }
278}
279
280#[derive(Debug, Clone, Default, Serialize, Deserialize)]
286pub struct SelectionPerformance {
287 pub strategy_stats: HashMap<String, StrategyStats>,
289 pub switch_history: Vec<StrategySwitchEvent>,
291 pub current_strategy: Option<String>,
293 pub strategy_start_visits: u32,
295 pub strategy_start_success_rate: f64,
297}
298
299#[derive(Debug, Clone, Default, Serialize, Deserialize)]
301pub struct StrategyStats {
302 pub visits: u32,
303 pub successes: u32,
304 pub failures: u32,
305 pub usage_count: u32,
306 pub episodes_success: u32,
307 pub episodes_failure: u32,
308}
309
310impl StrategyStats {
311 pub fn success_rate(&self) -> f64 {
312 if self.visits == 0 {
313 0.5
314 } else {
315 self.successes as f64 / self.visits as f64
316 }
317 }
318
319 pub fn episode_success_rate(&self) -> f64 {
320 let total = self.episodes_success + self.episodes_failure;
321 if total == 0 {
322 0.5
323 } else {
324 self.episodes_success as f64 / total as f64
325 }
326 }
327}
328
329#[derive(Debug, Clone, Serialize, Deserialize)]
331pub struct StrategySwitchEvent {
332 pub from: String,
333 pub to: String,
334 pub visits_at_switch: u32,
335 pub success_rate_at_switch: f64,
336 pub from_strategy_success_rate: f64,
337}
338
339impl SelectionPerformance {
340 pub fn start_strategy(
342 &mut self,
343 strategy: &str,
344 current_visits: u32,
345 current_success_rate: f64,
346 ) {
347 if let Some(ref current) = self.current_strategy {
348 if current != strategy {
349 let from_stats = self
350 .strategy_stats
351 .get(current)
352 .cloned()
353 .unwrap_or_default();
354 self.switch_history.push(StrategySwitchEvent {
355 from: current.clone(),
356 to: strategy.to_string(),
357 visits_at_switch: current_visits,
358 success_rate_at_switch: current_success_rate,
359 from_strategy_success_rate: from_stats.success_rate(),
360 });
361 }
362 }
363
364 self.current_strategy = Some(strategy.to_string());
365 self.strategy_start_visits = current_visits;
366 self.strategy_start_success_rate = current_success_rate;
367
368 self.strategy_stats
369 .entry(strategy.to_string())
370 .or_default()
371 .usage_count += 1;
372 }
373
374 pub fn record_action(&mut self, success: bool) {
376 if let Some(ref strategy) = self.current_strategy {
377 let stats = self.strategy_stats.entry(strategy.clone()).or_default();
378 stats.visits += 1;
379 if success {
380 stats.successes += 1;
381 } else {
382 stats.failures += 1;
383 }
384 }
385 }
386
387 pub fn record_episode_end(&mut self, success: bool) {
389 if let Some(ref strategy) = self.current_strategy {
390 let stats = self.strategy_stats.entry(strategy.clone()).or_default();
391 if success {
392 stats.episodes_success += 1;
393 } else {
394 stats.episodes_failure += 1;
395 }
396 }
397 }
398
399 pub fn strategy_effectiveness(&self, strategy: &str) -> Option<f64> {
401 self.strategy_stats.get(strategy).map(|s| s.success_rate())
402 }
403
404 pub fn best_strategy(&self) -> Option<(&str, f64)> {
406 self.strategy_stats
407 .iter()
408 .filter(|(_, stats)| stats.visits >= 10)
409 .max_by(|(_, a), (_, b)| {
410 a.success_rate()
411 .partial_cmp(&b.success_rate())
412 .unwrap_or(std::cmp::Ordering::Equal)
413 })
414 .map(|(name, stats)| (name.as_str(), stats.success_rate()))
415 }
416
417 pub fn recommended_strategy(&self, failure_rate: f64, visits: u32) -> &str {
419 let ucb1_score = self.strategy_score_for_context("UCB1", failure_rate, visits);
420 let greedy_score = self.strategy_score_for_context("Greedy", failure_rate, visits);
421 let thompson_score = self.strategy_score_for_context("Thompson", failure_rate, visits);
422
423 if ucb1_score >= greedy_score && ucb1_score >= thompson_score {
424 "UCB1"
425 } else if greedy_score >= thompson_score {
426 "Greedy"
427 } else {
428 "Thompson"
429 }
430 }
431
432 fn strategy_score_for_context(&self, strategy: &str, failure_rate: f64, _visits: u32) -> f64 {
433 let base_score = self
434 .strategy_stats
435 .get(strategy)
436 .map(|s| s.success_rate())
437 .unwrap_or(0.5);
438
439 match strategy {
440 "UCB1" => base_score + failure_rate * 0.2,
441 "Greedy" => base_score + (1.0 - failure_rate) * 0.2,
442 "Thompson" => {
443 let distance_from_middle = (failure_rate - 0.5).abs();
444 base_score + (0.5 - distance_from_middle) * 0.2
445 }
446 _ => base_score,
447 }
448 }
449}
450
451fn serialize_tuple2_map<V, S>(
456 map: &HashMap<(String, String), V>,
457 serializer: S,
458) -> Result<S::Ok, S::Error>
459where
460 V: Serialize,
461 S: serde::Serializer,
462{
463 use serde::ser::SerializeMap;
464 let mut ser_map = serializer.serialize_map(Some(map.len()))?;
465 for ((a, b), v) in map {
466 ser_map.serialize_entry(&format!("{}:{}", a, b), v)?;
467 }
468 ser_map.end()
469}
470
471fn deserialize_tuple2_map<'de, V, D>(
472 deserializer: D,
473) -> Result<HashMap<(String, String), V>, D::Error>
474where
475 V: Deserialize<'de>,
476 D: serde::Deserializer<'de>,
477{
478 use serde::de::Error;
479 let string_map: HashMap<String, V> = HashMap::deserialize(deserializer)?;
480 let mut result = HashMap::new();
481 for (k, v) in string_map {
482 let parts: Vec<&str> = k.splitn(2, ':').collect();
483 if parts.len() != 2 {
484 return Err(D::Error::custom(format!("invalid tuple2 key: {}", k)));
485 }
486 result.insert((parts[0].to_string(), parts[1].to_string()), v);
487 }
488 Ok(result)
489}
490
491fn serialize_tuple3_map<V, S>(
492 map: &HashMap<(String, String, String), V>,
493 serializer: S,
494) -> Result<S::Ok, S::Error>
495where
496 V: Serialize,
497 S: serde::Serializer,
498{
499 use serde::ser::SerializeMap;
500 let mut ser_map = serializer.serialize_map(Some(map.len()))?;
501 for ((a, b, c), v) in map {
502 ser_map.serialize_entry(&format!("{}:{}:{}", a, b, c), v)?;
503 }
504 ser_map.end()
505}
506
507fn deserialize_tuple3_map<'de, V, D>(
508 deserializer: D,
509) -> Result<HashMap<(String, String, String), V>, D::Error>
510where
511 V: Deserialize<'de>,
512 D: serde::Deserializer<'de>,
513{
514 use serde::de::Error;
515 let string_map: HashMap<String, V> = HashMap::deserialize(deserializer)?;
516 let mut result = HashMap::new();
517 for (k, v) in string_map {
518 let parts: Vec<&str> = k.splitn(3, ':').collect();
519 if parts.len() != 3 {
520 return Err(D::Error::custom(format!("invalid tuple3 key: {}", k)));
521 }
522 result.insert(
523 (
524 parts[0].to_string(),
525 parts[1].to_string(),
526 parts[2].to_string(),
527 ),
528 v,
529 );
530 }
531 Ok(result)
532}
533
534fn serialize_tuple4_map<V, S>(
535 map: &HashMap<(String, String, String, String), V>,
536 serializer: S,
537) -> Result<S::Ok, S::Error>
538where
539 V: Serialize,
540 S: serde::Serializer,
541{
542 use serde::ser::SerializeMap;
543 let mut ser_map = serializer.serialize_map(Some(map.len()))?;
544 for ((a, b, c, d), v) in map {
545 ser_map.serialize_entry(&format!("{}:{}:{}:{}", a, b, c, d), v)?;
546 }
547 ser_map.end()
548}
549
550fn deserialize_tuple4_map<'de, V, D>(
551 deserializer: D,
552) -> Result<HashMap<(String, String, String, String), V>, D::Error>
553where
554 V: Deserialize<'de>,
555 D: serde::Deserializer<'de>,
556{
557 use serde::de::Error;
558 let string_map: HashMap<String, V> = HashMap::deserialize(deserializer)?;
559 let mut result = HashMap::new();
560 for (k, v) in string_map {
561 let parts: Vec<&str> = k.splitn(4, ':').collect();
562 if parts.len() != 4 {
563 return Err(D::Error::custom(format!("invalid tuple4 key: {}", k)));
564 }
565 result.insert(
566 (
567 parts[0].to_string(),
568 parts[1].to_string(),
569 parts[2].to_string(),
570 parts[3].to_string(),
571 ),
572 v,
573 );
574 }
575 Ok(result)
576}