1use std::collections::HashMap;
12use std::sync::Arc;
13
14use super::stats::LearnStats;
15use super::stats_model::ScoreModel;
16
17#[derive(Debug, Clone)]
25pub enum LearningQuery<'a> {
26 Transition {
30 prev: &'a str,
31 action: &'a str,
32 target: Option<&'a str>,
33 },
34
35 Contextual {
39 prev: &'a str,
40 action: &'a str,
41 target: Option<&'a str>,
42 },
43
44 Ngram {
46 prev_prev: &'a str,
47 prev: &'a str,
48 action: &'a str,
49 target: Option<&'a str>,
50 },
51
52 Confidence {
57 action: &'a str,
58 target: Option<&'a str>,
59 prev: Option<&'a str>,
61 prev_prev: Option<&'a str>,
63 },
64}
65
66impl<'a> LearningQuery<'a> {
67 pub fn transition(prev: &'a str, action: &'a str, target: Option<&'a str>) -> Self {
69 Self::Transition {
70 prev,
71 action,
72 target,
73 }
74 }
75
76 pub fn contextual(prev: &'a str, action: &'a str, target: Option<&'a str>) -> Self {
78 Self::Contextual {
79 prev,
80 action,
81 target,
82 }
83 }
84
85 pub fn ngram(
87 prev_prev: &'a str,
88 prev: &'a str,
89 action: &'a str,
90 target: Option<&'a str>,
91 ) -> Self {
92 Self::Ngram {
93 prev_prev,
94 prev,
95 action,
96 target,
97 }
98 }
99
100 pub fn confidence(action: &'a str, target: Option<&'a str>) -> Self {
102 Self::Confidence {
103 action,
104 target,
105 prev: None,
106 prev_prev: None,
107 }
108 }
109
110 pub fn confidence_with_context(
114 action: &'a str,
115 target: Option<&'a str>,
116 prev: Option<&'a str>,
117 prev_prev: Option<&'a str>,
118 ) -> Self {
119 Self::Confidence {
120 action,
121 target,
122 prev,
123 prev_prev,
124 }
125 }
126}
127
128#[derive(Debug, Clone, PartialEq)]
134#[derive(Default)]
135pub enum LearningResult {
136 Score(f64),
138 #[default]
140 NotAvailable,
141}
142
143impl LearningResult {
144 pub fn score_or(&self, default: f64) -> f64 {
146 match self {
147 Self::Score(v) => *v,
148 Self::NotAvailable => default,
149 }
150 }
151
152 pub fn score(&self) -> f64 {
154 self.score_or(0.0)
155 }
156
157 pub fn is_available(&self) -> bool {
159 matches!(self, Self::Score(_))
160 }
161}
162
163
164pub trait LearnedProvider: Send + Sync {
173 fn query(&self, q: LearningQuery<'_>) -> LearningResult;
175
176 fn stats(&self) -> Option<&LearnStats> {
178 None
179 }
180
181 fn model(&self) -> Option<&ScoreModel> {
183 None
184 }
185}
186
187pub type SharedLearnedProvider = Arc<dyn LearnedProvider>;
189
190pub struct ScoreModelProvider {
198 model: ScoreModel,
199 stats: Option<LearnStats>,
200}
201
202impl ScoreModelProvider {
203 pub fn new(model: ScoreModel) -> Self {
205 Self { model, stats: None }
206 }
207
208 pub fn from_stats(stats: LearnStats) -> Self {
210 let model = ScoreModel::from_stats(&stats);
211 Self {
212 model,
213 stats: Some(stats),
214 }
215 }
216
217 pub fn inner(&self) -> &ScoreModel {
219 &self.model
220 }
221
222 pub fn update_model(&mut self, model: ScoreModel) {
224 self.model = model;
225 }
226}
227
228impl LearnedProvider for ScoreModelProvider {
229 fn query(&self, q: LearningQuery<'_>) -> LearningResult {
230 match q {
231 LearningQuery::Transition {
232 prev,
233 action,
234 target,
235 } => match self.model.transition(prev, action, target) {
236 Some(score) => LearningResult::Score(score),
237 None => LearningResult::NotAvailable,
238 },
239
240 LearningQuery::Contextual {
241 prev,
242 action,
243 target,
244 } => match self.model.contextual(prev, action, target) {
245 Some(score) => LearningResult::Score(score),
246 None => LearningResult::NotAvailable,
247 },
248
249 LearningQuery::Ngram {
250 prev_prev,
251 prev,
252 action,
253 target,
254 } => match self.model.ngram(prev_prev, prev, action, target) {
255 Some(score) => LearningResult::Score(score),
256 None => LearningResult::NotAvailable,
257 },
258
259 LearningQuery::Confidence {
260 action,
261 target,
262 prev,
263 prev_prev,
264 } => match self.model.confidence(action, target, prev, prev_prev) {
265 Some(score) => LearningResult::Score(score),
266 None => LearningResult::NotAvailable,
267 },
268 }
269 }
270
271 fn stats(&self) -> Option<&LearnStats> {
272 self.stats.as_ref()
273 }
274
275 fn model(&self) -> Option<&ScoreModel> {
276 Some(&self.model)
277 }
278}
279
280#[derive(Debug, Clone, Default)]
286pub struct NullProvider;
287
288impl LearnedProvider for NullProvider {
289 fn query(&self, _q: LearningQuery<'_>) -> LearningResult {
290 LearningResult::NotAvailable
291 }
292}
293
294#[derive(Debug, Clone, Default)]
302pub struct ConfidenceMapProvider {
303 confidence: HashMap<String, f64>,
304}
305
306impl ConfidenceMapProvider {
307 pub fn new(confidence: HashMap<String, f64>) -> Self {
308 Self { confidence }
309 }
310
311 pub fn get(&self, action: &str) -> Option<f64> {
313 self.confidence.get(action).copied()
314 }
315}
316
317impl LearnedProvider for ConfidenceMapProvider {
318 fn query(&self, q: LearningQuery<'_>) -> LearningResult {
319 match q {
320 LearningQuery::Confidence { action, .. } => {
321 match self.get(action) {
322 Some(c) => {
323 LearningResult::Score(c - 0.5)
325 }
326 None => LearningResult::NotAvailable,
327 }
328 }
329
330 _ => LearningResult::NotAvailable,
332 }
333 }
334}
335
336pub struct LearnStatsProvider {
345 stats: LearnStats,
346 model: ScoreModel,
347}
348
349impl LearnStatsProvider {
350 pub fn new(stats: LearnStats) -> Self {
351 let model = ScoreModel::from_stats(&stats);
352 Self { stats, model }
353 }
354
355 pub fn stats(&self) -> &LearnStats {
357 &self.stats
358 }
359
360 pub fn model(&self) -> &ScoreModel {
362 &self.model
363 }
364
365 pub fn update_stats<F>(&mut self, f: F)
370 where
371 F: FnOnce(&mut LearnStats),
372 {
373 f(&mut self.stats);
374 self.model = ScoreModel::from_stats(&self.stats);
375 }
376
377 pub fn replace_stats(&mut self, stats: LearnStats) {
379 self.stats = stats;
380 self.model = ScoreModel::from_stats(&self.stats);
381 }
382}
383
384impl LearnedProvider for LearnStatsProvider {
385 fn query(&self, q: LearningQuery<'_>) -> LearningResult {
386 match q {
387 LearningQuery::Transition {
388 prev,
389 action,
390 target,
391 } => match self.model.transition(prev, action, target) {
392 Some(score) => LearningResult::Score(score),
393 None => LearningResult::NotAvailable,
394 },
395
396 LearningQuery::Contextual {
397 prev,
398 action,
399 target,
400 } => match self.model.contextual(prev, action, target) {
401 Some(score) => LearningResult::Score(score),
402 None => LearningResult::NotAvailable,
403 },
404
405 LearningQuery::Ngram {
406 prev_prev,
407 prev,
408 action,
409 target,
410 } => match self.model.ngram(prev_prev, prev, action, target) {
411 Some(score) => LearningResult::Score(score),
412 None => LearningResult::NotAvailable,
413 },
414
415 LearningQuery::Confidence {
416 action,
417 target,
418 prev,
419 prev_prev,
420 } => match self.model.confidence(action, target, prev, prev_prev) {
421 Some(score) => LearningResult::Score(score),
422 None => LearningResult::NotAvailable,
423 },
424 }
425 }
426
427 fn stats(&self) -> Option<&LearnStats> {
428 Some(&self.stats)
429 }
430
431 fn model(&self) -> Option<&ScoreModel> {
432 Some(&self.model)
433 }
434}
435
436#[cfg(test)]
441mod tests {
442 use super::*;
443
444 #[test]
445 fn test_learning_result_score_or() {
446 assert_eq!(LearningResult::Score(0.5).score_or(0.0), 0.5);
447 assert_eq!(LearningResult::NotAvailable.score_or(0.0), 0.0);
448 assert_eq!(LearningResult::NotAvailable.score_or(-1.0), -1.0);
449 }
450
451 #[test]
452 fn test_learning_result_is_available() {
453 assert!(LearningResult::Score(0.5).is_available());
454 assert!(!LearningResult::NotAvailable.is_available());
455 }
456
457 #[test]
458 fn test_null_provider() {
459 let provider = NullProvider;
460
461 assert_eq!(
462 provider.query(LearningQuery::transition("A", "B", None)),
463 LearningResult::NotAvailable
464 );
465 assert_eq!(
466 provider.query(LearningQuery::contextual("A", "B", Some("svc1"))),
467 LearningResult::NotAvailable
468 );
469 assert_eq!(
470 provider.query(LearningQuery::ngram("A", "B", "C", None)),
471 LearningResult::NotAvailable
472 );
473 }
474
475 #[test]
476 fn test_confidence_map_provider() {
477 let mut map = HashMap::new();
478 map.insert("grep".to_string(), 0.8);
479 map.insert("restart".to_string(), 0.3);
480
481 let provider = ConfidenceMapProvider::new(map);
482
483 let result = provider.query(LearningQuery::confidence("grep", None));
485 let score = result.score();
486 assert!((score - 0.3).abs() < 1e-10, "expected ~0.3, got {}", score); let result = provider.query(LearningQuery::confidence("restart", None));
489 let score = result.score();
490 assert!(
491 (score - (-0.2)).abs() < 1e-10,
492 "expected ~-0.2, got {}",
493 score
494 ); let result = provider.query(LearningQuery::confidence("unknown", None));
498 assert_eq!(result, LearningResult::NotAvailable);
499
500 let result = provider.query(LearningQuery::transition("A", "B", None));
502 assert_eq!(result, LearningResult::NotAvailable);
503 }
504
505 #[test]
506 fn test_learning_query_constructors() {
507 let q = LearningQuery::transition("A", "B", Some("svc1"));
508 assert!(matches!(
509 q,
510 LearningQuery::Transition {
511 prev: "A",
512 action: "B",
513 target: Some("svc1")
514 }
515 ));
516
517 let q = LearningQuery::ngram("A", "B", "C", None);
518 assert!(matches!(
519 q,
520 LearningQuery::Ngram {
521 prev_prev: "A",
522 prev: "B",
523 action: "C",
524 target: None
525 }
526 ));
527
528 let q =
529 LearningQuery::confidence_with_context("action", None, Some("prev"), Some("prev_prev"));
530 assert!(matches!(
531 q,
532 LearningQuery::Confidence {
533 action: "action",
534 prev: Some("prev"),
535 prev_prev: Some("prev_prev"),
536 ..
537 }
538 ));
539 }
540
541 #[test]
542 fn test_score_model_provider() {
543 use crate::learn::stats::{ContextualActionStats, LearnStats};
544
545 let mut stats = LearnStats::default();
546
547 stats
549 .episode_transitions
550 .success_transitions
551 .insert(("A".to_string(), "B".to_string()), 10);
552 stats
553 .episode_transitions
554 .failure_transitions
555 .insert(("A".to_string(), "B".to_string()), 2);
556 stats.contextual_stats.insert(
557 ("A".to_string(), "B".to_string()),
558 ContextualActionStats {
559 visits: 12,
560 successes: 10,
561 failures: 2,
562 },
563 );
564 stats
565 .ngram_stats
566 .trigrams
567 .insert(("X".to_string(), "A".to_string(), "B".to_string()), (9, 1));
568
569 let provider = ScoreModelProvider::from_stats(stats);
570
571 let result = provider.query(LearningQuery::transition("A", "B", None));
573 assert!(result.is_available());
574
575 let result = provider.query(LearningQuery::contextual("A", "B", None));
577 assert!(result.is_available());
578 assert!(result.score() > 0.0, "成功率が高いので正のスコア");
579
580 let result = provider.query(LearningQuery::ngram("X", "A", "B", None));
582 assert!(result.is_available());
583
584 let result = provider.query(LearningQuery::confidence_with_context(
586 "B",
587 None,
588 Some("A"),
589 Some("X"),
590 ));
591 assert!(result.is_available());
592 }
593}