1use std::collections::{HashMap, HashSet};
43
44use serde::{Deserialize, Serialize};
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct DependencyEdge {
56 pub from: String,
58 pub to: String,
60 pub confidence: f64,
62}
63
64impl DependencyEdge {
65 pub fn new(from: impl Into<String>, to: impl Into<String>, confidence: f64) -> Self {
66 Self {
67 from: from.into(),
68 to: to.into(),
69 confidence: confidence.clamp(0.0, 1.0),
70 }
71 }
72}
73
74#[derive(Debug, Clone, Default, Serialize, Deserialize)]
79pub struct DependencyGraph {
80 edges: Vec<DependencyEdge>,
82 start_nodes: HashSet<String>,
84 terminal_nodes: HashSet<String>,
86 task: String,
88 available_actions: Vec<String>,
90 #[serde(default)]
94 param_variants: HashMap<String, (String, Vec<String>)>,
95}
96
97impl DependencyGraph {
98 pub fn new() -> Self {
100 Self::default()
101 }
102
103 pub fn builder() -> DependencyGraphBuilder {
105 DependencyGraphBuilder::new()
106 }
107
108 pub fn valid_next_actions(&self, current_action: &str) -> Vec<String> {
116 let mut edges: Vec<_> = self
117 .edges
118 .iter()
119 .filter(|e| e.from == current_action)
120 .collect();
121
122 edges.sort_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap());
124
125 edges.iter().map(|e| e.to.clone()).collect()
126 }
127
128 pub fn start_actions(&self) -> Vec<String> {
132 self.start_nodes.iter().cloned().collect()
133 }
134
135 pub fn terminal_actions(&self) -> Vec<String> {
137 self.terminal_nodes.iter().cloned().collect()
138 }
139
140 pub fn is_terminal(&self, action: &str) -> bool {
142 self.terminal_nodes.contains(action)
143 }
144
145 pub fn is_start(&self, action: &str) -> bool {
147 self.start_nodes.contains(action)
148 }
149
150 pub fn can_transition(&self, from: &str, to: &str) -> bool {
152 self.edges.iter().any(|e| e.from == from && e.to == to)
153 }
154
155 pub fn transition_confidence(&self, from: &str, to: &str) -> Option<f64> {
157 self.edges
158 .iter()
159 .find(|e| e.from == from && e.to == to)
160 .map(|e| e.confidence)
161 }
162
163 pub fn edges(&self) -> &[DependencyEdge] {
165 &self.edges
166 }
167
168 pub fn task(&self) -> &str {
170 &self.task
171 }
172
173 pub fn available_actions(&self) -> &[String] {
175 &self.available_actions
176 }
177
178 pub fn param_variants(&self, action: &str) -> Option<(&str, &[String])> {
180 self.param_variants
181 .get(action)
182 .map(|(key, values)| (key.as_str(), values.as_slice()))
183 }
184
185 pub fn all_param_variants(&self) -> &HashMap<String, (String, Vec<String>)> {
187 &self.param_variants
188 }
189
190 pub fn validate(&self) -> Result<(), DependencyGraphError> {
200 if self.start_nodes.is_empty() {
201 return Err(DependencyGraphError::NoStartNodes);
202 }
203
204 if self.terminal_nodes.is_empty() {
205 return Err(DependencyGraphError::NoTerminalNodes);
206 }
207
208 for edge in &self.edges {
210 if !self.available_actions.contains(&edge.from) {
211 return Err(DependencyGraphError::UnknownAction(edge.from.clone()));
212 }
213 if !self.available_actions.contains(&edge.to) {
214 return Err(DependencyGraphError::UnknownAction(edge.to.clone()));
215 }
216 }
217
218 Ok(())
219 }
220
221 pub fn to_mermaid(&self) -> String {
227 let mut lines = vec!["graph LR".to_string()];
228
229 for edge in &self.edges {
230 let label = format!("{:.0}%", edge.confidence * 100.0);
231 lines.push(format!(" {} -->|{}| {}", edge.from, label, edge.to));
232 }
233
234 for start in &self.start_nodes {
236 lines.push(format!(" style {} fill:#9f9", start));
237 }
238
239 for terminal in &self.terminal_nodes {
241 lines.push(format!(" style {} fill:#f99", terminal));
242 }
243
244 lines.join("\n")
245 }
246}
247
248#[derive(Debug, Clone, thiserror::Error)]
250pub enum DependencyGraphError {
251 #[error("No start nodes defined")]
252 NoStartNodes,
253
254 #[error("No terminal nodes defined")]
255 NoTerminalNodes,
256
257 #[error("Unknown action: {0}")]
258 UnknownAction(String),
259
260 #[error("Parse error: {0}")]
261 ParseError(String),
262
263 #[error("LLM error: {0}")]
264 LlmError(String),
265}
266
267pub trait DependencyGraphProvider: Send + Sync {
294 fn provide_graph(&self, task: &str, available_actions: &[String]) -> Option<DependencyGraph>;
304}
305
306#[derive(Debug, Clone, Default)]
312pub struct DependencyGraphBuilder {
313 edges: Vec<DependencyEdge>,
314 start_nodes: HashSet<String>,
315 terminal_nodes: HashSet<String>,
316 task: String,
317 available_actions: Vec<String>,
318 param_variants: HashMap<String, (String, Vec<String>)>,
319}
320
321impl DependencyGraphBuilder {
322 pub fn new() -> Self {
323 Self::default()
324 }
325
326 pub fn task(mut self, task: impl Into<String>) -> Self {
328 self.task = task.into();
329 self
330 }
331
332 pub fn available_actions<I, S>(mut self, actions: I) -> Self
334 where
335 I: IntoIterator<Item = S>,
336 S: Into<String>,
337 {
338 self.available_actions = actions.into_iter().map(|s| s.into()).collect();
339 self
340 }
341
342 pub fn edge(mut self, from: impl Into<String>, to: impl Into<String>, confidence: f64) -> Self {
344 self.edges.push(DependencyEdge::new(from, to, confidence));
345 self
346 }
347
348 pub fn start_node(mut self, action: impl Into<String>) -> Self {
350 self.start_nodes.insert(action.into());
351 self
352 }
353
354 pub fn start_nodes<I, S>(mut self, actions: I) -> Self
356 where
357 I: IntoIterator<Item = S>,
358 S: Into<String>,
359 {
360 self.start_nodes
361 .extend(actions.into_iter().map(|s| s.into()));
362 self
363 }
364
365 pub fn terminal_node(mut self, action: impl Into<String>) -> Self {
367 self.terminal_nodes.insert(action.into());
368 self
369 }
370
371 pub fn terminal_nodes<I, S>(mut self, actions: I) -> Self
373 where
374 I: IntoIterator<Item = S>,
375 S: Into<String>,
376 {
377 self.terminal_nodes
378 .extend(actions.into_iter().map(|s| s.into()));
379 self
380 }
381
382 pub fn param_variants<I, S>(
386 mut self,
387 action: impl Into<String>,
388 key: impl Into<String>,
389 values: I,
390 ) -> Self
391 where
392 I: IntoIterator<Item = S>,
393 S: Into<String>,
394 {
395 self.param_variants.insert(
396 action.into(),
397 (key.into(), values.into_iter().map(|s| s.into()).collect()),
398 );
399 self
400 }
401
402 pub fn build(self) -> DependencyGraph {
404 DependencyGraph {
405 edges: self.edges,
406 start_nodes: self.start_nodes,
407 terminal_nodes: self.terminal_nodes,
408 task: self.task,
409 available_actions: self.available_actions,
410 param_variants: self.param_variants,
411 }
412 }
413
414 pub fn build_validated(self) -> Result<DependencyGraph, DependencyGraphError> {
416 let graph = self.build();
417 graph.validate()?;
418 Ok(graph)
419 }
420}
421
422#[derive(Debug, Clone, Serialize, Deserialize)]
430pub struct LlmDependencyResponse {
431 pub edges: Vec<LlmEdge>,
433 pub start: Vec<String>,
435 pub terminal: Vec<String>,
437 #[serde(default)]
439 pub reasoning: Option<String>,
440}
441
442#[derive(Debug, Clone, Serialize, Deserialize)]
444pub struct LlmEdge {
445 pub from: String,
446 pub to: String,
447 pub confidence: f64,
448}
449
450impl LlmDependencyResponse {
451 pub fn into_graph(
453 self,
454 task: impl Into<String>,
455 available_actions: Vec<String>,
456 ) -> DependencyGraph {
457 let mut builder = DependencyGraphBuilder::new()
458 .task(task)
459 .available_actions(available_actions)
460 .start_nodes(self.start)
461 .terminal_nodes(self.terminal);
462
463 for edge in self.edges {
464 builder = builder.edge(edge.from, edge.to, edge.confidence);
465 }
466
467 builder.build()
468 }
469
470 pub fn parse(text: &str) -> Result<Self, DependencyGraphError> {
476 if let Some(response) = Self::parse_arrow_format(text) {
478 return Ok(response);
479 }
480
481 if let Ok(parsed) = serde_json::from_str(text) {
483 return Ok(parsed);
484 }
485
486 if let Some(json) = Self::extract_json(text) {
488 serde_json::from_str(&json).map_err(|e| DependencyGraphError::ParseError(e.to_string()))
489 } else {
490 Err(DependencyGraphError::ParseError(format!(
491 "No valid format found in response: {}",
492 text.chars().take(200).collect::<String>()
493 )))
494 }
495 }
496
497 fn parse_arrow_format(text: &str) -> Option<Self> {
504 if let Some(result) = Self::parse_arrow_only(text) {
506 return Some(result);
507 }
508
509 if let Some(result) = Self::parse_numbered_list(text) {
511 return Some(result);
512 }
513
514 None
515 }
516
517 fn parse_arrow_only(text: &str) -> Option<Self> {
519 let normalized = text.replace('→', "->");
520
521 let arrow_line = normalized.lines().find(|line| line.contains("->"))?;
523
524 let parts: Vec<&str> = arrow_line.split("->").collect();
525 if parts.len() < 2 {
526 return None;
527 }
528
529 let actions_in_order: Vec<String> = parts
531 .iter()
532 .filter_map(|part| {
533 let trimmed = part.trim();
534 let last_word = trimmed.split_whitespace().last()?;
536 let action: String = last_word.chars().filter(|c| c.is_alphabetic()).collect();
538 if action.is_empty() {
539 None
540 } else {
541 Some(action)
542 }
543 })
544 .collect();
545
546 if actions_in_order.len() < 2 {
547 return None;
548 }
549
550 Self::build_response(actions_in_order)
551 }
552
553 fn parse_numbered_list(text: &str) -> Option<Self> {
555 let mut actions_in_order: Vec<String> = Vec::new();
556
557 for i in 1..=10 {
559 let pattern = format!("{}.", i);
560 if let Some(pos) = text.find(&pattern) {
561 let after = &text[pos + pattern.len()..];
563 if let Some(word) = after.split_whitespace().next() {
564 let action: String = word.chars().filter(|c| c.is_alphabetic()).collect();
566 if !action.is_empty() && !actions_in_order.contains(&action) {
567 actions_in_order.push(action);
568 }
569 }
570 }
571 }
572
573 if actions_in_order.len() < 2 {
574 return None;
575 }
576
577 Self::build_response(actions_in_order)
578 }
579
580 fn build_response(actions_in_order: Vec<String>) -> Option<Self> {
582 let mut edges = Vec::new();
583 for window in actions_in_order.windows(2) {
584 edges.push(LlmEdge {
585 from: window[0].clone(),
586 to: window[1].clone(),
587 confidence: 0.9,
588 });
589 }
590
591 Some(Self {
592 edges,
593 start: vec![actions_in_order.first()?.clone()],
594 terminal: vec![actions_in_order.last()?.clone()],
595 reasoning: Some("Parsed from text format".to_string()),
596 })
597 }
598
599 fn extract_json(text: &str) -> Option<String> {
601 let start = text.find('{')?;
603 let chars: Vec<char> = text[start..].chars().collect();
604 let mut depth = 0;
605 let mut in_string = false;
606 let mut escape_next = false;
607
608 for (i, &ch) in chars.iter().enumerate() {
609 if escape_next {
610 escape_next = false;
611 continue;
612 }
613
614 match ch {
615 '\\' if in_string => escape_next = true,
616 '"' => in_string = !in_string,
617 '{' if !in_string => depth += 1,
618 '}' if !in_string => {
619 depth -= 1;
620 if depth == 0 {
621 return Some(chars[..=i].iter().collect());
622 }
623 }
624 _ => {}
625 }
626 }
627
628 None
629 }
630}
631
632pub trait DependencyPlanner: Send + Sync {
640 fn plan(
644 &self,
645 task: &str,
646 available_actions: &[String],
647 ) -> Result<DependencyGraph, DependencyGraphError>;
648
649 fn name(&self) -> &str;
651}
652
653#[derive(Debug, Clone, Default)]
658pub struct StaticDependencyPlanner {
659 patterns: HashMap<String, DependencyGraph>,
661 default_pattern: Option<String>,
663}
664
665impl StaticDependencyPlanner {
666 pub fn new() -> Self {
667 Self::default()
668 }
669
670 pub fn with_pattern(mut self, name: impl Into<String>, graph: DependencyGraph) -> Self {
672 let name = name.into();
673 if self.default_pattern.is_none() {
674 self.default_pattern = Some(name.clone());
675 }
676 self.patterns.insert(name, graph);
677 self
678 }
679
680 pub fn with_default_pattern(mut self, name: impl Into<String>) -> Self {
682 self.default_pattern = Some(name.into());
683 self
684 }
685
686 pub fn with_file_exploration_pattern(self) -> Self {
690 let graph = DependencyGraph::builder()
691 .task("File exploration")
692 .available_actions(["Grep", "List", "Read"])
693 .edge("Grep", "Read", 0.95)
694 .edge("List", "Grep", 0.60)
695 .edge("List", "Read", 0.40)
696 .start_nodes(["Grep", "List"])
697 .terminal_node("Read")
698 .build();
699
700 self.with_pattern("file_exploration", graph)
701 }
702
703 pub fn with_code_search_pattern(self) -> Self {
707 let graph = DependencyGraph::builder()
708 .task("Code search")
709 .available_actions(["Grep", "Read"])
710 .edge("Grep", "Read", 0.95)
711 .start_node("Grep")
712 .terminal_node("Read")
713 .build();
714
715 self.with_pattern("code_search", graph)
716 }
717}
718
719impl DependencyPlanner for StaticDependencyPlanner {
720 fn plan(
721 &self,
722 task: &str,
723 available_actions: &[String],
724 ) -> Result<DependencyGraph, DependencyGraphError> {
725 if let Some(pattern_name) = &self.default_pattern {
727 if let Some(graph) = self.patterns.get(pattern_name) {
728 let mut graph = graph.clone();
729 graph.task = task.to_string();
730 graph.available_actions = available_actions.to_vec();
731 return Ok(graph);
732 }
733 }
734
735 if available_actions.is_empty() {
738 return Err(DependencyGraphError::NoStartNodes);
739 }
740
741 let mut builder = DependencyGraphBuilder::new()
742 .task(task)
743 .available_actions(available_actions.to_vec())
744 .start_node(&available_actions[0]);
745
746 if available_actions.len() > 1 {
747 for window in available_actions.windows(2) {
748 builder = builder.edge(&window[0], &window[1], 0.80);
749 }
750 builder = builder.terminal_node(&available_actions[available_actions.len() - 1]);
751 } else {
752 builder = builder.terminal_node(&available_actions[0]);
753 }
754
755 Ok(builder.build())
756 }
757
758 fn name(&self) -> &str {
759 "StaticDependencyPlanner"
760 }
761}
762
763use crate::actions::ActionDef;
768
769pub struct DependencyPromptGenerator;
774
775impl DependencyPromptGenerator {
776 pub fn generate_prompt(task: &str, actions: &[ActionDef]) -> String {
780 let actions_list = actions
781 .iter()
782 .map(|a| a.name.as_str())
783 .collect::<Vec<_>>()
784 .join(", ");
785
786 format!(
787 r#"{task}
788Steps: {actions_list}
789The very first step is:"#
790 )
791 }
792
793 pub fn generate_first_prompt(_task: &str, actions: &[ActionDef]) -> String {
798 let mut sorted_actions: Vec<&ActionDef> = actions.iter().collect();
800 sorted_actions.sort_by(|a, b| a.name.cmp(&b.name));
801
802 let actions_list = sorted_actions
803 .iter()
804 .map(|a| a.name.as_str())
805 .collect::<Vec<_>>()
806 .join(", ");
807
808 let descriptions: Vec<String> = sorted_actions
810 .iter()
811 .map(|a| format!("- {}: {}", a.name, a.description))
812 .collect();
813 let descriptions_block = descriptions.join("\n");
814
815 let first_verb = sorted_actions
817 .first()
818 .map(|a| Self::extract_verb(&a.description))
819 .unwrap_or_else(|| "CHECK".to_string());
820
821 format!(
822 r#"Steps: {actions_list}
823{descriptions_block}
824Which step {first_verb}S first?
825Answer:"#
826 )
827 }
828
829 pub fn generate_last_prompt(_task: &str, actions: &[ActionDef]) -> String {
834 let mut sorted_actions: Vec<&ActionDef> = actions.iter().collect();
836 sorted_actions.sort_by(|a, b| a.name.cmp(&b.name));
837
838 let actions_list = sorted_actions
839 .iter()
840 .map(|a| a.name.as_str())
841 .collect::<Vec<_>>()
842 .join(", ");
843
844 let descriptions: Vec<String> = sorted_actions
845 .iter()
846 .map(|a| format!("- {}: {}", a.name, a.description))
847 .collect();
848 let descriptions_block = descriptions.join("\n");
849
850 format!(
851 r#"Steps: {actions_list}
852{descriptions_block}
853Which step should be done last?
854Answer:"#
855 )
856 }
857
858 pub fn generate_pair_prompt(task: &str, action_a: &str, action_b: &str) -> String {
860 format!(
861 r#"For {task}, which comes first: {action_a} or {action_b}?
862Answer (one word):"#
863 )
864 }
865
866 fn extract_verb(description: &str) -> String {
871 description
872 .split_whitespace()
873 .next()
874 .map(|w| {
875 let word = w.trim_end_matches('s').trim_end_matches('S');
877 word.to_uppercase()
878 })
879 .unwrap_or_else(|| "CHECK".to_string())
880 }
881}
882
883#[derive(Debug, Clone)]
892pub struct GraphNavigator {
893 graph: DependencyGraph,
894 completed_actions: HashSet<String>,
896}
897
898impl GraphNavigator {
899 pub fn new(graph: DependencyGraph) -> Self {
900 Self {
901 graph,
902 completed_actions: HashSet::new(),
903 }
904 }
905
906 pub fn mark_completed(&mut self, action: &str) {
908 self.completed_actions.insert(action.to_string());
909 }
910
911 pub fn suggest_next(&self) -> Vec<String> {
917 if self.completed_actions.is_empty() {
918 return self.graph.start_actions();
920 }
921
922 let mut candidates = Vec::new();
924 for completed in &self.completed_actions {
925 for next in self.graph.valid_next_actions(completed) {
926 if !self.completed_actions.contains(&next) && !candidates.contains(&next) {
927 candidates.push(next);
928 }
929 }
930 }
931
932 candidates
933 }
934
935 pub fn is_task_complete(&self) -> bool {
939 self.graph
940 .terminal_actions()
941 .iter()
942 .any(|t| self.completed_actions.contains(t))
943 }
944
945 pub fn progress(&self) -> f64 {
947 if self.graph.available_actions.is_empty() {
948 return 0.0;
949 }
950 self.completed_actions.len() as f64 / self.graph.available_actions.len() as f64
951 }
952
953 pub fn graph(&self) -> &DependencyGraph {
955 &self.graph
956 }
957}
958
959#[cfg(test)]
964mod tests {
965 use super::*;
966
967 #[test]
968 fn test_dependency_graph_builder() {
969 let graph = DependencyGraph::builder()
970 .task("Find auth function")
971 .available_actions(["Grep", "List", "Read"])
972 .edge("Grep", "Read", 0.95)
973 .edge("List", "Grep", 0.60)
974 .start_nodes(["Grep", "List"])
975 .terminal_node("Read")
976 .build();
977
978 assert_eq!(graph.task(), "Find auth function");
979 assert!(graph.is_start("Grep"));
980 assert!(graph.is_start("List"));
981 assert!(graph.is_terminal("Read"));
982 assert!(graph.can_transition("Grep", "Read"));
983 assert!(!graph.can_transition("Read", "Grep"));
984 }
985
986 #[test]
987 fn test_valid_next_actions() {
988 let graph = DependencyGraph::builder()
989 .available_actions(["Grep", "List", "Read"])
990 .edge("Grep", "Read", 0.95)
991 .edge("List", "Grep", 0.60)
992 .edge("List", "Read", 0.40)
993 .start_nodes(["Grep", "List"])
994 .terminal_node("Read")
995 .build();
996
997 let next = graph.valid_next_actions("Grep");
999 assert_eq!(next, vec!["Read"]);
1000
1001 let next = graph.valid_next_actions("List");
1003 assert_eq!(next, vec!["Grep", "Read"]);
1004
1005 let next = graph.valid_next_actions("Read");
1007 assert!(next.is_empty());
1008 }
1009
1010 #[test]
1011 fn test_static_planner_file_exploration() {
1012 let planner = StaticDependencyPlanner::new().with_file_exploration_pattern();
1013
1014 let graph = planner
1015 .plan("Find auth.rs", &["Grep".to_string(), "Read".to_string()])
1016 .unwrap();
1017
1018 assert!(graph.is_start("Grep"));
1019 assert!(graph.is_terminal("Read"));
1020 }
1021
1022 #[test]
1023 fn test_graph_navigator() {
1024 let graph = DependencyGraph::builder()
1025 .available_actions(["Grep", "Read"])
1026 .edge("Grep", "Read", 0.95)
1027 .start_node("Grep")
1028 .terminal_node("Read")
1029 .build();
1030
1031 let mut nav = GraphNavigator::new(graph);
1032
1033 assert_eq!(nav.suggest_next(), vec!["Grep"]);
1035 assert!(!nav.is_task_complete());
1036
1037 nav.mark_completed("Grep");
1039 assert_eq!(nav.suggest_next(), vec!["Read"]);
1040 assert!(!nav.is_task_complete());
1041
1042 nav.mark_completed("Read");
1044 assert!(nav.is_task_complete());
1045 assert!(nav.suggest_next().is_empty());
1046 }
1047
1048 #[test]
1049 fn test_llm_response_parsing() {
1050 let json = r#"{
1051 "edges": [
1052 {"from": "Grep", "to": "Read", "confidence": 0.95}
1053 ],
1054 "start": ["Grep"],
1055 "terminal": ["Read"],
1056 "reasoning": "Search first, then read"
1057 }"#;
1058
1059 let response = LlmDependencyResponse::parse(json).unwrap();
1060 assert_eq!(response.edges.len(), 1);
1061 assert_eq!(response.start, vec!["Grep"]);
1062 assert_eq!(response.terminal, vec!["Read"]);
1063 assert!(response.reasoning.is_some());
1064
1065 let graph = response.into_graph(
1066 "Find function",
1067 vec!["Grep".to_string(), "Read".to_string()],
1068 );
1069 assert!(graph.can_transition("Grep", "Read"));
1070 }
1071
1072 #[test]
1073 fn test_mermaid_output() {
1074 let graph = DependencyGraph::builder()
1075 .available_actions(["Grep", "List", "Read"])
1076 .edge("Grep", "Read", 0.95)
1077 .edge("List", "Grep", 0.60)
1078 .start_nodes(["Grep", "List"])
1079 .terminal_node("Read")
1080 .build();
1081
1082 let mermaid = graph.to_mermaid();
1083 assert!(mermaid.contains("graph LR"));
1084 assert!(mermaid.contains("Grep -->|95%| Read"));
1085 assert!(mermaid.contains("style Read fill:#f99"));
1086 }
1087}