Skip to main content

swarm_engine_core/exploration/selection/
greedy.rs

1//! Greedy Selection - LearnedProvider ボーナス最大優先選択
2//!
3//! Provider から得られた学習済みボーナスに基づいて選択する。
4
5use std::fmt::Debug;
6
7use super::SelectionLogic;
8use crate::exploration::map::{ExplorationMap, GraphMap, MapNodeId, MapState};
9use crate::exploration::mutation::ActionExtractor;
10use crate::learn::{LearnedProvider, LearningQuery};
11use crate::online_stats::SwarmStats;
12
13/// Greedy 選択(ボーナス最大を優先)
14///
15/// 探索を行わず、Provider からのボーナスが最も高いノードを選択する。
16/// 学習済みデータを活用する場合に有効。
17#[derive(Debug, Clone, Default)]
18pub struct Greedy {
19    /// ボーナスの重み(デフォルト: 1.0)
20    weight: f64,
21}
22
23impl Greedy {
24    pub fn new() -> Self {
25        Self { weight: 1.0 }
26    }
27
28    /// 重みを指定して生成
29    pub fn with_weight(weight: f64) -> Self {
30        Self { weight }
31    }
32
33    /// スコアを計算(コンテキストなし、後方互換用)
34    pub fn compute_score(
35        &self,
36        action: &str,
37        target: Option<&str>,
38        provider: &dyn LearnedProvider,
39    ) -> f64 {
40        self.compute_score_with_context(action, target, provider, None, None)
41    }
42
43    /// スコアを計算(コンテキスト付き)
44    ///
45    /// Provider に confidence_with_context クエリを投げてボーナスを取得。
46    pub fn compute_score_with_context(
47        &self,
48        action: &str,
49        target: Option<&str>,
50        provider: &dyn LearnedProvider,
51        prev_action: Option<&str>,
52        prev_prev_action: Option<&str>,
53    ) -> f64 {
54        // Provider からボーナス取得(コンテキスト付き)
55        let bonus = provider
56            .query(LearningQuery::confidence_with_context(
57                action,
58                target,
59                prev_action,
60                prev_prev_action,
61            ))
62            .score();
63
64        // 基本スコア 0.5 + ボーナス * weight
65        0.5 + bonus * self.weight
66    }
67}
68
69impl<N, E, S> SelectionLogic<N, E, S> for Greedy
70where
71    N: Debug + Clone + ActionExtractor,
72    E: Debug + Clone,
73    S: MapState,
74{
75    fn next(
76        &self,
77        map: &GraphMap<N, E, S>,
78        stats: &SwarmStats,
79        provider: &dyn LearnedProvider,
80    ) -> Option<MapNodeId> {
81        self.select(map, 1, stats, provider).into_iter().next()
82    }
83
84    fn select(
85        &self,
86        map: &GraphMap<N, E, S>,
87        count: usize,
88        _stats: &SwarmStats,
89        provider: &dyn LearnedProvider,
90    ) -> Vec<MapNodeId> {
91        let frontiers = map.frontiers();
92        if frontiers.is_empty() || count == 0 {
93            return vec![];
94        }
95
96        // スコア計算してソート(親ノードの情報を活用)
97        let mut scored: Vec<_> = frontiers
98            .iter()
99            .filter_map(|&id| {
100                map.get(id).map(|node| {
101                    let (action, target) = node.data.extract();
102
103                    // 親ノードのアクションを取得(学習ボーナス用)
104                    let parent_id = map.parent(id);
105                    let prev_action = parent_id.and_then(|pid| {
106                        map.get(pid)
107                            .map(|parent_node| parent_node.data.action_name().to_string())
108                    });
109
110                    // 親の親のアクションを取得(N-gram ボーナス用)
111                    let prev_prev_action =
112                        parent_id.and_then(|pid| map.parent(pid)).and_then(|ppid| {
113                            map.get(ppid).map(|grandparent_node| {
114                                grandparent_node.data.action_name().to_string()
115                            })
116                        });
117
118                    let score = self.compute_score_with_context(
119                        action,
120                        target,
121                        provider,
122                        prev_action.as_deref(),
123                        prev_prev_action.as_deref(),
124                    );
125                    (id, score)
126                })
127            })
128            .collect();
129
130        // スコア降順でソート(高スコアを優先)
131        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
132
133        scored.into_iter().take(count).map(|(id, _)| id).collect()
134    }
135
136    fn score(
137        &self,
138        action: &str,
139        target: Option<&str>,
140        _stats: &SwarmStats,
141        provider: &dyn LearnedProvider,
142    ) -> f64 {
143        self.compute_score(action, target, provider)
144    }
145
146    fn name(&self) -> &str {
147        "Greedy"
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154    use crate::learn::NullProvider;
155
156    #[test]
157    fn test_greedy_default_score() {
158        let greedy = Greedy::new();
159        let provider = NullProvider;
160
161        // NullProvider は常に NotAvailable → score() = 0.0 → 基本スコア 0.5
162        assert_eq!(greedy.compute_score("unknown", None, &provider), 0.5);
163    }
164
165    #[test]
166    fn test_greedy_with_null_provider() {
167        let greedy = Greedy::new();
168        let provider = NullProvider;
169
170        // NullProvider では全て 0.5
171        assert_eq!(greedy.compute_score("grep", None, &provider), 0.5);
172        assert_eq!(greedy.compute_score("glob", Some("svc1"), &provider), 0.5);
173        assert_eq!(greedy.compute_score("other", None, &provider), 0.5);
174    }
175
176    #[test]
177    fn test_greedy_with_weight() {
178        let greedy = Greedy::with_weight(2.0);
179        let provider = NullProvider;
180
181        // NullProvider でもスコアは 0.5(ボーナス 0 × weight は 0)
182        assert_eq!(greedy.compute_score("action", None, &provider), 0.5);
183    }
184}