1use anyhow::{Context, Result};
4use petgraph::graph::{DiGraph, NodeIndex};
5use petgraph::visit::EdgeRef;
6use std::collections::HashMap;
7use std::time::Duration;
8
9use super::dot_parser::{AttrValue, DotGraph};
10
11#[derive(Debug)]
13pub struct PipelineGraph {
14 pub name: String,
15 pub graph_attrs: GraphAttrs,
16 pub graph: DiGraph<PipelineNode, PipelineEdge>,
17 pub node_index: HashMap<String, NodeIndex>,
18 pub start_node: NodeIndex,
19 pub exit_node: NodeIndex,
20}
21
22#[derive(Debug, Clone, Default)]
24pub struct GraphAttrs {
25 pub goal: Option<String>,
26 pub fidelity: Option<FidelityMode>,
27 pub model_stylesheet: Option<String>,
28 pub extra: HashMap<String, String>,
29}
30
31#[derive(Debug, Clone)]
33pub struct PipelineNode {
34 pub id: String,
35 pub label: String,
36 pub shape: String,
37 pub handler_type: String,
38 pub prompt: String,
39 pub max_retries: u32,
40 pub goal_gate: bool,
41 pub retry_target: Option<String>,
42 pub fallback_retry_target: Option<String>,
43 pub fidelity: Option<FidelityMode>,
44 pub thread_id: Option<String>,
45 pub classes: Vec<String>,
46 pub timeout: Option<Duration>,
47 pub llm_model: Option<String>,
48 pub llm_provider: Option<String>,
49 pub reasoning_effort: String,
50 pub auto_status: bool,
51 pub allow_partial: bool,
52 pub extra_attrs: HashMap<String, AttrValue>,
53}
54
55impl Default for PipelineNode {
56 fn default() -> Self {
57 Self {
58 id: String::new(),
59 label: String::new(),
60 shape: "box".into(),
61 handler_type: "codergen".into(),
62 prompt: String::new(),
63 max_retries: 0,
64 goal_gate: false,
65 retry_target: None,
66 fallback_retry_target: None,
67 fidelity: None,
68 thread_id: None,
69 classes: vec![],
70 timeout: None,
71 llm_model: None,
72 llm_provider: None,
73 reasoning_effort: "high".into(),
74 auto_status: true,
75 allow_partial: false,
76 extra_attrs: HashMap::new(),
77 }
78 }
79}
80
81#[derive(Debug, Clone)]
83pub struct PipelineEdge {
84 pub label: String,
85 pub condition: String,
86 pub weight: i32,
87 pub fidelity: Option<FidelityMode>,
88 pub thread_id: Option<String>,
89 pub loop_restart: bool,
90}
91
92impl Default for PipelineEdge {
93 fn default() -> Self {
94 Self {
95 label: String::new(),
96 condition: String::new(),
97 weight: 0,
98 fidelity: None,
99 thread_id: None,
100 loop_restart: false,
101 }
102 }
103}
104
105#[derive(Debug, Clone, PartialEq)]
107pub enum FidelityMode {
108 Full,
109 Truncate,
110 Compact,
111 Summary(SummaryLevel),
112}
113
114#[derive(Debug, Clone, PartialEq)]
116pub enum SummaryLevel {
117 Low,
118 Medium,
119 High,
120}
121
122impl FidelityMode {
123 pub fn from_str(s: &str) -> Option<Self> {
124 match s.to_lowercase().as_str() {
125 "full" => Some(FidelityMode::Full),
126 "truncate" => Some(FidelityMode::Truncate),
127 "compact" => Some(FidelityMode::Compact),
128 "summary" | "summary-medium" => Some(FidelityMode::Summary(SummaryLevel::Medium)),
129 "summary-low" => Some(FidelityMode::Summary(SummaryLevel::Low)),
130 "summary-high" => Some(FidelityMode::Summary(SummaryLevel::High)),
131 _ => None,
132 }
133 }
134}
135
136fn handler_type_from_shape(shape: &str) -> &str {
138 match shape.to_lowercase().as_str() {
139 "mdiamond" => "start",
140 "msquare" => "exit",
141 "box" | "rect" | "rectangle" => "codergen",
142 "hexagon" => "wait.human",
143 "diamond" => "conditional",
144 "component" => "parallel",
145 "tripleoctagon" => "parallel.fan_in",
146 "parallelogram" => "tool",
147 "house" => "stack.manager_loop",
148 _ => "codergen", }
150}
151
152impl PipelineGraph {
153 pub fn from_dot(dot: &DotGraph) -> Result<Self> {
155 let mut graph = DiGraph::new();
156 let mut node_index = HashMap::new();
157
158 let graph_attrs = GraphAttrs {
160 goal: dot.graph_attrs.get("goal").map(|v| v.as_str()),
161 fidelity: dot
162 .graph_attrs
163 .get("fidelity")
164 .and_then(|v| FidelityMode::from_str(&v.as_str())),
165 model_stylesheet: dot.graph_attrs.get("model_stylesheet").map(|v| v.as_str()),
166 extra: dot
167 .graph_attrs
168 .iter()
169 .filter(|(k, _)| !["goal", "fidelity", "model_stylesheet"].contains(&k.as_str()))
170 .map(|(k, v)| (k.clone(), v.as_str()))
171 .collect(),
172 };
173
174 let mut all_node_ids: Vec<String> = Vec::new();
176 for node in &dot.nodes {
177 if !all_node_ids.contains(&node.id) {
178 all_node_ids.push(node.id.clone());
179 }
180 }
181 for edge in &dot.edges {
182 if !all_node_ids.contains(&edge.from) {
183 all_node_ids.push(edge.from.clone());
184 }
185 if !all_node_ids.contains(&edge.to) {
186 all_node_ids.push(edge.to.clone());
187 }
188 }
189 for sg in &dot.subgraphs {
190 for node in &sg.nodes {
191 if !all_node_ids.contains(&node.id) {
192 all_node_ids.push(node.id.clone());
193 }
194 }
195 for edge in &sg.edges {
196 if !all_node_ids.contains(&edge.from) {
197 all_node_ids.push(edge.from.clone());
198 }
199 if !all_node_ids.contains(&edge.to) {
200 all_node_ids.push(edge.to.clone());
201 }
202 }
203 }
204
205 let mut node_attrs_map: HashMap<String, HashMap<String, AttrValue>> = HashMap::new();
207 for node in &dot.nodes {
208 node_attrs_map.insert(node.id.clone(), node.attrs.clone());
209 }
210 for sg in &dot.subgraphs {
211 for node in &sg.nodes {
212 node_attrs_map.insert(node.id.clone(), node.attrs.clone());
213 }
214 }
215
216 for id in &all_node_ids {
218 let attrs = node_attrs_map.get(id).cloned().unwrap_or_default();
219 let merged_attrs = merge_with_defaults(&attrs, &dot.node_defaults);
220 let pipeline_node = build_pipeline_node(id, &merged_attrs);
221 let idx = graph.add_node(pipeline_node);
222 node_index.insert(id.clone(), idx);
223 }
224
225 let all_edges: Vec<_> = dot
227 .edges
228 .iter()
229 .chain(dot.subgraphs.iter().flat_map(|sg| sg.edges.iter()))
230 .collect();
231
232 for edge in all_edges {
233 let from_idx = *node_index
234 .get(&edge.from)
235 .context(format!("Edge source '{}' not found", edge.from))?;
236 let to_idx = *node_index
237 .get(&edge.to)
238 .context(format!("Edge target '{}' not found", edge.to))?;
239
240 let merged = merge_with_defaults(&edge.attrs, &dot.edge_defaults);
241 let pipeline_edge = build_pipeline_edge(&merged);
242 graph.add_edge(from_idx, to_idx, pipeline_edge);
243 }
244
245 let start_node = find_node_by_handler(&graph, &node_index, "start")
247 .context("No start node found (need a node with shape=Mdiamond)")?;
248 let exit_node = find_node_by_handler(&graph, &node_index, "exit")
249 .context("No exit node found (need a node with shape=Msquare)")?;
250
251 Ok(PipelineGraph {
252 name: dot.name.clone(),
253 graph_attrs,
254 graph,
255 node_index,
256 start_node,
257 exit_node,
258 })
259 }
260
261 pub fn node(&self, id: &str) -> Option<&PipelineNode> {
263 self.node_index.get(id).map(|idx| &self.graph[*idx])
264 }
265
266 pub fn outgoing_edges(&self, idx: NodeIndex) -> Vec<(NodeIndex, &PipelineEdge)> {
268 self.graph
269 .edges(idx)
270 .map(|e| (e.target(), e.weight()))
271 .collect()
272 }
273
274 pub fn topo_order(&self) -> Result<Vec<NodeIndex>> {
276 petgraph::algo::toposort(&self.graph, None)
277 .map_err(|_| anyhow::anyhow!("Pipeline graph contains a cycle"))
278 }
279}
280
281fn merge_with_defaults(
282 attrs: &HashMap<String, AttrValue>,
283 defaults: &HashMap<String, AttrValue>,
284) -> HashMap<String, AttrValue> {
285 let mut merged = defaults.clone();
286 for (k, v) in attrs {
287 merged.insert(k.clone(), v.clone());
288 }
289 merged
290}
291
292fn build_pipeline_node(id: &str, attrs: &HashMap<String, AttrValue>) -> PipelineNode {
293 let shape = attrs
294 .get("shape")
295 .map(|v| v.as_str())
296 .unwrap_or_else(|| "box".into());
297
298 let explicit_type = attrs.get("type").map(|v| v.as_str());
299 let handler_type = explicit_type.unwrap_or_else(|| handler_type_from_shape(&shape).into());
300
301 let label = attrs
302 .get("label")
303 .map(|v| v.as_str())
304 .unwrap_or_else(|| id.to_string());
305
306 let classes = attrs
307 .get("class")
308 .map(|v| v.as_str().split_whitespace().map(String::from).collect())
309 .unwrap_or_default();
310
311 let mut extra_attrs = HashMap::new();
312 let known_keys = [
313 "shape",
314 "type",
315 "label",
316 "prompt",
317 "max_retries",
318 "goal_gate",
319 "retry_target",
320 "fallback_retry_target",
321 "fidelity",
322 "thread_id",
323 "class",
324 "timeout",
325 "llm_model",
326 "llm_provider",
327 "reasoning_effort",
328 "auto_status",
329 "allow_partial",
330 ];
331 for (k, v) in attrs {
332 if !known_keys.contains(&k.as_str()) {
333 extra_attrs.insert(k.clone(), v.clone());
334 }
335 }
336
337 PipelineNode {
338 id: id.to_string(),
339 label,
340 shape,
341 handler_type,
342 prompt: attrs.get("prompt").map(|v| v.as_str()).unwrap_or_default(),
343 max_retries: attrs
344 .get("max_retries")
345 .and_then(|v| v.as_int())
346 .unwrap_or(0) as u32,
347 goal_gate: attrs
348 .get("goal_gate")
349 .and_then(|v| v.as_bool())
350 .unwrap_or(false),
351 retry_target: attrs.get("retry_target").map(|v| v.as_str()),
352 fallback_retry_target: attrs.get("fallback_retry_target").map(|v| v.as_str()),
353 fidelity: attrs
354 .get("fidelity")
355 .and_then(|v| FidelityMode::from_str(&v.as_str())),
356 thread_id: attrs.get("thread_id").map(|v| v.as_str()),
357 classes,
358 timeout: attrs.get("timeout").and_then(|v| match v {
359 AttrValue::Duration(d) => Some(*d),
360 _ => None,
361 }),
362 llm_model: attrs.get("llm_model").map(|v| v.as_str()),
363 llm_provider: attrs.get("llm_provider").map(|v| v.as_str()),
364 reasoning_effort: attrs
365 .get("reasoning_effort")
366 .map(|v| v.as_str())
367 .unwrap_or_else(|| "high".into()),
368 auto_status: attrs
369 .get("auto_status")
370 .and_then(|v| v.as_bool())
371 .unwrap_or(true),
372 allow_partial: attrs
373 .get("allow_partial")
374 .and_then(|v| v.as_bool())
375 .unwrap_or(false),
376 extra_attrs,
377 }
378}
379
380fn build_pipeline_edge(attrs: &HashMap<String, AttrValue>) -> PipelineEdge {
381 PipelineEdge {
382 label: attrs.get("label").map(|v| v.as_str()).unwrap_or_default(),
383 condition: attrs
384 .get("condition")
385 .map(|v| v.as_str())
386 .unwrap_or_default(),
387 weight: attrs.get("weight").and_then(|v| v.as_int()).unwrap_or(0) as i32,
388 fidelity: attrs
389 .get("fidelity")
390 .and_then(|v| FidelityMode::from_str(&v.as_str())),
391 thread_id: attrs.get("thread_id").map(|v| v.as_str()),
392 loop_restart: attrs
393 .get("loop_restart")
394 .and_then(|v| v.as_bool())
395 .unwrap_or(false),
396 }
397}
398
399fn find_node_by_handler(
400 graph: &DiGraph<PipelineNode, PipelineEdge>,
401 node_index: &HashMap<String, NodeIndex>,
402 handler: &str,
403) -> Option<NodeIndex> {
404 node_index
405 .values()
406 .copied()
407 .find(|idx| graph[*idx].handler_type == handler)
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413 use crate::attractor::dot_parser::parse_dot;
414
415 #[test]
416 fn test_build_simple_pipeline() {
417 let input = r#"
418 digraph pipeline {
419 graph [goal="Build feature X"]
420 start [shape=Mdiamond]
421 task_a [shape=box, label="Implement A", prompt="Write the code for A"]
422 finish [shape=Msquare]
423 start -> task_a -> finish
424 }
425 "#;
426 let dot = parse_dot(input).unwrap();
427 let pipeline = PipelineGraph::from_dot(&dot).unwrap();
428
429 assert_eq!(pipeline.name, "pipeline");
430 assert_eq!(pipeline.graph_attrs.goal, Some("Build feature X".into()));
431 assert_eq!(pipeline.graph.node_count(), 3);
432 assert_eq!(pipeline.graph.edge_count(), 2);
433
434 let start = &pipeline.graph[pipeline.start_node];
435 assert_eq!(start.handler_type, "start");
436
437 let exit = &pipeline.graph[pipeline.exit_node];
438 assert_eq!(exit.handler_type, "exit");
439
440 let task = pipeline.node("task_a").unwrap();
441 assert_eq!(task.handler_type, "codergen");
442 assert_eq!(task.prompt, "Write the code for A");
443 }
444
445 #[test]
446 fn test_shape_to_handler_mapping() {
447 assert_eq!(handler_type_from_shape("Mdiamond"), "start");
448 assert_eq!(handler_type_from_shape("Msquare"), "exit");
449 assert_eq!(handler_type_from_shape("box"), "codergen");
450 assert_eq!(handler_type_from_shape("hexagon"), "wait.human");
451 assert_eq!(handler_type_from_shape("diamond"), "conditional");
452 assert_eq!(handler_type_from_shape("component"), "parallel");
453 assert_eq!(handler_type_from_shape("tripleoctagon"), "parallel.fan_in");
454 assert_eq!(handler_type_from_shape("parallelogram"), "tool");
455 assert_eq!(handler_type_from_shape("house"), "stack.manager_loop");
456 }
457
458 #[test]
459 fn test_outgoing_edges() {
460 let input = r#"
461 digraph test {
462 start [shape=Mdiamond]
463 a [shape=box]
464 b [shape=box]
465 finish [shape=Msquare]
466 start -> a [label="go"]
467 start -> b [label="alt"]
468 a -> finish
469 b -> finish
470 }
471 "#;
472 let dot = parse_dot(input).unwrap();
473 let pipeline = PipelineGraph::from_dot(&dot).unwrap();
474
475 let edges = pipeline.outgoing_edges(pipeline.start_node);
476 assert_eq!(edges.len(), 2);
477 }
478
479 #[test]
480 fn test_node_defaults_applied() {
481 let input = r#"
482 digraph test {
483 node [reasoning_effort="medium"]
484 start [shape=Mdiamond]
485 a [shape=box]
486 finish [shape=Msquare]
487 start -> a -> finish
488 }
489 "#;
490 let dot = parse_dot(input).unwrap();
491 let pipeline = PipelineGraph::from_dot(&dot).unwrap();
492 let a = pipeline.node("a").unwrap();
493 assert_eq!(a.reasoning_effort, "medium");
494 }
495}