Skip to main content

swarm_engine_core/learn/
provider.rs

1//! LearnedProvider - 学習データへの抽象アクセス
2//!
3//! Swarm 全体で学習済みデータにアクセスするための抽象層。
4//! Selection、Orchestrator、Worker など様々なコンポーネントから利用可能。
5//!
6//! # 設計思想
7//!
8//! - **Model**: 計算ロジックを持ち、スコアを事前計算して保持
9//! - **Provider**: Model の結果を返すだけ(計算ロジックを持たない)
10
11use std::collections::HashMap;
12use std::sync::Arc;
13
14use super::stats::LearnStats;
15use super::stats_model::ScoreModel;
16
17// ============================================================================
18// LearningQuery - 学習データへのクエリ
19// ============================================================================
20
21/// 学習データへのクエリ
22///
23/// target を含む全てのコンテキストを統一的に扱う。
24#[derive(Debug, Clone)]
25pub enum LearningQuery<'a> {
26    /// 遷移ボーナス (prev → action @ target)
27    ///
28    /// 成功エピソードで頻出する遷移は正の値、失敗エピソードで頻出は負の値。
29    Transition {
30        prev: &'a str,
31        action: &'a str,
32        target: Option<&'a str>,
33    },
34
35    /// コンテキスト条件付きボーナス
36    ///
37    /// prev_action との組み合わせでの成功率補正。
38    Contextual {
39        prev: &'a str,
40        action: &'a str,
41        target: Option<&'a str>,
42    },
43
44    /// N-gram ボーナス(3-gram パターンの価値)
45    Ngram {
46        prev_prev: &'a str,
47        prev: &'a str,
48        action: &'a str,
49        target: Option<&'a str>,
50    },
51
52    /// Confidence スコア(コンテキスト付き学習ボーナス)
53    ///
54    /// prev / prev_prev がある場合、Transition + Contextual + Ngram を統合。
55    /// ない場合は action 単体の平均値を返す。
56    Confidence {
57        action: &'a str,
58        target: Option<&'a str>,
59        /// 直前のアクション(任意)
60        prev: Option<&'a str>,
61        /// 2つ前のアクション(任意、N-gram 用)
62        prev_prev: Option<&'a str>,
63    },
64}
65
66impl<'a> LearningQuery<'a> {
67    /// Transition クエリを作成
68    pub fn transition(prev: &'a str, action: &'a str, target: Option<&'a str>) -> Self {
69        Self::Transition {
70            prev,
71            action,
72            target,
73        }
74    }
75
76    /// Contextual クエリを作成
77    pub fn contextual(prev: &'a str, action: &'a str, target: Option<&'a str>) -> Self {
78        Self::Contextual {
79            prev,
80            action,
81            target,
82        }
83    }
84
85    /// Ngram クエリを作成
86    pub fn ngram(
87        prev_prev: &'a str,
88        prev: &'a str,
89        action: &'a str,
90        target: Option<&'a str>,
91    ) -> Self {
92        Self::Ngram {
93            prev_prev,
94            prev,
95            action,
96            target,
97        }
98    }
99
100    /// Confidence クエリを作成(コンテキストなし)
101    pub fn confidence(action: &'a str, target: Option<&'a str>) -> Self {
102        Self::Confidence {
103            action,
104            target,
105            prev: None,
106            prev_prev: None,
107        }
108    }
109
110    /// Confidence クエリを作成(コンテキスト付き)
111    ///
112    /// prev / prev_prev を渡すと、Transition + Contextual + Ngram を統合計算。
113    pub fn confidence_with_context(
114        action: &'a str,
115        target: Option<&'a str>,
116        prev: Option<&'a str>,
117        prev_prev: Option<&'a str>,
118    ) -> Self {
119        Self::Confidence {
120            action,
121            target,
122            prev,
123            prev_prev,
124        }
125    }
126}
127
128// ============================================================================
129// LearningResult - クエリ結果
130// ============================================================================
131
132/// クエリ結果
133#[derive(Debug, Clone, PartialEq)]
134#[derive(Default)]
135pub enum LearningResult {
136    /// スコア値
137    Score(f64),
138    /// データなし(事前確率を使用すべき)
139    #[default]
140    NotAvailable,
141}
142
143impl LearningResult {
144    /// スコアを取得(NotAvailable 時はデフォルト値)
145    pub fn score_or(&self, default: f64) -> f64 {
146        match self {
147            Self::Score(v) => *v,
148            Self::NotAvailable => default,
149        }
150    }
151
152    /// スコアを取得(NotAvailable 時は 0.0)
153    pub fn score(&self) -> f64 {
154        self.score_or(0.0)
155    }
156
157    /// データが存在するか
158    pub fn is_available(&self) -> bool {
159        matches!(self, Self::Score(_))
160    }
161}
162
163
164// ============================================================================
165// LearnedProvider trait
166// ============================================================================
167
168/// 学習済みデータへのアクセス Provider
169///
170/// Swarm の各コンポーネントから利用される統一的なインターフェース。
171/// 計算ロジックは持たない。Model から結果を取得して返すだけ。
172pub trait LearnedProvider: Send + Sync {
173    /// クエリを実行してボーナス/スコアを取得
174    fn query(&self, q: LearningQuery<'_>) -> LearningResult;
175
176    /// 内部の LearnStats を取得(永続化用、実装がある場合のみ)
177    fn stats(&self) -> Option<&LearnStats> {
178        None
179    }
180
181    /// 内部の ScoreModel を取得(実装がある場合のみ)
182    fn model(&self) -> Option<&ScoreModel> {
183        None
184    }
185}
186
187/// Provider の共有参照型
188pub type SharedLearnedProvider = Arc<dyn LearnedProvider>;
189
190// ============================================================================
191// ScoreModelProvider - Model ベースの Provider
192// ============================================================================
193
194/// ScoreModel を使った Provider 実装
195///
196/// Model が事前計算したスコアを返すだけ。計算ロジックを持たない。
197pub struct ScoreModelProvider {
198    model: ScoreModel,
199    stats: Option<LearnStats>,
200}
201
202impl ScoreModelProvider {
203    /// Model から Provider を作成
204    pub fn new(model: ScoreModel) -> Self {
205        Self { model, stats: None }
206    }
207
208    /// LearnStats から Provider を作成(Model を自動構築)
209    pub fn from_stats(stats: LearnStats) -> Self {
210        let model = ScoreModel::from_stats(&stats);
211        Self {
212            model,
213            stats: Some(stats),
214        }
215    }
216
217    /// Model を取得
218    pub fn inner(&self) -> &ScoreModel {
219        &self.model
220    }
221
222    /// Model を更新
223    pub fn update_model(&mut self, model: ScoreModel) {
224        self.model = model;
225    }
226}
227
228impl LearnedProvider for ScoreModelProvider {
229    fn query(&self, q: LearningQuery<'_>) -> LearningResult {
230        match q {
231            LearningQuery::Transition {
232                prev,
233                action,
234                target,
235            } => match self.model.transition(prev, action, target) {
236                Some(score) => LearningResult::Score(score),
237                None => LearningResult::NotAvailable,
238            },
239
240            LearningQuery::Contextual {
241                prev,
242                action,
243                target,
244            } => match self.model.contextual(prev, action, target) {
245                Some(score) => LearningResult::Score(score),
246                None => LearningResult::NotAvailable,
247            },
248
249            LearningQuery::Ngram {
250                prev_prev,
251                prev,
252                action,
253                target,
254            } => match self.model.ngram(prev_prev, prev, action, target) {
255                Some(score) => LearningResult::Score(score),
256                None => LearningResult::NotAvailable,
257            },
258
259            LearningQuery::Confidence {
260                action,
261                target,
262                prev,
263                prev_prev,
264            } => match self.model.confidence(action, target, prev, prev_prev) {
265                Some(score) => LearningResult::Score(score),
266                None => LearningResult::NotAvailable,
267            },
268        }
269    }
270
271    fn stats(&self) -> Option<&LearnStats> {
272        self.stats.as_ref()
273    }
274
275    fn model(&self) -> Option<&ScoreModel> {
276        Some(&self.model)
277    }
278}
279
280// ============================================================================
281// NullProvider
282// ============================================================================
283
284/// ボーナスを返さない Null Provider
285#[derive(Debug, Clone, Default)]
286pub struct NullProvider;
287
288impl LearnedProvider for NullProvider {
289    fn query(&self, _q: LearningQuery<'_>) -> LearningResult {
290        LearningResult::NotAvailable
291    }
292}
293
294// ============================================================================
295// ConfidenceMapProvider
296// ============================================================================
297
298/// HashMap<String, f64> ベースの静的 Provider
299///
300/// DependencyGraph から生成された confidence_map を使用。
301#[derive(Debug, Clone, Default)]
302pub struct ConfidenceMapProvider {
303    confidence: HashMap<String, f64>,
304}
305
306impl ConfidenceMapProvider {
307    pub fn new(confidence: HashMap<String, f64>) -> Self {
308        Self { confidence }
309    }
310
311    /// confidence を取得
312    pub fn get(&self, action: &str) -> Option<f64> {
313        self.confidence.get(action).copied()
314    }
315}
316
317impl LearnedProvider for ConfidenceMapProvider {
318    fn query(&self, q: LearningQuery<'_>) -> LearningResult {
319        match q {
320            LearningQuery::Confidence { action, .. } => {
321                match self.get(action) {
322                    Some(c) => {
323                        // confidence - 0.5 を返して、0.5 を中立点とする
324                        LearningResult::Score(c - 0.5)
325                    }
326                    None => LearningResult::NotAvailable,
327                }
328            }
329
330            // ConfidenceMapProvider は他のクエリには対応しない
331            _ => LearningResult::NotAvailable,
332        }
333    }
334}
335
336// ============================================================================
337// LearnStatsProvider - Stats + Model を一体管理
338// ============================================================================
339
340/// LearnStats と ScoreModel を一体管理する Provider
341///
342/// Stats を更新すると自動的に Model も再構築される。
343/// `rebuild_model()` の呼び忘れを防止。
344pub struct LearnStatsProvider {
345    stats: LearnStats,
346    model: ScoreModel,
347}
348
349impl LearnStatsProvider {
350    pub fn new(stats: LearnStats) -> Self {
351        let model = ScoreModel::from_stats(&stats);
352        Self { stats, model }
353    }
354
355    /// Stats を取得(読み取り専用)
356    pub fn stats(&self) -> &LearnStats {
357        &self.stats
358    }
359
360    /// Model を取得
361    pub fn model(&self) -> &ScoreModel {
362        &self.model
363    }
364
365    /// Stats を更新し、Model を自動再構築
366    ///
367    /// Stats を直接変更したい場合はこのメソッドを使用。
368    /// Model の再構築忘れを防止。
369    pub fn update_stats<F>(&mut self, f: F)
370    where
371        F: FnOnce(&mut LearnStats),
372    {
373        f(&mut self.stats);
374        self.model = ScoreModel::from_stats(&self.stats);
375    }
376
377    /// Stats を置換し、Model を自動再構築
378    pub fn replace_stats(&mut self, stats: LearnStats) {
379        self.stats = stats;
380        self.model = ScoreModel::from_stats(&self.stats);
381    }
382}
383
384impl LearnedProvider for LearnStatsProvider {
385    fn query(&self, q: LearningQuery<'_>) -> LearningResult {
386        match q {
387            LearningQuery::Transition {
388                prev,
389                action,
390                target,
391            } => match self.model.transition(prev, action, target) {
392                Some(score) => LearningResult::Score(score),
393                None => LearningResult::NotAvailable,
394            },
395
396            LearningQuery::Contextual {
397                prev,
398                action,
399                target,
400            } => match self.model.contextual(prev, action, target) {
401                Some(score) => LearningResult::Score(score),
402                None => LearningResult::NotAvailable,
403            },
404
405            LearningQuery::Ngram {
406                prev_prev,
407                prev,
408                action,
409                target,
410            } => match self.model.ngram(prev_prev, prev, action, target) {
411                Some(score) => LearningResult::Score(score),
412                None => LearningResult::NotAvailable,
413            },
414
415            LearningQuery::Confidence {
416                action,
417                target,
418                prev,
419                prev_prev,
420            } => match self.model.confidence(action, target, prev, prev_prev) {
421                Some(score) => LearningResult::Score(score),
422                None => LearningResult::NotAvailable,
423            },
424        }
425    }
426
427    fn stats(&self) -> Option<&LearnStats> {
428        Some(&self.stats)
429    }
430
431    fn model(&self) -> Option<&ScoreModel> {
432        Some(&self.model)
433    }
434}
435
436// ============================================================================
437// Tests
438// ============================================================================
439
440#[cfg(test)]
441mod tests {
442    use super::*;
443
444    #[test]
445    fn test_learning_result_score_or() {
446        assert_eq!(LearningResult::Score(0.5).score_or(0.0), 0.5);
447        assert_eq!(LearningResult::NotAvailable.score_or(0.0), 0.0);
448        assert_eq!(LearningResult::NotAvailable.score_or(-1.0), -1.0);
449    }
450
451    #[test]
452    fn test_learning_result_is_available() {
453        assert!(LearningResult::Score(0.5).is_available());
454        assert!(!LearningResult::NotAvailable.is_available());
455    }
456
457    #[test]
458    fn test_null_provider() {
459        let provider = NullProvider;
460
461        assert_eq!(
462            provider.query(LearningQuery::transition("A", "B", None)),
463            LearningResult::NotAvailable
464        );
465        assert_eq!(
466            provider.query(LearningQuery::contextual("A", "B", Some("svc1"))),
467            LearningResult::NotAvailable
468        );
469        assert_eq!(
470            provider.query(LearningQuery::ngram("A", "B", "C", None)),
471            LearningResult::NotAvailable
472        );
473    }
474
475    #[test]
476    fn test_confidence_map_provider() {
477        let mut map = HashMap::new();
478        map.insert("grep".to_string(), 0.8);
479        map.insert("restart".to_string(), 0.3);
480
481        let provider = ConfidenceMapProvider::new(map);
482
483        // Confidence クエリには対応
484        let result = provider.query(LearningQuery::confidence("grep", None));
485        let score = result.score();
486        assert!((score - 0.3).abs() < 1e-10, "expected ~0.3, got {}", score); // 0.8 - 0.5
487
488        let result = provider.query(LearningQuery::confidence("restart", None));
489        let score = result.score();
490        assert!(
491            (score - (-0.2)).abs() < 1e-10,
492            "expected ~-0.2, got {}",
493            score
494        ); // 0.3 - 0.5
495
496        // 存在しないアクション
497        let result = provider.query(LearningQuery::confidence("unknown", None));
498        assert_eq!(result, LearningResult::NotAvailable);
499
500        // 他のクエリには対応しない
501        let result = provider.query(LearningQuery::transition("A", "B", None));
502        assert_eq!(result, LearningResult::NotAvailable);
503    }
504
505    #[test]
506    fn test_learning_query_constructors() {
507        let q = LearningQuery::transition("A", "B", Some("svc1"));
508        assert!(matches!(
509            q,
510            LearningQuery::Transition {
511                prev: "A",
512                action: "B",
513                target: Some("svc1")
514            }
515        ));
516
517        let q = LearningQuery::ngram("A", "B", "C", None);
518        assert!(matches!(
519            q,
520            LearningQuery::Ngram {
521                prev_prev: "A",
522                prev: "B",
523                action: "C",
524                target: None
525            }
526        ));
527
528        let q =
529            LearningQuery::confidence_with_context("action", None, Some("prev"), Some("prev_prev"));
530        assert!(matches!(
531            q,
532            LearningQuery::Confidence {
533                action: "action",
534                prev: Some("prev"),
535                prev_prev: Some("prev_prev"),
536                ..
537            }
538        ));
539    }
540
541    #[test]
542    fn test_score_model_provider() {
543        use crate::learn::stats::{ContextualActionStats, LearnStats};
544
545        let mut stats = LearnStats::default();
546
547        // テストデータ
548        stats
549            .episode_transitions
550            .success_transitions
551            .insert(("A".to_string(), "B".to_string()), 10);
552        stats
553            .episode_transitions
554            .failure_transitions
555            .insert(("A".to_string(), "B".to_string()), 2);
556        stats.contextual_stats.insert(
557            ("A".to_string(), "B".to_string()),
558            ContextualActionStats {
559                visits: 12,
560                successes: 10,
561                failures: 2,
562            },
563        );
564        stats
565            .ngram_stats
566            .trigrams
567            .insert(("X".to_string(), "A".to_string(), "B".to_string()), (9, 1));
568
569        let provider = ScoreModelProvider::from_stats(stats);
570
571        // Transition クエリ
572        let result = provider.query(LearningQuery::transition("A", "B", None));
573        assert!(result.is_available());
574
575        // Contextual クエリ
576        let result = provider.query(LearningQuery::contextual("A", "B", None));
577        assert!(result.is_available());
578        assert!(result.score() > 0.0, "成功率が高いので正のスコア");
579
580        // Ngram クエリ
581        let result = provider.query(LearningQuery::ngram("X", "A", "B", None));
582        assert!(result.is_available());
583
584        // Confidence クエリ(コンテキスト付き)
585        let result = provider.query(LearningQuery::confidence_with_context(
586            "B",
587            None,
588            Some("A"),
589            Some("X"),
590        ));
591        assert!(result.is_available());
592    }
593}