swarm_engine_core/exploration/
node_rules.rs1use std::collections::{HashMap, HashSet};
32
33use super::DependencyGraph;
34
35pub trait Rules: Send + Sync {
49 fn successors(&self, node_type: &str) -> Vec<&str>;
51
52 fn roots(&self) -> Vec<&str>;
54
55 fn is_terminal(&self, node_type: &str) -> bool;
57
58 fn is_empty(&self) -> bool;
60
61 fn param_variants(&self, _node_type: &str) -> Option<(&str, &[String])> {
71 None
72 }
73}
74
75#[derive(Debug, Clone, Default)]
90pub struct NodeRules {
91 successors: HashMap<String, HashSet<String>>,
93
94 roots: HashSet<String>,
96
97 terminals: HashSet<String>,
99
100 param_variants: HashMap<String, (String, Vec<String>)>,
104
105 edge_confidence: HashMap<(String, String), f64>,
110}
111
112impl NodeRules {
113 pub fn new() -> Self {
114 Self::default()
115 }
116
117 pub fn add_rule(mut self, from: &str, to: &str) -> Self {
123 self.successors
124 .entry(from.to_string())
125 .or_default()
126 .insert(to.to_string());
127 self
128 }
129
130 pub fn add_rules(mut self, from: &str, tos: &[&str]) -> Self {
132 let entry = self.successors.entry(from.to_string()).or_default();
133 for to in tos {
134 entry.insert(to.to_string());
135 }
136 self
137 }
138
139 pub fn add_root(mut self, node_type: &str) -> Self {
141 self.roots.insert(node_type.to_string());
142 self
143 }
144
145 pub fn add_roots(mut self, node_types: &[&str]) -> Self {
147 for node_type in node_types {
148 self.roots.insert(node_type.to_string());
149 }
150 self
151 }
152
153 pub fn add_terminal(mut self, node_type: &str) -> Self {
155 self.terminals.insert(node_type.to_string());
156 self
157 }
158
159 pub fn add_terminals(mut self, node_types: &[&str]) -> Self {
161 for node_type in node_types {
162 self.terminals.insert(node_type.to_string());
163 }
164 self
165 }
166
167 pub fn add_param_variants(mut self, node_type: &str, key: &str, values: &[&str]) -> Self {
179 self.param_variants.insert(
180 node_type.to_string(),
181 (
182 key.to_string(),
183 values.iter().map(|s| s.to_string()).collect(),
184 ),
185 );
186 self
187 }
188
189 pub fn add_rule_with_confidence(mut self, from: &str, to: &str, confidence: f64) -> Self {
191 self.successors
192 .entry(from.to_string())
193 .or_default()
194 .insert(to.to_string());
195 self.edge_confidence.insert(
196 (from.to_string(), to.to_string()),
197 confidence.clamp(0.0, 1.0),
198 );
199 self
200 }
201
202 pub fn successors(&self, node_type: &str) -> Vec<&str> {
208 self.successors
209 .get(node_type)
210 .map(|set| set.iter().map(|s| s.as_str()).collect())
211 .unwrap_or_default()
212 }
213
214 pub fn roots(&self) -> Vec<&str> {
216 self.roots.iter().map(|s| s.as_str()).collect()
217 }
218
219 pub fn terminals(&self) -> Vec<&str> {
221 self.terminals.iter().map(|s| s.as_str()).collect()
222 }
223
224 pub fn can_transition(&self, from: &str, to: &str) -> bool {
226 self.successors
227 .get(from)
228 .map(|set| set.contains(to))
229 .unwrap_or(false)
230 }
231
232 pub fn has_node_type(&self, node_type: &str) -> bool {
234 self.successors.contains_key(node_type)
235 || self.roots.contains(node_type)
236 || self.terminals.contains(node_type)
237 }
238
239 pub fn is_terminal(&self, node_type: &str) -> bool {
241 self.terminals.contains(node_type)
242 }
243
244 pub fn is_root(&self, node_type: &str) -> bool {
246 self.roots.contains(node_type)
247 }
248
249 pub fn get_confidence(&self, from: &str, to: &str) -> Option<f64> {
251 self.edge_confidence
252 .get(&(from.to_string(), to.to_string()))
253 .copied()
254 }
255
256 pub fn confidence_map(&self) -> HashMap<String, f64> {
260 let mut result = HashMap::new();
261 for ((_, to), conf) in &self.edge_confidence {
262 let entry = result.entry(to.clone()).or_insert(0.0);
264 if *conf > *entry {
265 *entry = *conf;
266 }
267 }
268 result
269 }
270
271 pub fn is_empty(&self) -> bool {
273 self.successors.is_empty() && self.roots.is_empty()
274 }
275
276 #[cfg(test)]
285 pub fn for_testing() -> Self {
286 Self::new()
287 .add_roots(&["grep", "glob"])
288 .add_rules("grep", &["read", "summary"])
289 .add_rule("glob", "grep")
290 .add_terminals(&["read", "summary"])
291 }
292}
293
294impl Rules for NodeRules {
299 fn successors(&self, node_type: &str) -> Vec<&str> {
300 self.successors(node_type)
301 }
302
303 fn roots(&self) -> Vec<&str> {
304 self.roots()
305 }
306
307 fn is_terminal(&self, node_type: &str) -> bool {
308 self.is_terminal(node_type)
309 }
310
311 fn is_empty(&self) -> bool {
312 self.is_empty()
313 }
314
315 fn param_variants(&self, node_type: &str) -> Option<(&str, &[String])> {
316 self.param_variants
317 .get(node_type)
318 .map(|(key, values)| (key.as_str(), values.as_slice()))
319 }
320}
321
322impl From<DependencyGraph> for NodeRules {
327 fn from(graph: DependencyGraph) -> Self {
334 let mut rules = NodeRules::new();
335
336 for start in graph.start_actions() {
338 rules.roots.insert(start);
339 }
340
341 for terminal in graph.terminal_actions() {
343 rules.terminals.insert(terminal);
344 }
345
346 for edge in graph.edges() {
348 rules
349 .successors
350 .entry(edge.from.clone())
351 .or_default()
352 .insert(edge.to.clone());
353 rules
354 .edge_confidence
355 .insert((edge.from.clone(), edge.to.clone()), edge.confidence);
356 }
357
358 for (action, (key, values)) in graph.all_param_variants() {
360 rules
361 .param_variants
362 .insert(action.clone(), (key.clone(), values.clone()));
363 }
364
365 rules
366 }
367}
368
369impl From<&DependencyGraph> for NodeRules {
370 fn from(graph: &DependencyGraph) -> Self {
371 let mut rules = NodeRules::new();
372
373 for start in graph.start_actions() {
374 rules.roots.insert(start);
375 }
376
377 for terminal in graph.terminal_actions() {
378 rules.terminals.insert(terminal);
379 }
380
381 for edge in graph.edges() {
383 rules
384 .successors
385 .entry(edge.from.clone())
386 .or_default()
387 .insert(edge.to.clone());
388 rules
389 .edge_confidence
390 .insert((edge.from.clone(), edge.to.clone()), edge.confidence);
391 }
392
393 for (action, (key, values)) in graph.all_param_variants() {
395 rules
396 .param_variants
397 .insert(action.clone(), (key.clone(), values.clone()));
398 }
399
400 rules
401 }
402}
403
404#[cfg(test)]
409mod tests {
410 use super::*;
411 use crate::exploration::DependencyGraphBuilder;
412
413 #[test]
414 fn test_node_rules_basic() {
415 let rules = NodeRules::new()
416 .add_roots(&["grep", "glob"])
417 .add_rules("grep", &["read", "summary", "grep"])
418 .add_rules("read", &["analyze", "extract"])
419 .add_rule("summary", "report")
420 .add_terminals(&["report", "extract"]);
421
422 let roots = rules.roots();
424 assert!(roots.contains(&"grep"));
425 assert!(roots.contains(&"glob"));
426
427 let grep_successors = rules.successors("grep");
429 assert_eq!(grep_successors.len(), 3);
430 assert!(grep_successors.contains(&"read"));
431 assert!(grep_successors.contains(&"summary"));
432
433 assert!(rules.can_transition("grep", "read"));
435 assert!(!rules.can_transition("grep", "report"));
436 assert!(rules.can_transition("summary", "report"));
437
438 assert!(rules.is_terminal("report"));
440 assert!(rules.is_terminal("extract"));
441 assert!(!rules.is_terminal("grep"));
442 }
443
444 #[test]
445 fn test_node_rules_empty() {
446 let rules = NodeRules::new();
447 assert!(rules.is_empty());
448 assert!(rules.successors("anything").is_empty());
449 assert!(rules.roots().is_empty());
450 }
451
452 #[test]
453 fn test_node_rules_has_node_type() {
454 let rules = NodeRules::new()
455 .add_root("start")
456 .add_rule("middle", "end")
457 .add_terminal("end");
458
459 assert!(rules.has_node_type("start"));
460 assert!(rules.has_node_type("middle"));
461 assert!(rules.has_node_type("end"));
462 assert!(!rules.has_node_type("unknown"));
463 }
464
465 #[test]
470 fn test_from_dependency_graph() {
471 let graph = DependencyGraphBuilder::new()
472 .task("Find auth function")
473 .available_actions(["Grep", "List", "Read"])
474 .edge("Grep", "Read", 0.95)
475 .edge("List", "Grep", 0.60)
476 .edge("List", "Read", 0.40)
477 .start_nodes(["Grep", "List"])
478 .terminal_node("Read")
479 .build();
480
481 let rules: NodeRules = graph.into();
483
484 assert!(rules.is_root("Grep"));
486 assert!(rules.is_root("List"));
487 assert!(!rules.is_root("Read"));
488
489 assert!(rules.is_terminal("Read"));
491 assert!(!rules.is_terminal("Grep"));
492
493 assert!(rules.can_transition("Grep", "Read"));
495 assert!(rules.can_transition("List", "Grep"));
496 assert!(rules.can_transition("List", "Read"));
497 assert!(!rules.can_transition("Read", "Grep")); }
499
500 #[test]
501 fn test_from_dependency_graph_ref() {
502 let graph = DependencyGraphBuilder::new()
503 .edge("A", "B", 0.9)
504 .start_node("A")
505 .terminal_node("B")
506 .build();
507
508 let rules: NodeRules = (&graph).into();
510
511 assert!(rules.is_root("A"));
512 assert!(rules.is_terminal("B"));
513 assert!(rules.can_transition("A", "B"));
514 }
515}