1use anyhow::{Context, Result};
4use petgraph::graph::{DiGraph, NodeIndex};
5use petgraph::visit::EdgeRef;
6use std::collections::HashMap;
7use std::time::Duration;
8
9use super::dot_parser::{AttrValue, DotGraph};
10
11#[derive(Debug)]
13pub struct PipelineGraph {
14 pub name: String,
15 pub graph_attrs: GraphAttrs,
16 pub graph: DiGraph<PipelineNode, PipelineEdge>,
17 pub node_index: HashMap<String, NodeIndex>,
18 pub start_node: NodeIndex,
19 pub exit_node: NodeIndex,
20}
21
22#[derive(Debug, Clone, Default)]
24pub struct GraphAttrs {
25 pub goal: Option<String>,
26 pub fidelity: Option<FidelityMode>,
27 pub model_stylesheet: Option<String>,
28 pub extra: HashMap<String, String>,
29}
30
31#[derive(Debug, Clone)]
33pub struct PipelineNode {
34 pub id: String,
35 pub label: String,
36 pub shape: String,
37 pub handler_type: String,
38 pub prompt: String,
39 pub max_retries: u32,
40 pub goal_gate: bool,
41 pub retry_target: Option<String>,
42 pub fallback_retry_target: Option<String>,
43 pub fidelity: Option<FidelityMode>,
44 pub thread_id: Option<String>,
45 pub classes: Vec<String>,
46 pub timeout: Option<Duration>,
47 pub llm_model: Option<String>,
48 pub llm_provider: Option<String>,
49 pub reasoning_effort: String,
50 pub auto_status: bool,
51 pub allow_partial: bool,
52 pub extra_attrs: HashMap<String, AttrValue>,
53}
54
55impl Default for PipelineNode {
56 fn default() -> Self {
57 Self {
58 id: String::new(),
59 label: String::new(),
60 shape: "box".into(),
61 handler_type: "codergen".into(),
62 prompt: String::new(),
63 max_retries: 0,
64 goal_gate: false,
65 retry_target: None,
66 fallback_retry_target: None,
67 fidelity: None,
68 thread_id: None,
69 classes: vec![],
70 timeout: None,
71 llm_model: None,
72 llm_provider: None,
73 reasoning_effort: "high".into(),
74 auto_status: true,
75 allow_partial: false,
76 extra_attrs: HashMap::new(),
77 }
78 }
79}
80
81#[derive(Debug, Clone)]
83pub struct PipelineEdge {
84 pub label: String,
85 pub condition: String,
86 pub weight: i32,
87 pub fidelity: Option<FidelityMode>,
88 pub thread_id: Option<String>,
89 pub loop_restart: bool,
90}
91
92impl Default for PipelineEdge {
93 fn default() -> Self {
94 Self {
95 label: String::new(),
96 condition: String::new(),
97 weight: 0,
98 fidelity: None,
99 thread_id: None,
100 loop_restart: false,
101 }
102 }
103}
104
105#[derive(Debug, Clone, PartialEq)]
107pub enum FidelityMode {
108 Full,
109 Truncate,
110 Compact,
111 Summary(SummaryLevel),
112}
113
114#[derive(Debug, Clone, PartialEq)]
116pub enum SummaryLevel {
117 Low,
118 Medium,
119 High,
120}
121
122impl FidelityMode {
123 pub fn from_str(s: &str) -> Option<Self> {
124 match s.to_lowercase().as_str() {
125 "full" => Some(FidelityMode::Full),
126 "truncate" => Some(FidelityMode::Truncate),
127 "compact" => Some(FidelityMode::Compact),
128 "summary" | "summary-medium" => Some(FidelityMode::Summary(SummaryLevel::Medium)),
129 "summary-low" => Some(FidelityMode::Summary(SummaryLevel::Low)),
130 "summary-high" => Some(FidelityMode::Summary(SummaryLevel::High)),
131 _ => None,
132 }
133 }
134}
135
136fn handler_type_from_shape(shape: &str) -> &str {
138 match shape.to_lowercase().as_str() {
139 "mdiamond" => "start",
140 "msquare" => "exit",
141 "box" | "rect" | "rectangle" => "codergen",
142 "hexagon" => "wait.human",
143 "diamond" => "conditional",
144 "component" => "parallel",
145 "tripleoctagon" => "parallel.fan_in",
146 "parallelogram" => "tool",
147 "house" => "stack.manager_loop",
148 _ => "codergen", }
150}
151
152impl PipelineGraph {
153 pub fn from_dot(dot: &DotGraph) -> Result<Self> {
155 let mut graph = DiGraph::new();
156 let mut node_index = HashMap::new();
157
158 let graph_attrs = GraphAttrs {
160 goal: dot
161 .graph_attrs
162 .get("goal")
163 .map(|v| v.as_str()),
164 fidelity: dot
165 .graph_attrs
166 .get("fidelity")
167 .and_then(|v| FidelityMode::from_str(&v.as_str())),
168 model_stylesheet: dot
169 .graph_attrs
170 .get("model_stylesheet")
171 .map(|v| v.as_str()),
172 extra: dot
173 .graph_attrs
174 .iter()
175 .filter(|(k, _)| !["goal", "fidelity", "model_stylesheet"].contains(&k.as_str()))
176 .map(|(k, v)| (k.clone(), v.as_str()))
177 .collect(),
178 };
179
180 let mut all_node_ids: Vec<String> = Vec::new();
182 for node in &dot.nodes {
183 if !all_node_ids.contains(&node.id) {
184 all_node_ids.push(node.id.clone());
185 }
186 }
187 for edge in &dot.edges {
188 if !all_node_ids.contains(&edge.from) {
189 all_node_ids.push(edge.from.clone());
190 }
191 if !all_node_ids.contains(&edge.to) {
192 all_node_ids.push(edge.to.clone());
193 }
194 }
195 for sg in &dot.subgraphs {
196 for node in &sg.nodes {
197 if !all_node_ids.contains(&node.id) {
198 all_node_ids.push(node.id.clone());
199 }
200 }
201 for edge in &sg.edges {
202 if !all_node_ids.contains(&edge.from) {
203 all_node_ids.push(edge.from.clone());
204 }
205 if !all_node_ids.contains(&edge.to) {
206 all_node_ids.push(edge.to.clone());
207 }
208 }
209 }
210
211 let mut node_attrs_map: HashMap<String, HashMap<String, AttrValue>> = HashMap::new();
213 for node in &dot.nodes {
214 node_attrs_map.insert(node.id.clone(), node.attrs.clone());
215 }
216 for sg in &dot.subgraphs {
217 for node in &sg.nodes {
218 node_attrs_map.insert(node.id.clone(), node.attrs.clone());
219 }
220 }
221
222 for id in &all_node_ids {
224 let attrs = node_attrs_map.get(id).cloned().unwrap_or_default();
225 let merged_attrs = merge_with_defaults(&attrs, &dot.node_defaults);
226 let pipeline_node = build_pipeline_node(id, &merged_attrs);
227 let idx = graph.add_node(pipeline_node);
228 node_index.insert(id.clone(), idx);
229 }
230
231 let all_edges: Vec<_> = dot
233 .edges
234 .iter()
235 .chain(dot.subgraphs.iter().flat_map(|sg| sg.edges.iter()))
236 .collect();
237
238 for edge in all_edges {
239 let from_idx = *node_index
240 .get(&edge.from)
241 .context(format!("Edge source '{}' not found", edge.from))?;
242 let to_idx = *node_index
243 .get(&edge.to)
244 .context(format!("Edge target '{}' not found", edge.to))?;
245
246 let merged = merge_with_defaults(&edge.attrs, &dot.edge_defaults);
247 let pipeline_edge = build_pipeline_edge(&merged);
248 graph.add_edge(from_idx, to_idx, pipeline_edge);
249 }
250
251 let start_node = find_node_by_handler(&graph, &node_index, "start")
253 .context("No start node found (need a node with shape=Mdiamond)")?;
254 let exit_node = find_node_by_handler(&graph, &node_index, "exit")
255 .context("No exit node found (need a node with shape=Msquare)")?;
256
257 Ok(PipelineGraph {
258 name: dot.name.clone(),
259 graph_attrs,
260 graph,
261 node_index,
262 start_node,
263 exit_node,
264 })
265 }
266
267 pub fn node(&self, id: &str) -> Option<&PipelineNode> {
269 self.node_index
270 .get(id)
271 .map(|idx| &self.graph[*idx])
272 }
273
274 pub fn outgoing_edges(&self, idx: NodeIndex) -> Vec<(NodeIndex, &PipelineEdge)> {
276 self.graph
277 .edges(idx)
278 .map(|e| (e.target(), e.weight()))
279 .collect()
280 }
281
282 pub fn topo_order(&self) -> Result<Vec<NodeIndex>> {
284 petgraph::algo::toposort(&self.graph, None)
285 .map_err(|_| anyhow::anyhow!("Pipeline graph contains a cycle"))
286 }
287}
288
289fn merge_with_defaults(
290 attrs: &HashMap<String, AttrValue>,
291 defaults: &HashMap<String, AttrValue>,
292) -> HashMap<String, AttrValue> {
293 let mut merged = defaults.clone();
294 for (k, v) in attrs {
295 merged.insert(k.clone(), v.clone());
296 }
297 merged
298}
299
300fn build_pipeline_node(id: &str, attrs: &HashMap<String, AttrValue>) -> PipelineNode {
301 let shape = attrs
302 .get("shape")
303 .map(|v| v.as_str())
304 .unwrap_or_else(|| "box".into());
305
306 let explicit_type = attrs.get("type").map(|v| v.as_str());
307 let handler_type = explicit_type
308 .unwrap_or_else(|| handler_type_from_shape(&shape).into());
309
310 let label = attrs
311 .get("label")
312 .map(|v| v.as_str())
313 .unwrap_or_else(|| id.to_string());
314
315 let classes = attrs
316 .get("class")
317 .map(|v| {
318 v.as_str()
319 .split_whitespace()
320 .map(String::from)
321 .collect()
322 })
323 .unwrap_or_default();
324
325 let mut extra_attrs = HashMap::new();
326 let known_keys = [
327 "shape", "type", "label", "prompt", "max_retries", "goal_gate",
328 "retry_target", "fallback_retry_target", "fidelity", "thread_id",
329 "class", "timeout", "llm_model", "llm_provider", "reasoning_effort",
330 "auto_status", "allow_partial",
331 ];
332 for (k, v) in attrs {
333 if !known_keys.contains(&k.as_str()) {
334 extra_attrs.insert(k.clone(), v.clone());
335 }
336 }
337
338 PipelineNode {
339 id: id.to_string(),
340 label,
341 shape,
342 handler_type,
343 prompt: attrs.get("prompt").map(|v| v.as_str()).unwrap_or_default(),
344 max_retries: attrs
345 .get("max_retries")
346 .and_then(|v| v.as_int())
347 .unwrap_or(0) as u32,
348 goal_gate: attrs
349 .get("goal_gate")
350 .and_then(|v| v.as_bool())
351 .unwrap_or(false),
352 retry_target: attrs.get("retry_target").map(|v| v.as_str()),
353 fallback_retry_target: attrs.get("fallback_retry_target").map(|v| v.as_str()),
354 fidelity: attrs
355 .get("fidelity")
356 .and_then(|v| FidelityMode::from_str(&v.as_str())),
357 thread_id: attrs.get("thread_id").map(|v| v.as_str()),
358 classes,
359 timeout: attrs.get("timeout").and_then(|v| match v {
360 AttrValue::Duration(d) => Some(*d),
361 _ => None,
362 }),
363 llm_model: attrs.get("llm_model").map(|v| v.as_str()),
364 llm_provider: attrs.get("llm_provider").map(|v| v.as_str()),
365 reasoning_effort: attrs
366 .get("reasoning_effort")
367 .map(|v| v.as_str())
368 .unwrap_or_else(|| "high".into()),
369 auto_status: attrs
370 .get("auto_status")
371 .and_then(|v| v.as_bool())
372 .unwrap_or(true),
373 allow_partial: attrs
374 .get("allow_partial")
375 .and_then(|v| v.as_bool())
376 .unwrap_or(false),
377 extra_attrs,
378 }
379}
380
381fn build_pipeline_edge(attrs: &HashMap<String, AttrValue>) -> PipelineEdge {
382 PipelineEdge {
383 label: attrs.get("label").map(|v| v.as_str()).unwrap_or_default(),
384 condition: attrs.get("condition").map(|v| v.as_str()).unwrap_or_default(),
385 weight: attrs
386 .get("weight")
387 .and_then(|v| v.as_int())
388 .unwrap_or(0) as i32,
389 fidelity: attrs
390 .get("fidelity")
391 .and_then(|v| FidelityMode::from_str(&v.as_str())),
392 thread_id: attrs.get("thread_id").map(|v| v.as_str()),
393 loop_restart: attrs
394 .get("loop_restart")
395 .and_then(|v| v.as_bool())
396 .unwrap_or(false),
397 }
398}
399
400fn find_node_by_handler(
401 graph: &DiGraph<PipelineNode, PipelineEdge>,
402 node_index: &HashMap<String, NodeIndex>,
403 handler: &str,
404) -> Option<NodeIndex> {
405 node_index
406 .values()
407 .copied()
408 .find(|idx| graph[*idx].handler_type == handler)
409}
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414 use crate::attractor::dot_parser::parse_dot;
415
416 #[test]
417 fn test_build_simple_pipeline() {
418 let input = r#"
419 digraph pipeline {
420 graph [goal="Build feature X"]
421 start [shape=Mdiamond]
422 task_a [shape=box, label="Implement A", prompt="Write the code for A"]
423 finish [shape=Msquare]
424 start -> task_a -> finish
425 }
426 "#;
427 let dot = parse_dot(input).unwrap();
428 let pipeline = PipelineGraph::from_dot(&dot).unwrap();
429
430 assert_eq!(pipeline.name, "pipeline");
431 assert_eq!(pipeline.graph_attrs.goal, Some("Build feature X".into()));
432 assert_eq!(pipeline.graph.node_count(), 3);
433 assert_eq!(pipeline.graph.edge_count(), 2);
434
435 let start = &pipeline.graph[pipeline.start_node];
436 assert_eq!(start.handler_type, "start");
437
438 let exit = &pipeline.graph[pipeline.exit_node];
439 assert_eq!(exit.handler_type, "exit");
440
441 let task = pipeline.node("task_a").unwrap();
442 assert_eq!(task.handler_type, "codergen");
443 assert_eq!(task.prompt, "Write the code for A");
444 }
445
446 #[test]
447 fn test_shape_to_handler_mapping() {
448 assert_eq!(handler_type_from_shape("Mdiamond"), "start");
449 assert_eq!(handler_type_from_shape("Msquare"), "exit");
450 assert_eq!(handler_type_from_shape("box"), "codergen");
451 assert_eq!(handler_type_from_shape("hexagon"), "wait.human");
452 assert_eq!(handler_type_from_shape("diamond"), "conditional");
453 assert_eq!(handler_type_from_shape("component"), "parallel");
454 assert_eq!(handler_type_from_shape("tripleoctagon"), "parallel.fan_in");
455 assert_eq!(handler_type_from_shape("parallelogram"), "tool");
456 assert_eq!(handler_type_from_shape("house"), "stack.manager_loop");
457 }
458
459 #[test]
460 fn test_outgoing_edges() {
461 let input = r#"
462 digraph test {
463 start [shape=Mdiamond]
464 a [shape=box]
465 b [shape=box]
466 finish [shape=Msquare]
467 start -> a [label="go"]
468 start -> b [label="alt"]
469 a -> finish
470 b -> finish
471 }
472 "#;
473 let dot = parse_dot(input).unwrap();
474 let pipeline = PipelineGraph::from_dot(&dot).unwrap();
475
476 let edges = pipeline.outgoing_edges(pipeline.start_node);
477 assert_eq!(edges.len(), 2);
478 }
479
480 #[test]
481 fn test_node_defaults_applied() {
482 let input = r#"
483 digraph test {
484 node [reasoning_effort="medium"]
485 start [shape=Mdiamond]
486 a [shape=box]
487 finish [shape=Msquare]
488 start -> a -> finish
489 }
490 "#;
491 let dot = parse_dot(input).unwrap();
492 let pipeline = PipelineGraph::from_dot(&dot).unwrap();
493 let a = pipeline.node("a").unwrap();
494 assert_eq!(a.reasoning_effort, "medium");
495 }
496}