Skip to main content

swarm_engine_core/exploration/selection/
thompson.rs

1//! Thompson Selection - Thompson Sampling 選択
2//!
3//! Beta 分布の期待値に基づいて選択する bandit アルゴリズム。
4//! 学習ボーナスは Provider から取得(コンテキスト付き)。
5
6use std::fmt::Debug;
7
8use super::SelectionLogic;
9use crate::exploration::map::{ExplorationMap, GraphMap, MapNodeId, MapState};
10use crate::exploration::mutation::ActionExtractor;
11use crate::learn::{LearnedProvider, LearningQuery};
12use crate::online_stats::SwarmStats;
13
14/// Thompson Sampling 選択
15///
16/// Beta(successes + 1, failures + 1) の期待値 + 学習ボーナス でスコアリング。
17///
18/// 本来の Thompson Sampling は Beta 分布からサンプリングするが、
19/// 決定論的にするため期待値を使用。
20#[derive(Debug, Clone)]
21pub struct Thompson {
22    /// 学習ボーナス係数
23    learning_weight: f64,
24}
25
26impl Default for Thompson {
27    fn default() -> Self {
28        Self {
29            learning_weight: 0.3,
30        }
31    }
32}
33
34impl Thompson {
35    pub fn new() -> Self {
36        Self::default()
37    }
38
39    /// 学習ボーナス係数を指定して生成
40    pub fn with_learning_weight(learning_weight: f64) -> Self {
41        Self { learning_weight }
42    }
43
44    /// 学習ボーナス係数を取得
45    pub fn learning_weight(&self) -> f64 {
46        self.learning_weight
47    }
48
49    /// Thompson スコアを計算(コンテキストなし、後方互換用)
50    pub fn compute_score(
51        &self,
52        stats: &SwarmStats,
53        action: &str,
54        target: Option<&str>,
55        provider: &dyn LearnedProvider,
56    ) -> f64 {
57        self.compute_score_with_context(stats, action, target, provider, None, None)
58    }
59
60    /// Thompson スコアを計算(コンテキスト付き)
61    ///
62    /// Provider に confidence_with_context クエリを投げて学習ボーナスを取得。
63    pub fn compute_score_with_context(
64        &self,
65        stats: &SwarmStats,
66        action: &str,
67        target: Option<&str>,
68        provider: &dyn LearnedProvider,
69        prev_action: Option<&str>,
70        prev_prev_action: Option<&str>,
71    ) -> f64 {
72        let action_stats = match target {
73            Some(t) => stats.get_action_target_stats(action, t),
74            None => stats.get_action_stats(action),
75        };
76
77        // Beta 分布の期待値: α / (α + β)
78        let alpha = action_stats.successes as f64 + 1.0;
79        let beta = action_stats.failures as f64 + 1.0;
80        let base_score = alpha / (alpha + beta);
81
82        // Provider から学習ボーナスを取得(コンテキスト付き)
83        let learning_bonus = provider
84            .query(LearningQuery::confidence_with_context(
85                action,
86                target,
87                prev_action,
88                prev_prev_action,
89            ))
90            .score();
91
92        base_score + learning_bonus * self.learning_weight
93    }
94}
95
96impl<N, E, S> SelectionLogic<N, E, S> for Thompson
97where
98    N: Debug + Clone + ActionExtractor,
99    E: Debug + Clone,
100    S: MapState,
101{
102    fn next(
103        &self,
104        map: &GraphMap<N, E, S>,
105        stats: &SwarmStats,
106        provider: &dyn LearnedProvider,
107    ) -> Option<MapNodeId> {
108        self.select(map, 1, stats, provider).into_iter().next()
109    }
110
111    fn select(
112        &self,
113        map: &GraphMap<N, E, S>,
114        count: usize,
115        stats: &SwarmStats,
116        provider: &dyn LearnedProvider,
117    ) -> Vec<MapNodeId> {
118        let frontiers = map.frontiers();
119        if frontiers.is_empty() || count == 0 {
120            return vec![];
121        }
122
123        // Thompson スコアでソート(親ノードの情報を活用)
124        let mut scored: Vec<_> = frontiers
125            .iter()
126            .filter_map(|&id| {
127                map.get(id).map(|node| {
128                    let (action, target) = node.data.extract();
129
130                    // 親ノードのアクションを取得(学習ボーナス用)
131                    let parent_id = map.parent(id);
132                    let prev_action = parent_id.and_then(|pid| {
133                        map.get(pid)
134                            .map(|parent_node| parent_node.data.action_name().to_string())
135                    });
136
137                    // 親の親のアクションを取得(N-gram ボーナス用)
138                    let prev_prev_action =
139                        parent_id.and_then(|pid| map.parent(pid)).and_then(|ppid| {
140                            map.get(ppid).map(|grandparent_node| {
141                                grandparent_node.data.action_name().to_string()
142                            })
143                        });
144
145                    let score = self.compute_score_with_context(
146                        stats,
147                        action,
148                        target,
149                        provider,
150                        prev_action.as_deref(),
151                        prev_prev_action.as_deref(),
152                    );
153                    (id, score)
154                })
155            })
156            .collect();
157
158        // スコア降順でソート(高スコアを優先)
159        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
160
161        scored.into_iter().take(count).map(|(id, _)| id).collect()
162    }
163
164    fn score(
165        &self,
166        action: &str,
167        target: Option<&str>,
168        stats: &SwarmStats,
169        provider: &dyn LearnedProvider,
170    ) -> f64 {
171        self.compute_score(stats, action, target, provider)
172    }
173
174    fn name(&self) -> &str {
175        "Thompson"
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182    use crate::events::{ActionEventBuilder, ActionEventResult};
183    use crate::learn::NullProvider;
184    use crate::types::WorkerId;
185
186    fn record_success(stats: &mut SwarmStats, action: &str) {
187        let event = ActionEventBuilder::new(0, WorkerId(0), action)
188            .result(ActionEventResult::success())
189            .build();
190        stats.record(&event);
191    }
192
193    fn record_failure(stats: &mut SwarmStats, action: &str) {
194        let event = ActionEventBuilder::new(0, WorkerId(0), action)
195            .result(ActionEventResult::failure("error"))
196            .build();
197        stats.record(&event);
198    }
199
200    #[test]
201    fn test_thompson_initial_score() {
202        let thompson = Thompson::new();
203        let stats = SwarmStats::new();
204        let provider = NullProvider;
205
206        // Beta(1, 1) の期待値 = 0.5, bonus = 0
207        assert_eq!(thompson.compute_score(&stats, "grep", None, &provider), 0.5);
208    }
209
210    #[test]
211    fn test_thompson_score_increases_with_success() {
212        let thompson = Thompson::new();
213        let mut stats = SwarmStats::new();
214        let provider = NullProvider;
215
216        let initial = thompson.compute_score(&stats, "grep", None, &provider);
217
218        record_success(&mut stats, "grep");
219        let after_success = thompson.compute_score(&stats, "grep", None, &provider);
220
221        // 成功するとスコアが上がる
222        assert!(after_success > initial);
223    }
224
225    #[test]
226    fn test_thompson_score_decreases_with_failure() {
227        let thompson = Thompson::new();
228        let mut stats = SwarmStats::new();
229        let provider = NullProvider;
230
231        let initial = thompson.compute_score(&stats, "grep", None, &provider);
232
233        record_failure(&mut stats, "grep");
234        let after_failure = thompson.compute_score(&stats, "grep", None, &provider);
235
236        // 失敗するとスコアが下がる
237        assert!(after_failure < initial);
238    }
239
240    #[test]
241    fn test_thompson_with_learning_weight() {
242        let thompson = Thompson::with_learning_weight(0.5);
243        assert!((thompson.learning_weight() - 0.5).abs() < 1e-10);
244    }
245}