Skip to main content

simple_agents_workflow/
visualize.rs

1use crate::ir::{NodeKind, WorkflowDefinition};
2
3/// Renders a workflow definition as a Mermaid flowchart.
4pub fn workflow_to_mermaid(workflow: &WorkflowDefinition) -> String {
5    let mut lines = Vec::new();
6    lines.push("flowchart TD".to_string());
7
8    for node in &workflow.nodes {
9        let kind = node_kind_label(&node.kind);
10        lines.push(format!(
11            "  {}[\"{}\\n({})\"]",
12            sanitize_id(&node.id),
13            escape_label(&node.id),
14            kind
15        ));
16    }
17
18    for node in &workflow.nodes {
19        let from = sanitize_id(&node.id);
20        for (label, to) in edge_specs(&node.kind) {
21            let edge = if label.is_empty() {
22                format!("  {} --> {}", from, sanitize_id(&to))
23            } else {
24                format!(
25                    "  {} -- \"{}\" --> {}",
26                    from,
27                    escape_label(&label),
28                    sanitize_id(&to)
29                )
30            };
31            lines.push(edge);
32        }
33    }
34
35    lines.join("\n")
36}
37
38fn edge_specs(kind: &NodeKind) -> Vec<(String, String)> {
39    match kind {
40        NodeKind::Start { next } => vec![("".to_string(), next.clone())],
41        NodeKind::Llm { next, .. } => next
42            .as_ref()
43            .map(|n| vec![("".to_string(), n.clone())])
44            .unwrap_or_default(),
45        NodeKind::Tool { next, .. } => next
46            .as_ref()
47            .map(|n| vec![("".to_string(), n.clone())])
48            .unwrap_or_default(),
49        NodeKind::Condition {
50            on_true, on_false, ..
51        } => vec![
52            ("true".to_string(), on_true.clone()),
53            ("false".to_string(), on_false.clone()),
54        ],
55        NodeKind::Debounce {
56            next,
57            on_suppressed,
58            ..
59        } => {
60            let mut edges = vec![("emit".to_string(), next.clone())];
61            if let Some(target) = on_suppressed.as_ref() {
62                edges.push(("suppressed".to_string(), target.clone()));
63            }
64            edges
65        }
66        NodeKind::Throttle {
67            next, on_throttled, ..
68        } => {
69            let mut edges = vec![("emit".to_string(), next.clone())];
70            if let Some(target) = on_throttled.as_ref() {
71                edges.push(("throttled".to_string(), target.clone()));
72            }
73            edges
74        }
75        NodeKind::RetryCompensate {
76            next,
77            on_compensated,
78            ..
79        } => {
80            let mut edges = vec![("success".to_string(), next.clone())];
81            if let Some(target) = on_compensated.as_ref() {
82                edges.push(("compensated".to_string(), target.clone()));
83            }
84            edges
85        }
86        NodeKind::HumanInTheLoop {
87            on_approve,
88            on_reject,
89            ..
90        } => {
91            vec![
92                ("approve".to_string(), on_approve.clone()),
93                ("reject".to_string(), on_reject.clone()),
94            ]
95        }
96        NodeKind::CacheWrite { next, .. } => vec![("".to_string(), next.clone())],
97        NodeKind::CacheRead { next, on_miss, .. } => {
98            let mut edges = vec![("hit".to_string(), next.clone())];
99            if let Some(target) = on_miss.as_ref() {
100                edges.push(("miss".to_string(), target.clone()));
101            }
102            edges
103        }
104        NodeKind::EventTrigger {
105            next, on_mismatch, ..
106        } => {
107            let mut edges = vec![("match".to_string(), next.clone())];
108            if let Some(target) = on_mismatch.as_ref() {
109                edges.push(("mismatch".to_string(), target.clone()));
110            }
111            edges
112        }
113        NodeKind::Router { routes, default } => {
114            let mut edges: Vec<(String, String)> = routes
115                .iter()
116                .enumerate()
117                .map(|(i, route)| {
118                    let mut label = String::from("route");
119                    label.push_str(&(i + 1).to_string());
120                    (label, route.next.clone())
121                })
122                .collect();
123            edges.push(("default".to_string(), default.clone()));
124            edges
125        }
126        NodeKind::Transform { next, .. } => vec![("".to_string(), next.clone())],
127        NodeKind::Loop { body, next, .. } => vec![
128            ("continue".to_string(), body.clone()),
129            ("done".to_string(), next.clone()),
130        ],
131        NodeKind::Subgraph { next, .. } => next
132            .as_ref()
133            .map(|n| vec![("".to_string(), n.clone())])
134            .unwrap_or_default(),
135        NodeKind::Batch { next, .. } => vec![("".to_string(), next.clone())],
136        NodeKind::Filter { next, .. } => vec![("".to_string(), next.clone())],
137        NodeKind::Parallel { branches, next, .. } => {
138            let mut edges = branches
139                .iter()
140                .map(|branch| ("branch".to_string(), branch.clone()))
141                .collect::<Vec<(String, String)>>();
142            edges.push(("join".to_string(), next.clone()));
143            edges
144        }
145        NodeKind::Merge { sources, next, .. } => {
146            let mut edges = sources
147                .iter()
148                .map(|source| ("source".to_string(), source.clone()))
149                .collect::<Vec<(String, String)>>();
150            edges.push(("next".to_string(), next.clone()));
151            edges
152        }
153        NodeKind::Map { next, .. } => vec![("".to_string(), next.clone())],
154        NodeKind::Reduce { next, .. } => vec![("".to_string(), next.clone())],
155        NodeKind::End => Vec::new(),
156    }
157}
158
159fn node_kind_label(kind: &NodeKind) -> &'static str {
160    match kind {
161        NodeKind::Start { .. } => "start",
162        NodeKind::Llm { .. } => "llm",
163        NodeKind::Tool { .. } => "tool",
164        NodeKind::Condition { .. } => "condition",
165        NodeKind::Debounce { .. } => "debounce",
166        NodeKind::Throttle { .. } => "throttle",
167        NodeKind::RetryCompensate { .. } => "retry_compensate",
168        NodeKind::HumanInTheLoop { .. } => "human_in_the_loop",
169        NodeKind::CacheWrite { .. } => "cache_write",
170        NodeKind::CacheRead { .. } => "cache_read",
171        NodeKind::EventTrigger { .. } => "event_trigger",
172        NodeKind::Router { .. } => "router",
173        NodeKind::Transform { .. } => "transform",
174        NodeKind::Loop { .. } => "loop",
175        NodeKind::Subgraph { .. } => "subgraph",
176        NodeKind::Batch { .. } => "batch",
177        NodeKind::Filter { .. } => "filter",
178        NodeKind::Parallel { .. } => "parallel",
179        NodeKind::Merge { .. } => "merge",
180        NodeKind::Map { .. } => "map",
181        NodeKind::Reduce { .. } => "reduce",
182        NodeKind::End => "end",
183    }
184}
185
186fn sanitize_id(id: &str) -> String {
187    let mut out = String::with_capacity(id.len() + 1);
188    if id
189        .chars()
190        .next()
191        .is_some_and(|ch| ch.is_ascii_alphabetic() || ch == '_')
192    {
193        out.push_str(id);
194    } else {
195        out.push('n');
196        out.push('_');
197        out.push_str(id);
198    }
199    out.chars()
200        .map(|ch| {
201            if ch.is_ascii_alphanumeric() || ch == '_' {
202                ch
203            } else {
204                '_'
205            }
206        })
207        .collect()
208}
209
210fn escape_label(label: &str) -> String {
211    label.replace('"', "\\\"")
212}
213
214#[cfg(test)]
215mod tests {
216    use serde_json::json;
217
218    use crate::ir::{Node, NodeKind, RouterRoute, WorkflowDefinition};
219    use crate::visualize::workflow_to_mermaid;
220
221    #[test]
222    fn renders_condition_edges_with_labels() {
223        let workflow = WorkflowDefinition {
224            version: "v0".to_string(),
225            name: "cond".to_string(),
226            nodes: vec![
227                Node {
228                    id: "start".to_string(),
229                    kind: NodeKind::Start {
230                        next: "route".to_string(),
231                    },
232                },
233                Node {
234                    id: "route".to_string(),
235                    kind: NodeKind::Condition {
236                        expression: "input.ok == true".to_string(),
237                        on_true: "yes".to_string(),
238                        on_false: "no".to_string(),
239                    },
240                },
241                Node {
242                    id: "yes".to_string(),
243                    kind: NodeKind::End,
244                },
245                Node {
246                    id: "no".to_string(),
247                    kind: NodeKind::End,
248                },
249            ],
250        };
251
252        let mermaid = workflow_to_mermaid(&workflow);
253        assert!(mermaid.contains("route -- \"true\" --> yes"));
254        assert!(mermaid.contains("route -- \"false\" --> no"));
255    }
256
257    #[test]
258    fn renders_parallel_and_router_shapes() {
259        let workflow = WorkflowDefinition {
260            version: "v0".to_string(),
261            name: "advanced".to_string(),
262            nodes: vec![
263                Node {
264                    id: "fanout".to_string(),
265                    kind: NodeKind::Parallel {
266                        branches: vec!["a".to_string(), "b".to_string()],
267                        next: "join".to_string(),
268                        max_in_flight: Some(2),
269                    },
270                },
271                Node {
272                    id: "pick".to_string(),
273                    kind: NodeKind::Router {
274                        routes: vec![RouterRoute {
275                            when: "input.x == 1".to_string(),
276                            next: "a".to_string(),
277                        }],
278                        default: "b".to_string(),
279                    },
280                },
281                Node {
282                    id: "a".to_string(),
283                    kind: NodeKind::Tool {
284                        tool: "t".to_string(),
285                        input: json!({}),
286                        next: Some("join".to_string()),
287                    },
288                },
289                Node {
290                    id: "b".to_string(),
291                    kind: NodeKind::End,
292                },
293                Node {
294                    id: "join".to_string(),
295                    kind: NodeKind::Merge {
296                        sources: vec!["a".to_string(), "b".to_string()],
297                        policy: crate::ir::MergePolicy::All,
298                        quorum: None,
299                        next: "b".to_string(),
300                    },
301                },
302            ],
303        };
304
305        let mermaid = workflow_to_mermaid(&workflow);
306        assert!(mermaid.contains("fanout -- \"branch\" --> a"));
307        assert!(mermaid.contains("fanout -- \"join\" --> join"));
308        assert!(mermaid.contains("pick -- \"route1\" --> a"));
309        assert!(mermaid.contains("pick -- \"default\" --> b"));
310    }
311}