1use std::collections::HashMap;
12use std::sync::Arc;
13
14use super::stats::LearnStats;
15use super::stats_model::ScoreModel;
16
17#[derive(Debug, Clone)]
25pub enum LearningQuery<'a> {
26 Transition {
30 prev: &'a str,
31 action: &'a str,
32 target: Option<&'a str>,
33 },
34
35 Contextual {
39 prev: &'a str,
40 action: &'a str,
41 target: Option<&'a str>,
42 },
43
44 Ngram {
46 prev_prev: &'a str,
47 prev: &'a str,
48 action: &'a str,
49 target: Option<&'a str>,
50 },
51
52 Confidence {
57 action: &'a str,
58 target: Option<&'a str>,
59 prev: Option<&'a str>,
61 prev_prev: Option<&'a str>,
63 },
64}
65
66impl<'a> LearningQuery<'a> {
67 pub fn transition(prev: &'a str, action: &'a str, target: Option<&'a str>) -> Self {
69 Self::Transition {
70 prev,
71 action,
72 target,
73 }
74 }
75
76 pub fn contextual(prev: &'a str, action: &'a str, target: Option<&'a str>) -> Self {
78 Self::Contextual {
79 prev,
80 action,
81 target,
82 }
83 }
84
85 pub fn ngram(
87 prev_prev: &'a str,
88 prev: &'a str,
89 action: &'a str,
90 target: Option<&'a str>,
91 ) -> Self {
92 Self::Ngram {
93 prev_prev,
94 prev,
95 action,
96 target,
97 }
98 }
99
100 pub fn confidence(action: &'a str, target: Option<&'a str>) -> Self {
102 Self::Confidence {
103 action,
104 target,
105 prev: None,
106 prev_prev: None,
107 }
108 }
109
110 pub fn confidence_with_context(
114 action: &'a str,
115 target: Option<&'a str>,
116 prev: Option<&'a str>,
117 prev_prev: Option<&'a str>,
118 ) -> Self {
119 Self::Confidence {
120 action,
121 target,
122 prev,
123 prev_prev,
124 }
125 }
126}
127
128#[derive(Debug, Clone, PartialEq, Default)]
134pub enum LearningResult {
135 Score(f64),
137 #[default]
139 NotAvailable,
140}
141
142impl LearningResult {
143 pub fn score_or(&self, default: f64) -> f64 {
145 match self {
146 Self::Score(v) => *v,
147 Self::NotAvailable => default,
148 }
149 }
150
151 pub fn score(&self) -> f64 {
153 self.score_or(0.0)
154 }
155
156 pub fn is_available(&self) -> bool {
158 matches!(self, Self::Score(_))
159 }
160}
161
162pub trait LearnedProvider: Send + Sync {
171 fn query(&self, q: LearningQuery<'_>) -> LearningResult;
173
174 fn stats(&self) -> Option<&LearnStats> {
176 None
177 }
178
179 fn model(&self) -> Option<&ScoreModel> {
181 None
182 }
183}
184
185pub type SharedLearnedProvider = Arc<dyn LearnedProvider>;
187
188pub struct ScoreModelProvider {
196 model: ScoreModel,
197 stats: Option<LearnStats>,
198}
199
200impl ScoreModelProvider {
201 pub fn new(model: ScoreModel) -> Self {
203 Self { model, stats: None }
204 }
205
206 pub fn from_stats(stats: LearnStats) -> Self {
208 let model = ScoreModel::from_stats(&stats);
209 Self {
210 model,
211 stats: Some(stats),
212 }
213 }
214
215 pub fn inner(&self) -> &ScoreModel {
217 &self.model
218 }
219
220 pub fn update_model(&mut self, model: ScoreModel) {
222 self.model = model;
223 }
224}
225
226impl LearnedProvider for ScoreModelProvider {
227 fn query(&self, q: LearningQuery<'_>) -> LearningResult {
228 match q {
229 LearningQuery::Transition {
230 prev,
231 action,
232 target,
233 } => match self.model.transition(prev, action, target) {
234 Some(score) => LearningResult::Score(score),
235 None => LearningResult::NotAvailable,
236 },
237
238 LearningQuery::Contextual {
239 prev,
240 action,
241 target,
242 } => match self.model.contextual(prev, action, target) {
243 Some(score) => LearningResult::Score(score),
244 None => LearningResult::NotAvailable,
245 },
246
247 LearningQuery::Ngram {
248 prev_prev,
249 prev,
250 action,
251 target,
252 } => match self.model.ngram(prev_prev, prev, action, target) {
253 Some(score) => LearningResult::Score(score),
254 None => LearningResult::NotAvailable,
255 },
256
257 LearningQuery::Confidence {
258 action,
259 target,
260 prev,
261 prev_prev,
262 } => match self.model.confidence(action, target, prev, prev_prev) {
263 Some(score) => LearningResult::Score(score),
264 None => LearningResult::NotAvailable,
265 },
266 }
267 }
268
269 fn stats(&self) -> Option<&LearnStats> {
270 self.stats.as_ref()
271 }
272
273 fn model(&self) -> Option<&ScoreModel> {
274 Some(&self.model)
275 }
276}
277
278#[derive(Debug, Clone, Default)]
284pub struct NullProvider;
285
286impl LearnedProvider for NullProvider {
287 fn query(&self, _q: LearningQuery<'_>) -> LearningResult {
288 LearningResult::NotAvailable
289 }
290}
291
292#[derive(Debug, Clone, Default)]
300pub struct ConfidenceMapProvider {
301 confidence: HashMap<String, f64>,
302}
303
304impl ConfidenceMapProvider {
305 pub fn new(confidence: HashMap<String, f64>) -> Self {
306 Self { confidence }
307 }
308
309 pub fn get(&self, action: &str) -> Option<f64> {
311 self.confidence.get(action).copied()
312 }
313}
314
315impl LearnedProvider for ConfidenceMapProvider {
316 fn query(&self, q: LearningQuery<'_>) -> LearningResult {
317 match q {
318 LearningQuery::Confidence { action, .. } => {
319 match self.get(action) {
320 Some(c) => {
321 LearningResult::Score(c - 0.5)
323 }
324 None => LearningResult::NotAvailable,
325 }
326 }
327
328 _ => LearningResult::NotAvailable,
330 }
331 }
332}
333
334pub struct LearnStatsProvider {
343 stats: LearnStats,
344 model: ScoreModel,
345}
346
347impl LearnStatsProvider {
348 pub fn new(stats: LearnStats) -> Self {
349 let model = ScoreModel::from_stats(&stats);
350 Self { stats, model }
351 }
352
353 pub fn stats(&self) -> &LearnStats {
355 &self.stats
356 }
357
358 pub fn model(&self) -> &ScoreModel {
360 &self.model
361 }
362
363 pub fn update_stats<F>(&mut self, f: F)
368 where
369 F: FnOnce(&mut LearnStats),
370 {
371 f(&mut self.stats);
372 self.model = ScoreModel::from_stats(&self.stats);
373 }
374
375 pub fn replace_stats(&mut self, stats: LearnStats) {
377 self.stats = stats;
378 self.model = ScoreModel::from_stats(&self.stats);
379 }
380}
381
382impl LearnedProvider for LearnStatsProvider {
383 fn query(&self, q: LearningQuery<'_>) -> LearningResult {
384 match q {
385 LearningQuery::Transition {
386 prev,
387 action,
388 target,
389 } => match self.model.transition(prev, action, target) {
390 Some(score) => LearningResult::Score(score),
391 None => LearningResult::NotAvailable,
392 },
393
394 LearningQuery::Contextual {
395 prev,
396 action,
397 target,
398 } => match self.model.contextual(prev, action, target) {
399 Some(score) => LearningResult::Score(score),
400 None => LearningResult::NotAvailable,
401 },
402
403 LearningQuery::Ngram {
404 prev_prev,
405 prev,
406 action,
407 target,
408 } => match self.model.ngram(prev_prev, prev, action, target) {
409 Some(score) => LearningResult::Score(score),
410 None => LearningResult::NotAvailable,
411 },
412
413 LearningQuery::Confidence {
414 action,
415 target,
416 prev,
417 prev_prev,
418 } => match self.model.confidence(action, target, prev, prev_prev) {
419 Some(score) => LearningResult::Score(score),
420 None => LearningResult::NotAvailable,
421 },
422 }
423 }
424
425 fn stats(&self) -> Option<&LearnStats> {
426 Some(&self.stats)
427 }
428
429 fn model(&self) -> Option<&ScoreModel> {
430 Some(&self.model)
431 }
432}
433
434#[cfg(test)]
439mod tests {
440 use super::*;
441
442 #[test]
443 fn test_learning_result_score_or() {
444 assert_eq!(LearningResult::Score(0.5).score_or(0.0), 0.5);
445 assert_eq!(LearningResult::NotAvailable.score_or(0.0), 0.0);
446 assert_eq!(LearningResult::NotAvailable.score_or(-1.0), -1.0);
447 }
448
449 #[test]
450 fn test_learning_result_is_available() {
451 assert!(LearningResult::Score(0.5).is_available());
452 assert!(!LearningResult::NotAvailable.is_available());
453 }
454
455 #[test]
456 fn test_null_provider() {
457 let provider = NullProvider;
458
459 assert_eq!(
460 provider.query(LearningQuery::transition("A", "B", None)),
461 LearningResult::NotAvailable
462 );
463 assert_eq!(
464 provider.query(LearningQuery::contextual("A", "B", Some("svc1"))),
465 LearningResult::NotAvailable
466 );
467 assert_eq!(
468 provider.query(LearningQuery::ngram("A", "B", "C", None)),
469 LearningResult::NotAvailable
470 );
471 }
472
473 #[test]
474 fn test_confidence_map_provider() {
475 let mut map = HashMap::new();
476 map.insert("grep".to_string(), 0.8);
477 map.insert("restart".to_string(), 0.3);
478
479 let provider = ConfidenceMapProvider::new(map);
480
481 let result = provider.query(LearningQuery::confidence("grep", None));
483 let score = result.score();
484 assert!((score - 0.3).abs() < 1e-10, "expected ~0.3, got {}", score); let result = provider.query(LearningQuery::confidence("restart", None));
487 let score = result.score();
488 assert!(
489 (score - (-0.2)).abs() < 1e-10,
490 "expected ~-0.2, got {}",
491 score
492 ); let result = provider.query(LearningQuery::confidence("unknown", None));
496 assert_eq!(result, LearningResult::NotAvailable);
497
498 let result = provider.query(LearningQuery::transition("A", "B", None));
500 assert_eq!(result, LearningResult::NotAvailable);
501 }
502
503 #[test]
504 fn test_learning_query_constructors() {
505 let q = LearningQuery::transition("A", "B", Some("svc1"));
506 assert!(matches!(
507 q,
508 LearningQuery::Transition {
509 prev: "A",
510 action: "B",
511 target: Some("svc1")
512 }
513 ));
514
515 let q = LearningQuery::ngram("A", "B", "C", None);
516 assert!(matches!(
517 q,
518 LearningQuery::Ngram {
519 prev_prev: "A",
520 prev: "B",
521 action: "C",
522 target: None
523 }
524 ));
525
526 let q =
527 LearningQuery::confidence_with_context("action", None, Some("prev"), Some("prev_prev"));
528 assert!(matches!(
529 q,
530 LearningQuery::Confidence {
531 action: "action",
532 prev: Some("prev"),
533 prev_prev: Some("prev_prev"),
534 ..
535 }
536 ));
537 }
538
539 #[test]
540 fn test_score_model_provider() {
541 use crate::learn::stats::{ContextualActionStats, LearnStats};
542
543 let mut stats = LearnStats::default();
544
545 stats
547 .episode_transitions
548 .success_transitions
549 .insert(("A".to_string(), "B".to_string()), 10);
550 stats
551 .episode_transitions
552 .failure_transitions
553 .insert(("A".to_string(), "B".to_string()), 2);
554 stats.contextual_stats.insert(
555 ("A".to_string(), "B".to_string()),
556 ContextualActionStats {
557 visits: 12,
558 successes: 10,
559 failures: 2,
560 },
561 );
562 stats
563 .ngram_stats
564 .trigrams
565 .insert(("X".to_string(), "A".to_string(), "B".to_string()), (9, 1));
566
567 let provider = ScoreModelProvider::from_stats(stats);
568
569 let result = provider.query(LearningQuery::transition("A", "B", None));
571 assert!(result.is_available());
572
573 let result = provider.query(LearningQuery::contextual("A", "B", None));
575 assert!(result.is_available());
576 assert!(result.score() > 0.0, "成功率が高いので正のスコア");
577
578 let result = provider.query(LearningQuery::ngram("X", "A", "B", None));
580 assert!(result.is_available());
581
582 let result = provider.query(LearningQuery::confidence_with_context(
584 "B",
585 None,
586 Some("A"),
587 Some("X"),
588 ));
589 assert!(result.is_available());
590 }
591}