Skip to main content

swarm_engine_core/learn/
provider.rs

1//! LearnedProvider - 学習データへの抽象アクセス
2//!
3//! Swarm 全体で学習済みデータにアクセスするための抽象層。
4//! Selection、Orchestrator、Worker など様々なコンポーネントから利用可能。
5//!
6//! # 設計思想
7//!
8//! - **Model**: 計算ロジックを持ち、スコアを事前計算して保持
9//! - **Provider**: Model の結果を返すだけ(計算ロジックを持たない)
10
11use std::collections::HashMap;
12use std::sync::Arc;
13
14use super::stats::LearnStats;
15use super::stats_model::ScoreModel;
16
17// ============================================================================
18// LearningQuery - 学習データへのクエリ
19// ============================================================================
20
21/// 学習データへのクエリ
22///
23/// target を含む全てのコンテキストを統一的に扱う。
24#[derive(Debug, Clone)]
25pub enum LearningQuery<'a> {
26    /// 遷移ボーナス (prev → action @ target)
27    ///
28    /// 成功エピソードで頻出する遷移は正の値、失敗エピソードで頻出は負の値。
29    Transition {
30        prev: &'a str,
31        action: &'a str,
32        target: Option<&'a str>,
33    },
34
35    /// コンテキスト条件付きボーナス
36    ///
37    /// prev_action との組み合わせでの成功率補正。
38    Contextual {
39        prev: &'a str,
40        action: &'a str,
41        target: Option<&'a str>,
42    },
43
44    /// N-gram ボーナス(3-gram パターンの価値)
45    Ngram {
46        prev_prev: &'a str,
47        prev: &'a str,
48        action: &'a str,
49        target: Option<&'a str>,
50    },
51
52    /// Confidence スコア(コンテキスト付き学習ボーナス)
53    ///
54    /// prev / prev_prev がある場合、Transition + Contextual + Ngram を統合。
55    /// ない場合は action 単体の平均値を返す。
56    Confidence {
57        action: &'a str,
58        target: Option<&'a str>,
59        /// 直前のアクション(任意)
60        prev: Option<&'a str>,
61        /// 2つ前のアクション(任意、N-gram 用)
62        prev_prev: Option<&'a str>,
63    },
64}
65
66impl<'a> LearningQuery<'a> {
67    /// Transition クエリを作成
68    pub fn transition(prev: &'a str, action: &'a str, target: Option<&'a str>) -> Self {
69        Self::Transition {
70            prev,
71            action,
72            target,
73        }
74    }
75
76    /// Contextual クエリを作成
77    pub fn contextual(prev: &'a str, action: &'a str, target: Option<&'a str>) -> Self {
78        Self::Contextual {
79            prev,
80            action,
81            target,
82        }
83    }
84
85    /// Ngram クエリを作成
86    pub fn ngram(
87        prev_prev: &'a str,
88        prev: &'a str,
89        action: &'a str,
90        target: Option<&'a str>,
91    ) -> Self {
92        Self::Ngram {
93            prev_prev,
94            prev,
95            action,
96            target,
97        }
98    }
99
100    /// Confidence クエリを作成(コンテキストなし)
101    pub fn confidence(action: &'a str, target: Option<&'a str>) -> Self {
102        Self::Confidence {
103            action,
104            target,
105            prev: None,
106            prev_prev: None,
107        }
108    }
109
110    /// Confidence クエリを作成(コンテキスト付き)
111    ///
112    /// prev / prev_prev を渡すと、Transition + Contextual + Ngram を統合計算。
113    pub fn confidence_with_context(
114        action: &'a str,
115        target: Option<&'a str>,
116        prev: Option<&'a str>,
117        prev_prev: Option<&'a str>,
118    ) -> Self {
119        Self::Confidence {
120            action,
121            target,
122            prev,
123            prev_prev,
124        }
125    }
126}
127
128// ============================================================================
129// LearningResult - クエリ結果
130// ============================================================================
131
132/// クエリ結果
133#[derive(Debug, Clone, PartialEq, Default)]
134pub enum LearningResult {
135    /// スコア値
136    Score(f64),
137    /// データなし(事前確率を使用すべき)
138    #[default]
139    NotAvailable,
140}
141
142impl LearningResult {
143    /// スコアを取得(NotAvailable 時はデフォルト値)
144    pub fn score_or(&self, default: f64) -> f64 {
145        match self {
146            Self::Score(v) => *v,
147            Self::NotAvailable => default,
148        }
149    }
150
151    /// スコアを取得(NotAvailable 時は 0.0)
152    pub fn score(&self) -> f64 {
153        self.score_or(0.0)
154    }
155
156    /// データが存在するか
157    pub fn is_available(&self) -> bool {
158        matches!(self, Self::Score(_))
159    }
160}
161
162// ============================================================================
163// LearnedProvider trait
164// ============================================================================
165
166/// 学習済みデータへのアクセス Provider
167///
168/// Swarm の各コンポーネントから利用される統一的なインターフェース。
169/// 計算ロジックは持たない。Model から結果を取得して返すだけ。
170pub trait LearnedProvider: Send + Sync {
171    /// クエリを実行してボーナス/スコアを取得
172    fn query(&self, q: LearningQuery<'_>) -> LearningResult;
173
174    /// 内部の LearnStats を取得(永続化用、実装がある場合のみ)
175    fn stats(&self) -> Option<&LearnStats> {
176        None
177    }
178
179    /// 内部の ScoreModel を取得(実装がある場合のみ)
180    fn model(&self) -> Option<&ScoreModel> {
181        None
182    }
183}
184
185/// Provider の共有参照型
186pub type SharedLearnedProvider = Arc<dyn LearnedProvider>;
187
188// ============================================================================
189// ScoreModelProvider - Model ベースの Provider
190// ============================================================================
191
192/// ScoreModel を使った Provider 実装
193///
194/// Model が事前計算したスコアを返すだけ。計算ロジックを持たない。
195pub struct ScoreModelProvider {
196    model: ScoreModel,
197    stats: Option<LearnStats>,
198}
199
200impl ScoreModelProvider {
201    /// Model から Provider を作成
202    pub fn new(model: ScoreModel) -> Self {
203        Self { model, stats: None }
204    }
205
206    /// LearnStats から Provider を作成(Model を自動構築)
207    pub fn from_stats(stats: LearnStats) -> Self {
208        let model = ScoreModel::from_stats(&stats);
209        Self {
210            model,
211            stats: Some(stats),
212        }
213    }
214
215    /// Model を取得
216    pub fn inner(&self) -> &ScoreModel {
217        &self.model
218    }
219
220    /// Model を更新
221    pub fn update_model(&mut self, model: ScoreModel) {
222        self.model = model;
223    }
224}
225
226impl LearnedProvider for ScoreModelProvider {
227    fn query(&self, q: LearningQuery<'_>) -> LearningResult {
228        match q {
229            LearningQuery::Transition {
230                prev,
231                action,
232                target,
233            } => match self.model.transition(prev, action, target) {
234                Some(score) => LearningResult::Score(score),
235                None => LearningResult::NotAvailable,
236            },
237
238            LearningQuery::Contextual {
239                prev,
240                action,
241                target,
242            } => match self.model.contextual(prev, action, target) {
243                Some(score) => LearningResult::Score(score),
244                None => LearningResult::NotAvailable,
245            },
246
247            LearningQuery::Ngram {
248                prev_prev,
249                prev,
250                action,
251                target,
252            } => match self.model.ngram(prev_prev, prev, action, target) {
253                Some(score) => LearningResult::Score(score),
254                None => LearningResult::NotAvailable,
255            },
256
257            LearningQuery::Confidence {
258                action,
259                target,
260                prev,
261                prev_prev,
262            } => match self.model.confidence(action, target, prev, prev_prev) {
263                Some(score) => LearningResult::Score(score),
264                None => LearningResult::NotAvailable,
265            },
266        }
267    }
268
269    fn stats(&self) -> Option<&LearnStats> {
270        self.stats.as_ref()
271    }
272
273    fn model(&self) -> Option<&ScoreModel> {
274        Some(&self.model)
275    }
276}
277
278// ============================================================================
279// NullProvider
280// ============================================================================
281
282/// ボーナスを返さない Null Provider
283#[derive(Debug, Clone, Default)]
284pub struct NullProvider;
285
286impl LearnedProvider for NullProvider {
287    fn query(&self, _q: LearningQuery<'_>) -> LearningResult {
288        LearningResult::NotAvailable
289    }
290}
291
292// ============================================================================
293// ConfidenceMapProvider
294// ============================================================================
295
296/// HashMap<String, f64> ベースの静的 Provider
297///
298/// DependencyGraph から生成された confidence_map を使用。
299#[derive(Debug, Clone, Default)]
300pub struct ConfidenceMapProvider {
301    confidence: HashMap<String, f64>,
302}
303
304impl ConfidenceMapProvider {
305    pub fn new(confidence: HashMap<String, f64>) -> Self {
306        Self { confidence }
307    }
308
309    /// confidence を取得
310    pub fn get(&self, action: &str) -> Option<f64> {
311        self.confidence.get(action).copied()
312    }
313}
314
315impl LearnedProvider for ConfidenceMapProvider {
316    fn query(&self, q: LearningQuery<'_>) -> LearningResult {
317        match q {
318            LearningQuery::Confidence { action, .. } => {
319                match self.get(action) {
320                    Some(c) => {
321                        // confidence - 0.5 を返して、0.5 を中立点とする
322                        LearningResult::Score(c - 0.5)
323                    }
324                    None => LearningResult::NotAvailable,
325                }
326            }
327
328            // ConfidenceMapProvider は他のクエリには対応しない
329            _ => LearningResult::NotAvailable,
330        }
331    }
332}
333
334// ============================================================================
335// LearnStatsProvider - Stats + Model を一体管理
336// ============================================================================
337
338/// LearnStats と ScoreModel を一体管理する Provider
339///
340/// Stats を更新すると自動的に Model も再構築される。
341/// `rebuild_model()` の呼び忘れを防止。
342pub struct LearnStatsProvider {
343    stats: LearnStats,
344    model: ScoreModel,
345}
346
347impl LearnStatsProvider {
348    pub fn new(stats: LearnStats) -> Self {
349        let model = ScoreModel::from_stats(&stats);
350        Self { stats, model }
351    }
352
353    /// Stats を取得(読み取り専用)
354    pub fn stats(&self) -> &LearnStats {
355        &self.stats
356    }
357
358    /// Model を取得
359    pub fn model(&self) -> &ScoreModel {
360        &self.model
361    }
362
363    /// Stats を更新し、Model を自動再構築
364    ///
365    /// Stats を直接変更したい場合はこのメソッドを使用。
366    /// Model の再構築忘れを防止。
367    pub fn update_stats<F>(&mut self, f: F)
368    where
369        F: FnOnce(&mut LearnStats),
370    {
371        f(&mut self.stats);
372        self.model = ScoreModel::from_stats(&self.stats);
373    }
374
375    /// Stats を置換し、Model を自動再構築
376    pub fn replace_stats(&mut self, stats: LearnStats) {
377        self.stats = stats;
378        self.model = ScoreModel::from_stats(&self.stats);
379    }
380}
381
382impl LearnedProvider for LearnStatsProvider {
383    fn query(&self, q: LearningQuery<'_>) -> LearningResult {
384        match q {
385            LearningQuery::Transition {
386                prev,
387                action,
388                target,
389            } => match self.model.transition(prev, action, target) {
390                Some(score) => LearningResult::Score(score),
391                None => LearningResult::NotAvailable,
392            },
393
394            LearningQuery::Contextual {
395                prev,
396                action,
397                target,
398            } => match self.model.contextual(prev, action, target) {
399                Some(score) => LearningResult::Score(score),
400                None => LearningResult::NotAvailable,
401            },
402
403            LearningQuery::Ngram {
404                prev_prev,
405                prev,
406                action,
407                target,
408            } => match self.model.ngram(prev_prev, prev, action, target) {
409                Some(score) => LearningResult::Score(score),
410                None => LearningResult::NotAvailable,
411            },
412
413            LearningQuery::Confidence {
414                action,
415                target,
416                prev,
417                prev_prev,
418            } => match self.model.confidence(action, target, prev, prev_prev) {
419                Some(score) => LearningResult::Score(score),
420                None => LearningResult::NotAvailable,
421            },
422        }
423    }
424
425    fn stats(&self) -> Option<&LearnStats> {
426        Some(&self.stats)
427    }
428
429    fn model(&self) -> Option<&ScoreModel> {
430        Some(&self.model)
431    }
432}
433
434// ============================================================================
435// Tests
436// ============================================================================
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441
442    #[test]
443    fn test_learning_result_score_or() {
444        assert_eq!(LearningResult::Score(0.5).score_or(0.0), 0.5);
445        assert_eq!(LearningResult::NotAvailable.score_or(0.0), 0.0);
446        assert_eq!(LearningResult::NotAvailable.score_or(-1.0), -1.0);
447    }
448
449    #[test]
450    fn test_learning_result_is_available() {
451        assert!(LearningResult::Score(0.5).is_available());
452        assert!(!LearningResult::NotAvailable.is_available());
453    }
454
455    #[test]
456    fn test_null_provider() {
457        let provider = NullProvider;
458
459        assert_eq!(
460            provider.query(LearningQuery::transition("A", "B", None)),
461            LearningResult::NotAvailable
462        );
463        assert_eq!(
464            provider.query(LearningQuery::contextual("A", "B", Some("svc1"))),
465            LearningResult::NotAvailable
466        );
467        assert_eq!(
468            provider.query(LearningQuery::ngram("A", "B", "C", None)),
469            LearningResult::NotAvailable
470        );
471    }
472
473    #[test]
474    fn test_confidence_map_provider() {
475        let mut map = HashMap::new();
476        map.insert("grep".to_string(), 0.8);
477        map.insert("restart".to_string(), 0.3);
478
479        let provider = ConfidenceMapProvider::new(map);
480
481        // Confidence クエリには対応
482        let result = provider.query(LearningQuery::confidence("grep", None));
483        let score = result.score();
484        assert!((score - 0.3).abs() < 1e-10, "expected ~0.3, got {}", score); // 0.8 - 0.5
485
486        let result = provider.query(LearningQuery::confidence("restart", None));
487        let score = result.score();
488        assert!(
489            (score - (-0.2)).abs() < 1e-10,
490            "expected ~-0.2, got {}",
491            score
492        ); // 0.3 - 0.5
493
494        // 存在しないアクション
495        let result = provider.query(LearningQuery::confidence("unknown", None));
496        assert_eq!(result, LearningResult::NotAvailable);
497
498        // 他のクエリには対応しない
499        let result = provider.query(LearningQuery::transition("A", "B", None));
500        assert_eq!(result, LearningResult::NotAvailable);
501    }
502
503    #[test]
504    fn test_learning_query_constructors() {
505        let q = LearningQuery::transition("A", "B", Some("svc1"));
506        assert!(matches!(
507            q,
508            LearningQuery::Transition {
509                prev: "A",
510                action: "B",
511                target: Some("svc1")
512            }
513        ));
514
515        let q = LearningQuery::ngram("A", "B", "C", None);
516        assert!(matches!(
517            q,
518            LearningQuery::Ngram {
519                prev_prev: "A",
520                prev: "B",
521                action: "C",
522                target: None
523            }
524        ));
525
526        let q =
527            LearningQuery::confidence_with_context("action", None, Some("prev"), Some("prev_prev"));
528        assert!(matches!(
529            q,
530            LearningQuery::Confidence {
531                action: "action",
532                prev: Some("prev"),
533                prev_prev: Some("prev_prev"),
534                ..
535            }
536        ));
537    }
538
539    #[test]
540    fn test_score_model_provider() {
541        use crate::learn::stats::{ContextualActionStats, LearnStats};
542
543        let mut stats = LearnStats::default();
544
545        // テストデータ
546        stats
547            .episode_transitions
548            .success_transitions
549            .insert(("A".to_string(), "B".to_string()), 10);
550        stats
551            .episode_transitions
552            .failure_transitions
553            .insert(("A".to_string(), "B".to_string()), 2);
554        stats.contextual_stats.insert(
555            ("A".to_string(), "B".to_string()),
556            ContextualActionStats {
557                visits: 12,
558                successes: 10,
559                failures: 2,
560            },
561        );
562        stats
563            .ngram_stats
564            .trigrams
565            .insert(("X".to_string(), "A".to_string(), "B".to_string()), (9, 1));
566
567        let provider = ScoreModelProvider::from_stats(stats);
568
569        // Transition クエリ
570        let result = provider.query(LearningQuery::transition("A", "B", None));
571        assert!(result.is_available());
572
573        // Contextual クエリ
574        let result = provider.query(LearningQuery::contextual("A", "B", None));
575        assert!(result.is_available());
576        assert!(result.score() > 0.0, "成功率が高いので正のスコア");
577
578        // Ngram クエリ
579        let result = provider.query(LearningQuery::ngram("X", "A", "B", None));
580        assert!(result.is_available());
581
582        // Confidence クエリ(コンテキスト付き)
583        let result = provider.query(LearningQuery::confidence_with_context(
584            "B",
585            None,
586            Some("A"),
587            Some("X"),
588        ));
589        assert!(result.is_available());
590    }
591}