swarm_engine_core/exploration/selection/
greedy.rs1use std::fmt::Debug;
6
7use super::SelectionLogic;
8use crate::exploration::map::{ExplorationMap, GraphMap, MapNodeId, MapState};
9use crate::exploration::mutation::ActionExtractor;
10use crate::learn::{LearnedProvider, LearningQuery};
11use crate::online_stats::SwarmStats;
12
13#[derive(Debug, Clone, Default)]
18pub struct Greedy {
19 weight: f64,
21}
22
23impl Greedy {
24 pub fn new() -> Self {
25 Self { weight: 1.0 }
26 }
27
28 pub fn with_weight(weight: f64) -> Self {
30 Self { weight }
31 }
32
33 pub fn compute_score(
35 &self,
36 action: &str,
37 target: Option<&str>,
38 provider: &dyn LearnedProvider,
39 ) -> f64 {
40 self.compute_score_with_context(action, target, provider, None, None)
41 }
42
43 pub fn compute_score_with_context(
47 &self,
48 action: &str,
49 target: Option<&str>,
50 provider: &dyn LearnedProvider,
51 prev_action: Option<&str>,
52 prev_prev_action: Option<&str>,
53 ) -> f64 {
54 let bonus = provider
56 .query(LearningQuery::confidence_with_context(
57 action,
58 target,
59 prev_action,
60 prev_prev_action,
61 ))
62 .score();
63
64 0.5 + bonus * self.weight
66 }
67}
68
69impl<N, E, S> SelectionLogic<N, E, S> for Greedy
70where
71 N: Debug + Clone + ActionExtractor,
72 E: Debug + Clone,
73 S: MapState,
74{
75 fn next(
76 &self,
77 map: &GraphMap<N, E, S>,
78 stats: &SwarmStats,
79 provider: &dyn LearnedProvider,
80 ) -> Option<MapNodeId> {
81 self.select(map, 1, stats, provider).into_iter().next()
82 }
83
84 fn select(
85 &self,
86 map: &GraphMap<N, E, S>,
87 count: usize,
88 _stats: &SwarmStats,
89 provider: &dyn LearnedProvider,
90 ) -> Vec<MapNodeId> {
91 let frontiers = map.frontiers();
92 if frontiers.is_empty() || count == 0 {
93 return vec![];
94 }
95
96 let mut scored: Vec<_> = frontiers
98 .iter()
99 .filter_map(|&id| {
100 map.get(id).map(|node| {
101 let (action, target) = node.data.extract();
102
103 let parent_id = map.parent(id);
105 let prev_action = parent_id.and_then(|pid| {
106 map.get(pid)
107 .map(|parent_node| parent_node.data.action_name().to_string())
108 });
109
110 let prev_prev_action =
112 parent_id.and_then(|pid| map.parent(pid)).and_then(|ppid| {
113 map.get(ppid).map(|grandparent_node| {
114 grandparent_node.data.action_name().to_string()
115 })
116 });
117
118 let score = self.compute_score_with_context(
119 action,
120 target,
121 provider,
122 prev_action.as_deref(),
123 prev_prev_action.as_deref(),
124 );
125 (id, score)
126 })
127 })
128 .collect();
129
130 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
132
133 scored.into_iter().take(count).map(|(id, _)| id).collect()
134 }
135
136 fn score(
137 &self,
138 action: &str,
139 target: Option<&str>,
140 _stats: &SwarmStats,
141 provider: &dyn LearnedProvider,
142 ) -> f64 {
143 self.compute_score(action, target, provider)
144 }
145
146 fn name(&self) -> &str {
147 "Greedy"
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154 use crate::learn::NullProvider;
155
156 #[test]
157 fn test_greedy_default_score() {
158 let greedy = Greedy::new();
159 let provider = NullProvider;
160
161 assert_eq!(greedy.compute_score("unknown", None, &provider), 0.5);
163 }
164
165 #[test]
166 fn test_greedy_with_null_provider() {
167 let greedy = Greedy::new();
168 let provider = NullProvider;
169
170 assert_eq!(greedy.compute_score("grep", None, &provider), 0.5);
172 assert_eq!(greedy.compute_score("glob", Some("svc1"), &provider), 0.5);
173 assert_eq!(greedy.compute_score("other", None, &provider), 0.5);
174 }
175
176 #[test]
177 fn test_greedy_with_weight() {
178 let greedy = Greedy::with_weight(2.0);
179 let provider = NullProvider;
180
181 assert_eq!(greedy.compute_score("action", None, &provider), 0.5);
183 }
184}