Skip to main content

swarm_engine_core/exploration/selection/
ucb1.rs

1//! UCB1 Selection - Upper Confidence Bound 選択
2//!
3//! 探索と活用のバランスを取る bandit アルゴリズム。
4//! 学習ボーナスは Provider から取得(コンテキスト付き)。
5
6use std::fmt::Debug;
7
8use super::SelectionLogic;
9use crate::exploration::map::{ExplorationMap, GraphMap, MapNodeId, MapState};
10use crate::exploration::mutation::ActionExtractor;
11use crate::learn::{LearnedProvider, LearningQuery};
12use crate::online_stats::SwarmStats;
13
14/// UCB1 選択
15///
16/// score = success_rate + c * sqrt(ln(total_visits) / visits) + learned_bonus
17///
18/// - 成功率が高いノードを優先(活用)
19/// - 訪問回数が少ないノードにボーナス(探索)
20/// - Provider から学習済みボーナスを加算
21#[derive(Debug, Clone)]
22pub struct Ucb1 {
23    /// 探索係数(大きいほど未探索を優先)
24    c: f64,
25    /// 学習ボーナス係数
26    learning_weight: f64,
27}
28
29impl Default for Ucb1 {
30    fn default() -> Self {
31        Self {
32            c: std::f64::consts::SQRT_2,
33            learning_weight: 0.3,
34        }
35    }
36}
37
38impl Ucb1 {
39    /// 指定した探索係数で生成
40    pub fn new(c: f64) -> Self {
41        Self {
42            c,
43            learning_weight: 0.3,
44        }
45    }
46
47    /// 探索係数と学習ボーナス係数を指定して生成
48    pub fn with_weights(c: f64, learning_weight: f64) -> Self {
49        Self { c, learning_weight }
50    }
51
52    /// デフォルトの探索係数(√2)で生成
53    pub fn with_default_c() -> Self {
54        Self::default()
55    }
56
57    /// 探索係数を取得
58    pub fn c(&self) -> f64 {
59        self.c
60    }
61
62    /// 学習ボーナス係数を取得
63    pub fn learning_weight(&self) -> f64 {
64        self.learning_weight
65    }
66
67    /// UCB1 スコアを計算(コンテキストなし、後方互換用)
68    pub fn compute_score(
69        &self,
70        stats: &SwarmStats,
71        action: &str,
72        target: Option<&str>,
73        provider: &dyn LearnedProvider,
74    ) -> f64 {
75        self.compute_score_with_context(stats, action, target, provider, None, None)
76    }
77
78    /// UCB1 スコアを計算(コンテキスト付き)
79    ///
80    /// Provider に confidence_with_context クエリを投げて学習ボーナスを取得。
81    pub fn compute_score_with_context(
82        &self,
83        stats: &SwarmStats,
84        action: &str,
85        target: Option<&str>,
86        provider: &dyn LearnedProvider,
87        prev_action: Option<&str>,
88        prev_prev_action: Option<&str>,
89    ) -> f64 {
90        let action_stats = match target {
91            Some(t) => stats.get_action_target_stats(action, t),
92            None => stats.get_action_stats(action),
93        };
94        let total = stats.total_visits().max(1);
95
96        if action_stats.visits == 0 {
97            return f64::INFINITY; // 未訪問は必ず選択
98        }
99
100        let success_rate = action_stats.success_rate();
101        let exploration = self.c * ((total as f64).ln() / action_stats.visits as f64).sqrt();
102
103        // Provider から学習ボーナスを取得(コンテキスト付き)
104        let learning_bonus = provider
105            .query(LearningQuery::confidence_with_context(
106                action,
107                target,
108                prev_action,
109                prev_prev_action,
110            ))
111            .score();
112
113        success_rate + exploration + learning_bonus * self.learning_weight
114    }
115}
116
117impl<N, E, S> SelectionLogic<N, E, S> for Ucb1
118where
119    N: Debug + Clone + ActionExtractor,
120    E: Debug + Clone,
121    S: MapState,
122{
123    fn next(
124        &self,
125        map: &GraphMap<N, E, S>,
126        stats: &SwarmStats,
127        provider: &dyn LearnedProvider,
128    ) -> Option<MapNodeId> {
129        self.select(map, 1, stats, provider).into_iter().next()
130    }
131
132    fn select(
133        &self,
134        map: &GraphMap<N, E, S>,
135        count: usize,
136        stats: &SwarmStats,
137        provider: &dyn LearnedProvider,
138    ) -> Vec<MapNodeId> {
139        let frontiers = map.frontiers();
140        if frontiers.is_empty() || count == 0 {
141            return vec![];
142        }
143
144        // UCB1 スコア計算してソート(親ノードの情報を活用)
145        let mut scored: Vec<_> = frontiers
146            .iter()
147            .filter_map(|&id| {
148                map.get(id).map(|node| {
149                    let (action, target) = node.data.extract();
150
151                    // 親ノードのアクションを取得(学習ボーナス用)
152                    let parent_id = map.parent(id);
153                    let prev_action = parent_id.and_then(|pid| {
154                        map.get(pid)
155                            .map(|parent_node| parent_node.data.action_name().to_string())
156                    });
157
158                    // 親の親のアクションを取得(N-gram ボーナス用)
159                    let prev_prev_action =
160                        parent_id.and_then(|pid| map.parent(pid)).and_then(|ppid| {
161                            map.get(ppid).map(|grandparent_node| {
162                                grandparent_node.data.action_name().to_string()
163                            })
164                        });
165
166                    let score = self.compute_score_with_context(
167                        stats,
168                        action,
169                        target,
170                        provider,
171                        prev_action.as_deref(),
172                        prev_prev_action.as_deref(),
173                    );
174                    (id, score)
175                })
176            })
177            .collect();
178
179        // スコア降順でソート(高スコアを優先)
180        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
181
182        scored.into_iter().take(count).map(|(id, _)| id).collect()
183    }
184
185    fn score(
186        &self,
187        action: &str,
188        target: Option<&str>,
189        stats: &SwarmStats,
190        provider: &dyn LearnedProvider,
191    ) -> f64 {
192        self.compute_score(stats, action, target, provider)
193    }
194
195    fn name(&self) -> &str {
196        "UCB1"
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203    use crate::events::ActionEventBuilder;
204    use crate::learn::NullProvider;
205
206    fn record_success(stats: &mut SwarmStats, action: &str, target: Option<&str>) {
207        use crate::events::ActionEventResult;
208        use crate::types::WorkerId;
209
210        let mut builder = ActionEventBuilder::new(0, WorkerId(0), action);
211        if let Some(t) = target {
212            builder = builder.target(t);
213        }
214        let event = builder.result(ActionEventResult::success()).build();
215        stats.record(&event);
216    }
217
218    #[test]
219    fn test_ucb1_unvisited_is_infinity() {
220        let ucb1 = Ucb1::new(1.41);
221        let stats = SwarmStats::new();
222        let provider = NullProvider;
223
224        // 未訪問ノードは INFINITY
225        assert!(ucb1
226            .compute_score(&stats, "grep", None, &provider)
227            .is_infinite());
228    }
229
230    #[test]
231    fn test_ucb1_score_changes_with_visits() {
232        let ucb1 = Ucb1::new(1.41);
233        let mut stats = SwarmStats::new();
234        let provider = NullProvider;
235
236        // 最初は INFINITY
237        let score1 = ucb1.compute_score(&stats, "grep", None, &provider);
238        assert!(score1.is_infinite());
239
240        // 1回訪問(ln(1) = 0 なので探索項は 0)
241        record_success(&mut stats, "grep", None);
242        let score2 = ucb1.compute_score(&stats, "grep", None, &provider);
243        assert!(score2.is_finite());
244        // success_rate = 1.0, exploration = c * sqrt(ln(1)/1) = 0, bonus = 0
245        assert!((score2 - 1.0).abs() < 0.01);
246
247        // 別のアクションを追加して total を増やす(explore の増加を確認)
248        record_success(&mut stats, "glob", None);
249        let score3 = ucb1.compute_score(&stats, "grep", None, &provider);
250        // total=2, visits=1 なので探索項が増加
251        assert!(score3 > score2);
252    }
253
254    #[test]
255    fn test_ucb1_default_c() {
256        let ucb1 = Ucb1::default();
257        assert!((ucb1.c() - std::f64::consts::SQRT_2).abs() < 1e-10);
258        assert!((ucb1.learning_weight() - 0.3).abs() < 1e-10);
259    }
260
261    #[test]
262    fn test_ucb1_with_weights() {
263        let ucb1 = Ucb1::with_weights(2.0, 0.5);
264        assert!((ucb1.c() - 2.0).abs() < 1e-10);
265        assert!((ucb1.learning_weight() - 0.5).abs() < 1e-10);
266    }
267}