Skip to main content

scud/attractor/
dot_parser.rs

1//! DOT subset parser for Attractor pipeline graphs.
2//!
3//! Parses the Attractor DOT dialect: one digraph per file, directed edges only,
4//! node/edge attributes, chained edges, subgraph scoping, default blocks.
5
6use anyhow::{bail, Context, Result};
7use std::collections::HashMap;
8
9/// A parsed DOT graph.
10#[derive(Debug, Clone)]
11pub struct DotGraph {
12    pub name: String,
13    pub graph_attrs: HashMap<String, AttrValue>,
14    pub nodes: Vec<DotNode>,
15    pub edges: Vec<DotEdge>,
16    pub subgraphs: Vec<DotSubgraph>,
17    pub node_defaults: HashMap<String, AttrValue>,
18    pub edge_defaults: HashMap<String, AttrValue>,
19}
20
21/// A node in the DOT graph.
22#[derive(Debug, Clone)]
23pub struct DotNode {
24    pub id: String,
25    pub attrs: HashMap<String, AttrValue>,
26}
27
28/// An edge in the DOT graph.
29#[derive(Debug, Clone)]
30pub struct DotEdge {
31    pub from: String,
32    pub to: String,
33    pub attrs: HashMap<String, AttrValue>,
34}
35
36/// A subgraph (cluster) in the DOT graph.
37#[derive(Debug, Clone)]
38pub struct DotSubgraph {
39    pub name: Option<String>,
40    pub attrs: HashMap<String, AttrValue>,
41    pub nodes: Vec<DotNode>,
42    pub edges: Vec<DotEdge>,
43}
44
45/// Typed attribute value.
46#[derive(Debug, Clone, PartialEq)]
47pub enum AttrValue {
48    Str(String),
49    Int(i64),
50    Float(f64),
51    Bool(bool),
52    Duration(std::time::Duration),
53}
54
55impl AttrValue {
56    /// Get the value as a string, regardless of type.
57    pub fn as_str(&self) -> String {
58        match self {
59            AttrValue::Str(s) => s.clone(),
60            AttrValue::Int(i) => i.to_string(),
61            AttrValue::Float(f) => f.to_string(),
62            AttrValue::Bool(b) => b.to_string(),
63            AttrValue::Duration(d) => format!("{}s", d.as_secs()),
64        }
65    }
66
67    /// Get the value as a string reference if it is a string.
68    pub fn str_ref(&self) -> Option<&str> {
69        match self {
70            AttrValue::Str(s) => Some(s),
71            _ => None,
72        }
73    }
74
75    /// Get the value as an integer if possible.
76    pub fn as_int(&self) -> Option<i64> {
77        match self {
78            AttrValue::Int(i) => Some(*i),
79            _ => None,
80        }
81    }
82
83    /// Get the value as a bool if possible.
84    pub fn as_bool(&self) -> Option<bool> {
85        match self {
86            AttrValue::Bool(b) => Some(*b),
87            _ => None,
88        }
89    }
90}
91
92/// Parse a DOT file string into a DotGraph.
93pub fn parse_dot(input: &str) -> Result<DotGraph> {
94    let mut parser = DotParser::new(input);
95    parser.parse()
96}
97
98struct DotParser<'a> {
99    input: &'a str,
100    pos: usize,
101}
102
103impl<'a> DotParser<'a> {
104    fn new(input: &'a str) -> Self {
105        Self { input, pos: 0 }
106    }
107
108    fn parse(&mut self) -> Result<DotGraph> {
109        self.skip_ws();
110
111        // Expect "digraph"
112        self.expect_keyword("digraph")
113            .context("Expected 'digraph' keyword")?;
114        self.skip_ws();
115
116        // Graph name (optional)
117        let name = if self.peek_char() != Some('{') {
118            self.read_identifier()?
119        } else {
120            String::new()
121        };
122        self.skip_ws();
123
124        self.expect_char('{')?;
125
126        let mut graph = DotGraph {
127            name,
128            graph_attrs: HashMap::new(),
129            nodes: Vec::new(),
130            edges: Vec::new(),
131            subgraphs: Vec::new(),
132            node_defaults: HashMap::new(),
133            edge_defaults: HashMap::new(),
134        };
135
136        self.parse_body(&mut graph)?;
137
138        self.skip_ws();
139        self.expect_char('}')?;
140
141        Ok(graph)
142    }
143
144    fn parse_body(&mut self, graph: &mut DotGraph) -> Result<()> {
145        loop {
146            self.skip_ws();
147            if self.peek_char() == Some('}') || self.is_eof() {
148                break;
149            }
150
151            // Skip comments
152            if self.peek_str("//") {
153                self.skip_line();
154                continue;
155            }
156            if self.peek_str("/*") {
157                self.skip_block_comment();
158                continue;
159            }
160
161            // Check for keyword statements
162            if self.peek_keyword("node") {
163                self.advance(4);
164                self.skip_ws();
165                if self.peek_char() == Some('[') {
166                    let attrs = self.parse_attr_list()?;
167                    for (k, v) in attrs {
168                        graph.node_defaults.insert(k, v);
169                    }
170                }
171                self.skip_optional_semicolon();
172                continue;
173            }
174
175            if self.peek_keyword("edge") {
176                self.advance(4);
177                self.skip_ws();
178                if self.peek_char() == Some('[') {
179                    let attrs = self.parse_attr_list()?;
180                    for (k, v) in attrs {
181                        graph.edge_defaults.insert(k, v);
182                    }
183                }
184                self.skip_optional_semicolon();
185                continue;
186            }
187
188            if self.peek_keyword("graph") {
189                self.advance(5);
190                self.skip_ws();
191                if self.peek_char() == Some('[') {
192                    let attrs = self.parse_attr_list()?;
193                    for (k, v) in attrs {
194                        graph.graph_attrs.insert(k, v);
195                    }
196                }
197                self.skip_optional_semicolon();
198                continue;
199            }
200
201            // Check for subgraph
202            if self.peek_keyword("subgraph") {
203                let sg = self.parse_subgraph()?;
204                graph.subgraphs.push(sg);
205                self.skip_optional_semicolon();
206                continue;
207            }
208
209            // Must be a node or edge statement
210            let id = self.read_identifier_or_quoted()?;
211            self.skip_ws();
212
213            if self.peek_str("->") {
214                // Edge chain: a -> b -> c [attrs]
215                let mut chain = vec![id];
216                while self.peek_str("->") {
217                    self.advance(2);
218                    self.skip_ws();
219                    chain.push(self.read_identifier_or_quoted()?);
220                    self.skip_ws();
221                }
222
223                let attrs = if self.peek_char() == Some('[') {
224                    self.parse_attr_list()?
225                } else {
226                    HashMap::new()
227                };
228
229                // Expand chain into individual edges
230                for window in chain.windows(2) {
231                    graph.edges.push(DotEdge {
232                        from: window[0].clone(),
233                        to: window[1].clone(),
234                        attrs: attrs.clone(),
235                    });
236                }
237            } else {
238                // Node statement with optional attrs
239                let attrs = if self.peek_char() == Some('[') {
240                    self.parse_attr_list()?
241                } else {
242                    HashMap::new()
243                };
244
245                graph.nodes.push(DotNode {
246                    id,
247                    attrs,
248                });
249            }
250
251            self.skip_optional_semicolon();
252        }
253
254        Ok(())
255    }
256
257    fn parse_subgraph(&mut self) -> Result<DotSubgraph> {
258        self.expect_keyword("subgraph")?;
259        self.skip_ws();
260
261        let name = if self.peek_char() != Some('{') {
262            Some(self.read_identifier_or_quoted()?)
263        } else {
264            None
265        };
266        self.skip_ws();
267        self.expect_char('{')?;
268
269        let mut sg = DotSubgraph {
270            name,
271            attrs: HashMap::new(),
272            nodes: Vec::new(),
273            edges: Vec::new(),
274        };
275
276        // Parse subgraph body (simplified — no nested subgraphs)
277        loop {
278            self.skip_ws();
279            if self.peek_char() == Some('}') || self.is_eof() {
280                break;
281            }
282            if self.peek_str("//") {
283                self.skip_line();
284                continue;
285            }
286            if self.peek_str("/*") {
287                self.skip_block_comment();
288                continue;
289            }
290
291            // Check for graph/node/edge defaults
292            if self.peek_keyword("graph") {
293                self.advance(5);
294                self.skip_ws();
295                if self.peek_char() == Some('[') {
296                    let attrs = self.parse_attr_list()?;
297                    sg.attrs.extend(attrs);
298                }
299                self.skip_optional_semicolon();
300                continue;
301            }
302
303            let id = self.read_identifier_or_quoted()?;
304            self.skip_ws();
305
306            if self.peek_str("->") {
307                let mut chain = vec![id];
308                while self.peek_str("->") {
309                    self.advance(2);
310                    self.skip_ws();
311                    chain.push(self.read_identifier_or_quoted()?);
312                    self.skip_ws();
313                }
314                let attrs = if self.peek_char() == Some('[') {
315                    self.parse_attr_list()?
316                } else {
317                    HashMap::new()
318                };
319                for window in chain.windows(2) {
320                    sg.edges.push(DotEdge {
321                        from: window[0].clone(),
322                        to: window[1].clone(),
323                        attrs: attrs.clone(),
324                    });
325                }
326            } else {
327                let attrs = if self.peek_char() == Some('[') {
328                    self.parse_attr_list()?
329                } else {
330                    HashMap::new()
331                };
332                sg.nodes.push(DotNode { id, attrs });
333            }
334
335            self.skip_optional_semicolon();
336        }
337
338        self.expect_char('}')?;
339        Ok(sg)
340    }
341
342    fn parse_attr_list(&mut self) -> Result<HashMap<String, AttrValue>> {
343        self.expect_char('[')?;
344        let mut attrs = HashMap::new();
345
346        loop {
347            self.skip_ws();
348            if self.peek_char() == Some(']') {
349                self.advance(1);
350                break;
351            }
352
353            let key = self.read_identifier()?;
354            self.skip_ws();
355            self.expect_char('=')?;
356            self.skip_ws();
357            let value = self.read_attr_value()?;
358            attrs.insert(key, value);
359
360            self.skip_ws();
361            // Optional comma or semicolon separator
362            if self.peek_char() == Some(',') || self.peek_char() == Some(';') {
363                self.advance(1);
364            }
365        }
366
367        Ok(attrs)
368    }
369
370    fn read_attr_value(&mut self) -> Result<AttrValue> {
371        let ch = self.peek_char().context("Unexpected EOF in attribute value")?;
372
373        if ch == '"' {
374            let s = self.read_quoted_string()?;
375            // Try to parse as duration (e.g., "30s", "5m")
376            if let Some(d) = parse_duration_str(&s) {
377                return Ok(AttrValue::Duration(d));
378            }
379            Ok(AttrValue::Str(s))
380        } else if ch == '-' || ch.is_ascii_digit() {
381            let num_str = self.read_number_str();
382            if num_str.contains('.') {
383                Ok(AttrValue::Float(
384                    num_str.parse().context("Invalid float")?,
385                ))
386            } else {
387                Ok(AttrValue::Int(num_str.parse().context("Invalid integer")?))
388            }
389        } else {
390            // Bare word — could be bool or string
391            let word = self.read_identifier()?;
392            match word.to_lowercase().as_str() {
393                "true" | "yes" => Ok(AttrValue::Bool(true)),
394                "false" | "no" => Ok(AttrValue::Bool(false)),
395                _ => Ok(AttrValue::Str(word)),
396            }
397        }
398    }
399
400    fn read_quoted_string(&mut self) -> Result<String> {
401        self.expect_char('"')?;
402        let mut s = String::new();
403        loop {
404            match self.next_char() {
405                Some('\\') => match self.next_char() {
406                    Some('n') => s.push('\n'),
407                    Some('t') => s.push('\t'),
408                    Some('"') => s.push('"'),
409                    Some('\\') => s.push('\\'),
410                    Some(c) => {
411                        s.push('\\');
412                        s.push(c);
413                    }
414                    None => bail!("Unterminated escape in string"),
415                },
416                Some('"') => break,
417                Some(c) => s.push(c),
418                None => bail!("Unterminated string"),
419            }
420        }
421        Ok(s)
422    }
423
424    fn read_identifier(&mut self) -> Result<String> {
425        let start = self.pos;
426        while let Some(c) = self.peek_char() {
427            if c.is_alphanumeric() || c == '_' || c == '.' || c == '-' {
428                self.advance(1);
429            } else {
430                break;
431            }
432        }
433        if self.pos == start {
434            bail!(
435                "Expected identifier at position {}, got {:?}",
436                self.pos,
437                self.peek_char()
438            );
439        }
440        Ok(self.input[start..self.pos].to_string())
441    }
442
443    fn read_identifier_or_quoted(&mut self) -> Result<String> {
444        if self.peek_char() == Some('"') {
445            self.read_quoted_string()
446        } else {
447            self.read_identifier()
448        }
449    }
450
451    fn read_number_str(&mut self) -> String {
452        let start = self.pos;
453        if self.peek_char() == Some('-') {
454            self.advance(1);
455        }
456        while let Some(c) = self.peek_char() {
457            if c.is_ascii_digit() || c == '.' {
458                self.advance(1);
459            } else {
460                break;
461            }
462        }
463        self.input[start..self.pos].to_string()
464    }
465
466    // --- Utility methods ---
467
468    fn skip_ws(&mut self) {
469        loop {
470            match self.peek_char() {
471                Some(c) if c.is_whitespace() => {
472                    self.advance(1);
473                }
474                Some('/') if self.peek_str("//") => {
475                    self.skip_line();
476                }
477                Some('/') if self.peek_str("/*") => {
478                    self.skip_block_comment();
479                }
480                _ => break,
481            }
482        }
483    }
484
485    fn skip_line(&mut self) {
486        while let Some(c) = self.next_char() {
487            if c == '\n' {
488                break;
489            }
490        }
491    }
492
493    fn skip_block_comment(&mut self) {
494        self.advance(2); // skip /*
495        while !self.is_eof() {
496            if self.peek_str("*/") {
497                self.advance(2);
498                return;
499            }
500            self.advance(1);
501        }
502    }
503
504    fn skip_optional_semicolon(&mut self) {
505        self.skip_ws();
506        if self.peek_char() == Some(';') {
507            self.advance(1);
508        }
509    }
510
511    fn peek_char(&self) -> Option<char> {
512        self.input[self.pos..].chars().next()
513    }
514
515    fn next_char(&mut self) -> Option<char> {
516        let c = self.input[self.pos..].chars().next()?;
517        self.pos += c.len_utf8();
518        Some(c)
519    }
520
521    fn advance(&mut self, n: usize) {
522        self.pos = (self.pos + n).min(self.input.len());
523    }
524
525    fn is_eof(&self) -> bool {
526        self.pos >= self.input.len()
527    }
528
529    fn peek_str(&self, s: &str) -> bool {
530        self.input[self.pos..].starts_with(s)
531    }
532
533    fn peek_keyword(&self, kw: &str) -> bool {
534        if !self.input[self.pos..].starts_with(kw) {
535            return false;
536        }
537        // Must be followed by non-identifier char
538        let after = self.pos + kw.len();
539        if after >= self.input.len() {
540            return true;
541        }
542        let next = self.input[after..].chars().next().unwrap();
543        !next.is_alphanumeric() && next != '_'
544    }
545
546    fn expect_keyword(&mut self, kw: &str) -> Result<()> {
547        if !self.peek_keyword(kw) {
548            bail!(
549                "Expected '{}' at position {}, got '{}'",
550                kw,
551                self.pos,
552                &self.input[self.pos..self.pos + 10.min(self.input.len() - self.pos)]
553            );
554        }
555        self.advance(kw.len());
556        Ok(())
557    }
558
559    fn expect_char(&mut self, expected: char) -> Result<()> {
560        match self.next_char() {
561            Some(c) if c == expected => Ok(()),
562            Some(c) => bail!("Expected '{}', got '{}' at position {}", expected, c, self.pos - 1),
563            None => bail!("Expected '{}', got EOF", expected),
564        }
565    }
566}
567
568/// Parse a duration string like "30s", "5m", "1h".
569fn parse_duration_str(s: &str) -> Option<std::time::Duration> {
570    let s = s.trim();
571    if s.ends_with('s') {
572        let n: u64 = s[..s.len() - 1].parse().ok()?;
573        Some(std::time::Duration::from_secs(n))
574    } else if s.ends_with('m') {
575        let n: u64 = s[..s.len() - 1].parse().ok()?;
576        Some(std::time::Duration::from_secs(n * 60))
577    } else if s.ends_with('h') {
578        let n: u64 = s[..s.len() - 1].parse().ok()?;
579        Some(std::time::Duration::from_secs(n * 3600))
580    } else {
581        None
582    }
583}
584
585#[cfg(test)]
586mod tests {
587    use super::*;
588
589    #[test]
590    fn test_parse_simple_digraph() {
591        let input = r#"
592        digraph pipeline {
593            start [shape=Mdiamond]
594            task_a [shape=box, label="Do task A", prompt="Write code"]
595            finish [shape=Msquare]
596
597            start -> task_a -> finish
598        }
599        "#;
600        let graph = parse_dot(input).unwrap();
601        assert_eq!(graph.name, "pipeline");
602        assert_eq!(graph.nodes.len(), 3);
603        assert_eq!(graph.edges.len(), 2);
604        assert_eq!(graph.edges[0].from, "start");
605        assert_eq!(graph.edges[0].to, "task_a");
606    }
607
608    #[test]
609    fn test_parse_graph_attrs() {
610        let input = r#"
611        digraph test {
612            graph [goal="Build a feature", fidelity="full"]
613            a -> b
614        }
615        "#;
616        let graph = parse_dot(input).unwrap();
617        assert_eq!(
618            graph.graph_attrs.get("goal"),
619            Some(&AttrValue::Str("Build a feature".into()))
620        );
621    }
622
623    #[test]
624    fn test_parse_node_defaults() {
625        let input = r#"
626        digraph test {
627            node [shape=box, reasoning_effort="high"]
628            a
629            b
630            a -> b
631        }
632        "#;
633        let graph = parse_dot(input).unwrap();
634        assert_eq!(
635            graph.node_defaults.get("shape"),
636            Some(&AttrValue::Str("box".into()))
637        );
638    }
639
640    #[test]
641    fn test_parse_edge_with_attrs() {
642        let input = r#"
643        digraph test {
644            a -> b [label="success", condition="outcome=success", weight=10]
645        }
646        "#;
647        let graph = parse_dot(input).unwrap();
648        assert_eq!(graph.edges.len(), 1);
649        assert_eq!(
650            graph.edges[0].attrs.get("label"),
651            Some(&AttrValue::Str("success".into()))
652        );
653        assert_eq!(
654            graph.edges[0].attrs.get("weight"),
655            Some(&AttrValue::Int(10))
656        );
657    }
658
659    #[test]
660    fn test_parse_chained_edges() {
661        let input = r#"
662        digraph test {
663            a -> b -> c -> d [label="chain"]
664        }
665        "#;
666        let graph = parse_dot(input).unwrap();
667        assert_eq!(graph.edges.len(), 3);
668        assert_eq!(graph.edges[0].from, "a");
669        assert_eq!(graph.edges[0].to, "b");
670        assert_eq!(graph.edges[2].from, "c");
671        assert_eq!(graph.edges[2].to, "d");
672    }
673
674    #[test]
675    fn test_parse_bool_and_int_attrs() {
676        let input = r#"
677        digraph test {
678            a [goal_gate=true, max_retries=3, auto_status=false]
679        }
680        "#;
681        let graph = parse_dot(input).unwrap();
682        let node = &graph.nodes[0];
683        assert_eq!(node.attrs.get("goal_gate"), Some(&AttrValue::Bool(true)));
684        assert_eq!(node.attrs.get("max_retries"), Some(&AttrValue::Int(3)));
685        assert_eq!(node.attrs.get("auto_status"), Some(&AttrValue::Bool(false)));
686    }
687
688    #[test]
689    fn test_parse_duration_attr() {
690        let input = r#"
691        digraph test {
692            a [timeout="30s"]
693        }
694        "#;
695        let graph = parse_dot(input).unwrap();
696        assert_eq!(
697            graph.nodes[0].attrs.get("timeout"),
698            Some(&AttrValue::Duration(std::time::Duration::from_secs(30)))
699        );
700    }
701
702    #[test]
703    fn test_parse_comments() {
704        let input = r#"
705        // This is a comment
706        digraph test {
707            /* block comment */
708            a -> b // inline comment
709        }
710        "#;
711        let graph = parse_dot(input).unwrap();
712        assert_eq!(graph.edges.len(), 1);
713    }
714
715    #[test]
716    fn test_parse_subgraph() {
717        let input = r#"
718        digraph test {
719            subgraph cluster_parallel {
720                graph [label="Parallel branch"]
721                p1 [shape=box]
722                p2 [shape=box]
723            }
724            start -> p1
725        }
726        "#;
727        let graph = parse_dot(input).unwrap();
728        assert_eq!(graph.subgraphs.len(), 1);
729        assert_eq!(graph.subgraphs[0].nodes.len(), 2);
730    }
731
732    #[test]
733    fn test_parse_quoted_identifiers() {
734        let input = r#"
735        digraph test {
736            "node with spaces" [label="A node"]
737            "node with spaces" -> b
738        }
739        "#;
740        let graph = parse_dot(input).unwrap();
741        assert_eq!(graph.nodes[0].id, "node with spaces");
742    }
743
744    #[test]
745    fn test_parse_empty_graph() {
746        let input = "digraph empty {}";
747        let graph = parse_dot(input).unwrap();
748        assert!(graph.nodes.is_empty());
749        assert!(graph.edges.is_empty());
750    }
751}