Skip to main content

scud/attractor/
dot_parser.rs

1//! DOT subset parser for Attractor pipeline graphs.
2//!
3//! Parses the Attractor DOT dialect: one digraph per file, directed edges only,
4//! node/edge attributes, chained edges, subgraph scoping, default blocks.
5
6use anyhow::{bail, Context, Result};
7use std::collections::HashMap;
8
9/// A parsed DOT graph.
10#[derive(Debug, Clone)]
11pub struct DotGraph {
12    pub name: String,
13    pub graph_attrs: HashMap<String, AttrValue>,
14    pub nodes: Vec<DotNode>,
15    pub edges: Vec<DotEdge>,
16    pub subgraphs: Vec<DotSubgraph>,
17    pub node_defaults: HashMap<String, AttrValue>,
18    pub edge_defaults: HashMap<String, AttrValue>,
19}
20
21/// A node in the DOT graph.
22#[derive(Debug, Clone)]
23pub struct DotNode {
24    pub id: String,
25    pub attrs: HashMap<String, AttrValue>,
26}
27
28/// An edge in the DOT graph.
29#[derive(Debug, Clone)]
30pub struct DotEdge {
31    pub from: String,
32    pub to: String,
33    pub attrs: HashMap<String, AttrValue>,
34}
35
36/// A subgraph (cluster) in the DOT graph.
37#[derive(Debug, Clone)]
38pub struct DotSubgraph {
39    pub name: Option<String>,
40    pub attrs: HashMap<String, AttrValue>,
41    pub nodes: Vec<DotNode>,
42    pub edges: Vec<DotEdge>,
43}
44
45/// Typed attribute value.
46#[derive(Debug, Clone, PartialEq)]
47pub enum AttrValue {
48    Str(String),
49    Int(i64),
50    Float(f64),
51    Bool(bool),
52    Duration(std::time::Duration),
53}
54
55impl AttrValue {
56    /// Get the value as a string, regardless of type.
57    pub fn as_str(&self) -> String {
58        match self {
59            AttrValue::Str(s) => s.clone(),
60            AttrValue::Int(i) => i.to_string(),
61            AttrValue::Float(f) => f.to_string(),
62            AttrValue::Bool(b) => b.to_string(),
63            AttrValue::Duration(d) => format!("{}s", d.as_secs()),
64        }
65    }
66
67    /// Get the value as a string reference if it is a string.
68    pub fn str_ref(&self) -> Option<&str> {
69        match self {
70            AttrValue::Str(s) => Some(s),
71            _ => None,
72        }
73    }
74
75    /// Get the value as an integer if possible.
76    pub fn as_int(&self) -> Option<i64> {
77        match self {
78            AttrValue::Int(i) => Some(*i),
79            _ => None,
80        }
81    }
82
83    /// Get the value as a bool if possible.
84    pub fn as_bool(&self) -> Option<bool> {
85        match self {
86            AttrValue::Bool(b) => Some(*b),
87            _ => None,
88        }
89    }
90}
91
92/// Parse a DOT file string into a DotGraph.
93pub fn parse_dot(input: &str) -> Result<DotGraph> {
94    let mut parser = DotParser::new(input);
95    parser.parse()
96}
97
98struct DotParser<'a> {
99    input: &'a str,
100    pos: usize,
101}
102
103impl<'a> DotParser<'a> {
104    fn new(input: &'a str) -> Self {
105        Self { input, pos: 0 }
106    }
107
108    fn parse(&mut self) -> Result<DotGraph> {
109        self.skip_ws();
110
111        // Expect "digraph"
112        self.expect_keyword("digraph")
113            .context("Expected 'digraph' keyword")?;
114        self.skip_ws();
115
116        // Graph name (optional)
117        let name = if self.peek_char() != Some('{') {
118            self.read_identifier()?
119        } else {
120            String::new()
121        };
122        self.skip_ws();
123
124        self.expect_char('{')?;
125
126        let mut graph = DotGraph {
127            name,
128            graph_attrs: HashMap::new(),
129            nodes: Vec::new(),
130            edges: Vec::new(),
131            subgraphs: Vec::new(),
132            node_defaults: HashMap::new(),
133            edge_defaults: HashMap::new(),
134        };
135
136        self.parse_body(&mut graph)?;
137
138        self.skip_ws();
139        self.expect_char('}')?;
140
141        Ok(graph)
142    }
143
144    fn parse_body(&mut self, graph: &mut DotGraph) -> Result<()> {
145        loop {
146            self.skip_ws();
147            if self.peek_char() == Some('}') || self.is_eof() {
148                break;
149            }
150
151            // Skip comments
152            if self.peek_str("//") {
153                self.skip_line();
154                continue;
155            }
156            if self.peek_str("/*") {
157                self.skip_block_comment();
158                continue;
159            }
160
161            // Check for keyword statements
162            if self.peek_keyword("node") {
163                self.advance(4);
164                self.skip_ws();
165                if self.peek_char() == Some('[') {
166                    let attrs = self.parse_attr_list()?;
167                    for (k, v) in attrs {
168                        graph.node_defaults.insert(k, v);
169                    }
170                }
171                self.skip_optional_semicolon();
172                continue;
173            }
174
175            if self.peek_keyword("edge") {
176                self.advance(4);
177                self.skip_ws();
178                if self.peek_char() == Some('[') {
179                    let attrs = self.parse_attr_list()?;
180                    for (k, v) in attrs {
181                        graph.edge_defaults.insert(k, v);
182                    }
183                }
184                self.skip_optional_semicolon();
185                continue;
186            }
187
188            if self.peek_keyword("graph") {
189                self.advance(5);
190                self.skip_ws();
191                if self.peek_char() == Some('[') {
192                    let attrs = self.parse_attr_list()?;
193                    for (k, v) in attrs {
194                        graph.graph_attrs.insert(k, v);
195                    }
196                }
197                self.skip_optional_semicolon();
198                continue;
199            }
200
201            // Check for subgraph
202            if self.peek_keyword("subgraph") {
203                let sg = self.parse_subgraph()?;
204                graph.subgraphs.push(sg);
205                self.skip_optional_semicolon();
206                continue;
207            }
208
209            // Must be a node or edge statement
210            let id = self.read_identifier_or_quoted()?;
211            self.skip_ws();
212
213            if self.peek_str("->") {
214                // Edge chain: a -> b -> c [attrs]
215                let mut chain = vec![id];
216                while self.peek_str("->") {
217                    self.advance(2);
218                    self.skip_ws();
219                    chain.push(self.read_identifier_or_quoted()?);
220                    self.skip_ws();
221                }
222
223                let attrs = if self.peek_char() == Some('[') {
224                    self.parse_attr_list()?
225                } else {
226                    HashMap::new()
227                };
228
229                // Expand chain into individual edges
230                for window in chain.windows(2) {
231                    graph.edges.push(DotEdge {
232                        from: window[0].clone(),
233                        to: window[1].clone(),
234                        attrs: attrs.clone(),
235                    });
236                }
237            } else {
238                // Node statement with optional attrs
239                let attrs = if self.peek_char() == Some('[') {
240                    self.parse_attr_list()?
241                } else {
242                    HashMap::new()
243                };
244
245                graph.nodes.push(DotNode { id, attrs });
246            }
247
248            self.skip_optional_semicolon();
249        }
250
251        Ok(())
252    }
253
254    fn parse_subgraph(&mut self) -> Result<DotSubgraph> {
255        self.expect_keyword("subgraph")?;
256        self.skip_ws();
257
258        let name = if self.peek_char() != Some('{') {
259            Some(self.read_identifier_or_quoted()?)
260        } else {
261            None
262        };
263        self.skip_ws();
264        self.expect_char('{')?;
265
266        let mut sg = DotSubgraph {
267            name,
268            attrs: HashMap::new(),
269            nodes: Vec::new(),
270            edges: Vec::new(),
271        };
272
273        // Parse subgraph body (simplified — no nested subgraphs)
274        loop {
275            self.skip_ws();
276            if self.peek_char() == Some('}') || self.is_eof() {
277                break;
278            }
279            if self.peek_str("//") {
280                self.skip_line();
281                continue;
282            }
283            if self.peek_str("/*") {
284                self.skip_block_comment();
285                continue;
286            }
287
288            // Check for graph/node/edge defaults
289            if self.peek_keyword("graph") {
290                self.advance(5);
291                self.skip_ws();
292                if self.peek_char() == Some('[') {
293                    let attrs = self.parse_attr_list()?;
294                    sg.attrs.extend(attrs);
295                }
296                self.skip_optional_semicolon();
297                continue;
298            }
299
300            let id = self.read_identifier_or_quoted()?;
301            self.skip_ws();
302
303            if self.peek_str("->") {
304                let mut chain = vec![id];
305                while self.peek_str("->") {
306                    self.advance(2);
307                    self.skip_ws();
308                    chain.push(self.read_identifier_or_quoted()?);
309                    self.skip_ws();
310                }
311                let attrs = if self.peek_char() == Some('[') {
312                    self.parse_attr_list()?
313                } else {
314                    HashMap::new()
315                };
316                for window in chain.windows(2) {
317                    sg.edges.push(DotEdge {
318                        from: window[0].clone(),
319                        to: window[1].clone(),
320                        attrs: attrs.clone(),
321                    });
322                }
323            } else {
324                let attrs = if self.peek_char() == Some('[') {
325                    self.parse_attr_list()?
326                } else {
327                    HashMap::new()
328                };
329                sg.nodes.push(DotNode { id, attrs });
330            }
331
332            self.skip_optional_semicolon();
333        }
334
335        self.expect_char('}')?;
336        Ok(sg)
337    }
338
339    fn parse_attr_list(&mut self) -> Result<HashMap<String, AttrValue>> {
340        self.expect_char('[')?;
341        let mut attrs = HashMap::new();
342
343        loop {
344            self.skip_ws();
345            if self.peek_char() == Some(']') {
346                self.advance(1);
347                break;
348            }
349
350            let key = self.read_identifier()?;
351            self.skip_ws();
352            self.expect_char('=')?;
353            self.skip_ws();
354            let value = self.read_attr_value()?;
355            attrs.insert(key, value);
356
357            self.skip_ws();
358            // Optional comma or semicolon separator
359            if self.peek_char() == Some(',') || self.peek_char() == Some(';') {
360                self.advance(1);
361            }
362        }
363
364        Ok(attrs)
365    }
366
367    fn read_attr_value(&mut self) -> Result<AttrValue> {
368        let ch = self
369            .peek_char()
370            .context("Unexpected EOF in attribute value")?;
371
372        if ch == '"' {
373            let s = self.read_quoted_string()?;
374            // Try to parse as duration (e.g., "30s", "5m")
375            if let Some(d) = parse_duration_str(&s) {
376                return Ok(AttrValue::Duration(d));
377            }
378            Ok(AttrValue::Str(s))
379        } else if ch == '-' || ch.is_ascii_digit() {
380            let num_str = self.read_number_str();
381            if num_str.contains('.') {
382                Ok(AttrValue::Float(num_str.parse().context("Invalid float")?))
383            } else {
384                Ok(AttrValue::Int(num_str.parse().context("Invalid integer")?))
385            }
386        } else {
387            // Bare word — could be bool or string
388            let word = self.read_identifier()?;
389            match word.to_lowercase().as_str() {
390                "true" | "yes" => Ok(AttrValue::Bool(true)),
391                "false" | "no" => Ok(AttrValue::Bool(false)),
392                _ => Ok(AttrValue::Str(word)),
393            }
394        }
395    }
396
397    fn read_quoted_string(&mut self) -> Result<String> {
398        self.expect_char('"')?;
399        let mut s = String::new();
400        loop {
401            match self.next_char() {
402                Some('\\') => match self.next_char() {
403                    Some('n') => s.push('\n'),
404                    Some('t') => s.push('\t'),
405                    Some('"') => s.push('"'),
406                    Some('\\') => s.push('\\'),
407                    Some(c) => {
408                        s.push('\\');
409                        s.push(c);
410                    }
411                    None => bail!("Unterminated escape in string"),
412                },
413                Some('"') => break,
414                Some(c) => s.push(c),
415                None => bail!("Unterminated string"),
416            }
417        }
418        Ok(s)
419    }
420
421    fn read_identifier(&mut self) -> Result<String> {
422        let start = self.pos;
423        while let Some(c) = self.peek_char() {
424            if c.is_alphanumeric() || c == '_' || c == '.' || c == '-' {
425                self.advance(1);
426            } else {
427                break;
428            }
429        }
430        if self.pos == start {
431            bail!(
432                "Expected identifier at position {}, got {:?}",
433                self.pos,
434                self.peek_char()
435            );
436        }
437        Ok(self.input[start..self.pos].to_string())
438    }
439
440    fn read_identifier_or_quoted(&mut self) -> Result<String> {
441        if self.peek_char() == Some('"') {
442            self.read_quoted_string()
443        } else {
444            self.read_identifier()
445        }
446    }
447
448    fn read_number_str(&mut self) -> String {
449        let start = self.pos;
450        if self.peek_char() == Some('-') {
451            self.advance(1);
452        }
453        while let Some(c) = self.peek_char() {
454            if c.is_ascii_digit() || c == '.' {
455                self.advance(1);
456            } else {
457                break;
458            }
459        }
460        self.input[start..self.pos].to_string()
461    }
462
463    // --- Utility methods ---
464
465    fn skip_ws(&mut self) {
466        loop {
467            match self.peek_char() {
468                Some(c) if c.is_whitespace() => {
469                    self.advance(1);
470                }
471                Some('/') if self.peek_str("//") => {
472                    self.skip_line();
473                }
474                Some('/') if self.peek_str("/*") => {
475                    self.skip_block_comment();
476                }
477                _ => break,
478            }
479        }
480    }
481
482    fn skip_line(&mut self) {
483        while let Some(c) = self.next_char() {
484            if c == '\n' {
485                break;
486            }
487        }
488    }
489
490    fn skip_block_comment(&mut self) {
491        self.advance(2); // skip /*
492        while !self.is_eof() {
493            if self.peek_str("*/") {
494                self.advance(2);
495                return;
496            }
497            self.advance(1);
498        }
499    }
500
501    fn skip_optional_semicolon(&mut self) {
502        self.skip_ws();
503        if self.peek_char() == Some(';') {
504            self.advance(1);
505        }
506    }
507
508    fn peek_char(&self) -> Option<char> {
509        self.input[self.pos..].chars().next()
510    }
511
512    fn next_char(&mut self) -> Option<char> {
513        let c = self.input[self.pos..].chars().next()?;
514        self.pos += c.len_utf8();
515        Some(c)
516    }
517
518    fn advance(&mut self, n: usize) {
519        self.pos = (self.pos + n).min(self.input.len());
520    }
521
522    fn is_eof(&self) -> bool {
523        self.pos >= self.input.len()
524    }
525
526    fn peek_str(&self, s: &str) -> bool {
527        self.input[self.pos..].starts_with(s)
528    }
529
530    fn peek_keyword(&self, kw: &str) -> bool {
531        if !self.input[self.pos..].starts_with(kw) {
532            return false;
533        }
534        // Must be followed by non-identifier char
535        let after = self.pos + kw.len();
536        if after >= self.input.len() {
537            return true;
538        }
539        let next = self.input[after..].chars().next().unwrap();
540        !next.is_alphanumeric() && next != '_'
541    }
542
543    fn expect_keyword(&mut self, kw: &str) -> Result<()> {
544        if !self.peek_keyword(kw) {
545            bail!(
546                "Expected '{}' at position {}, got '{}'",
547                kw,
548                self.pos,
549                &self.input[self.pos..self.pos + 10.min(self.input.len() - self.pos)]
550            );
551        }
552        self.advance(kw.len());
553        Ok(())
554    }
555
556    fn expect_char(&mut self, expected: char) -> Result<()> {
557        match self.next_char() {
558            Some(c) if c == expected => Ok(()),
559            Some(c) => bail!(
560                "Expected '{}', got '{}' at position {}",
561                expected,
562                c,
563                self.pos - 1
564            ),
565            None => bail!("Expected '{}', got EOF", expected),
566        }
567    }
568}
569
570/// Parse a duration string like "30s", "5m", "1h".
571fn parse_duration_str(s: &str) -> Option<std::time::Duration> {
572    let s = s.trim();
573    if s.ends_with('s') {
574        let n: u64 = s[..s.len() - 1].parse().ok()?;
575        Some(std::time::Duration::from_secs(n))
576    } else if s.ends_with('m') {
577        let n: u64 = s[..s.len() - 1].parse().ok()?;
578        Some(std::time::Duration::from_secs(n * 60))
579    } else if s.ends_with('h') {
580        let n: u64 = s[..s.len() - 1].parse().ok()?;
581        Some(std::time::Duration::from_secs(n * 3600))
582    } else {
583        None
584    }
585}
586
587#[cfg(test)]
588mod tests {
589    use super::*;
590
591    #[test]
592    fn test_parse_simple_digraph() {
593        let input = r#"
594        digraph pipeline {
595            start [shape=Mdiamond]
596            task_a [shape=box, label="Do task A", prompt="Write code"]
597            finish [shape=Msquare]
598
599            start -> task_a -> finish
600        }
601        "#;
602        let graph = parse_dot(input).unwrap();
603        assert_eq!(graph.name, "pipeline");
604        assert_eq!(graph.nodes.len(), 3);
605        assert_eq!(graph.edges.len(), 2);
606        assert_eq!(graph.edges[0].from, "start");
607        assert_eq!(graph.edges[0].to, "task_a");
608    }
609
610    #[test]
611    fn test_parse_graph_attrs() {
612        let input = r#"
613        digraph test {
614            graph [goal="Build a feature", fidelity="full"]
615            a -> b
616        }
617        "#;
618        let graph = parse_dot(input).unwrap();
619        assert_eq!(
620            graph.graph_attrs.get("goal"),
621            Some(&AttrValue::Str("Build a feature".into()))
622        );
623    }
624
625    #[test]
626    fn test_parse_node_defaults() {
627        let input = r#"
628        digraph test {
629            node [shape=box, reasoning_effort="high"]
630            a
631            b
632            a -> b
633        }
634        "#;
635        let graph = parse_dot(input).unwrap();
636        assert_eq!(
637            graph.node_defaults.get("shape"),
638            Some(&AttrValue::Str("box".into()))
639        );
640    }
641
642    #[test]
643    fn test_parse_edge_with_attrs() {
644        let input = r#"
645        digraph test {
646            a -> b [label="success", condition="outcome=success", weight=10]
647        }
648        "#;
649        let graph = parse_dot(input).unwrap();
650        assert_eq!(graph.edges.len(), 1);
651        assert_eq!(
652            graph.edges[0].attrs.get("label"),
653            Some(&AttrValue::Str("success".into()))
654        );
655        assert_eq!(
656            graph.edges[0].attrs.get("weight"),
657            Some(&AttrValue::Int(10))
658        );
659    }
660
661    #[test]
662    fn test_parse_chained_edges() {
663        let input = r#"
664        digraph test {
665            a -> b -> c -> d [label="chain"]
666        }
667        "#;
668        let graph = parse_dot(input).unwrap();
669        assert_eq!(graph.edges.len(), 3);
670        assert_eq!(graph.edges[0].from, "a");
671        assert_eq!(graph.edges[0].to, "b");
672        assert_eq!(graph.edges[2].from, "c");
673        assert_eq!(graph.edges[2].to, "d");
674    }
675
676    #[test]
677    fn test_parse_bool_and_int_attrs() {
678        let input = r#"
679        digraph test {
680            a [goal_gate=true, max_retries=3, auto_status=false]
681        }
682        "#;
683        let graph = parse_dot(input).unwrap();
684        let node = &graph.nodes[0];
685        assert_eq!(node.attrs.get("goal_gate"), Some(&AttrValue::Bool(true)));
686        assert_eq!(node.attrs.get("max_retries"), Some(&AttrValue::Int(3)));
687        assert_eq!(node.attrs.get("auto_status"), Some(&AttrValue::Bool(false)));
688    }
689
690    #[test]
691    fn test_parse_duration_attr() {
692        let input = r#"
693        digraph test {
694            a [timeout="30s"]
695        }
696        "#;
697        let graph = parse_dot(input).unwrap();
698        assert_eq!(
699            graph.nodes[0].attrs.get("timeout"),
700            Some(&AttrValue::Duration(std::time::Duration::from_secs(30)))
701        );
702    }
703
704    #[test]
705    fn test_parse_comments() {
706        let input = r#"
707        // This is a comment
708        digraph test {
709            /* block comment */
710            a -> b // inline comment
711        }
712        "#;
713        let graph = parse_dot(input).unwrap();
714        assert_eq!(graph.edges.len(), 1);
715    }
716
717    #[test]
718    fn test_parse_subgraph() {
719        let input = r#"
720        digraph test {
721            subgraph cluster_parallel {
722                graph [label="Parallel branch"]
723                p1 [shape=box]
724                p2 [shape=box]
725            }
726            start -> p1
727        }
728        "#;
729        let graph = parse_dot(input).unwrap();
730        assert_eq!(graph.subgraphs.len(), 1);
731        assert_eq!(graph.subgraphs[0].nodes.len(), 2);
732    }
733
734    #[test]
735    fn test_parse_quoted_identifiers() {
736        let input = r#"
737        digraph test {
738            "node with spaces" [label="A node"]
739            "node with spaces" -> b
740        }
741        "#;
742        let graph = parse_dot(input).unwrap();
743        assert_eq!(graph.nodes[0].id, "node with spaces");
744    }
745
746    #[test]
747    fn test_parse_empty_graph() {
748        let input = "digraph empty {}";
749        let graph = parse_dot(input).unwrap();
750        assert!(graph.nodes.is_empty());
751        assert!(graph.edges.is_empty());
752    }
753}