Skip to main content

scud/attractor/
stylesheet.rs

1//! CSS-like model/provider stylesheet parser and applicator.
2//!
3//! Selectors:
4//! - `*` (specificity 0) — matches all nodes
5//! - `.class` (specificity 1) — matches nodes with that class
6//! - `#node_id` (specificity 2) — matches a specific node
7//!
8//! Later rules of equal specificity win.
9//! Explicit node attributes always override stylesheet values.
10
11use anyhow::{bail, Result};
12use std::collections::HashMap;
13
14use super::graph::PipelineGraph;
15
16/// A single stylesheet rule.
17#[derive(Debug, Clone)]
18pub struct StyleRule {
19    pub selector: Selector,
20    pub properties: HashMap<String, String>,
21}
22
23/// A CSS-like selector.
24#[derive(Debug, Clone)]
25pub enum Selector {
26    /// `*` — matches all nodes.
27    Universal,
28    /// `.classname` — matches nodes with that class.
29    Class(String),
30    /// `#node_id` — matches a specific node by ID.
31    Id(String),
32}
33
34impl Selector {
35    /// Specificity value (higher = more specific).
36    pub fn specificity(&self) -> u8 {
37        match self {
38            Selector::Universal => 0,
39            Selector::Class(_) => 1,
40            Selector::Id(_) => 2,
41        }
42    }
43
44    /// Check if this selector matches a node.
45    pub fn matches(&self, node_id: &str, node_classes: &[String]) -> bool {
46        match self {
47            Selector::Universal => true,
48            Selector::Class(class) => node_classes.iter().any(|c| c == class),
49            Selector::Id(id) => node_id == id,
50        }
51    }
52}
53
54/// Parse a stylesheet string into rules.
55///
56/// Format:
57/// ```text
58/// * { model: "claude-3-opus"; reasoning_effort: "high" }
59/// .fast { model: "claude-3-haiku"; reasoning_effort: "low" }
60/// #critical_node { model: "claude-3-opus"; provider: "anthropic" }
61/// ```
62pub fn parse_stylesheet(input: &str) -> Result<Vec<StyleRule>> {
63    let mut rules = Vec::new();
64    let mut chars = input.chars().peekable();
65
66    loop {
67        // Skip whitespace
68        while chars.peek().map(|c| c.is_whitespace()).unwrap_or(false) {
69            chars.next();
70        }
71
72        if chars.peek().is_none() {
73            break;
74        }
75
76        // Parse selector
77        let selector = parse_selector(&mut chars)?;
78
79        // Skip whitespace
80        while chars.peek().map(|c| c.is_whitespace()).unwrap_or(false) {
81            chars.next();
82        }
83
84        // Expect {
85        match chars.next() {
86            Some('{') => {}
87            _ => bail!("Expected '{{' after selector"),
88        }
89
90        // Parse properties
91        let mut properties = HashMap::new();
92        loop {
93            // Skip whitespace
94            while chars.peek().map(|c| c.is_whitespace()).unwrap_or(false) {
95                chars.next();
96            }
97
98            if chars.peek() == Some(&'}') {
99                chars.next();
100                break;
101            }
102
103            if chars.peek().is_none() {
104                bail!("Unterminated rule block");
105            }
106
107            // Read property name
108            let mut name = String::new();
109            while let Some(&c) = chars.peek() {
110                if c == ':' || c.is_whitespace() {
111                    break;
112                }
113                name.push(c);
114                chars.next();
115            }
116
117            // Skip whitespace and colon
118            while chars.peek().map(|c| c.is_whitespace()).unwrap_or(false) {
119                chars.next();
120            }
121            if chars.peek() == Some(&':') {
122                chars.next();
123            }
124            while chars.peek().map(|c| c.is_whitespace()).unwrap_or(false) {
125                chars.next();
126            }
127
128            // Read value (quoted or bare)
129            let value = if chars.peek() == Some(&'"') {
130                chars.next(); // skip opening quote
131                let mut v = String::new();
132                while let Some(c) = chars.next() {
133                    if c == '"' {
134                        break;
135                    }
136                    v.push(c);
137                }
138                v
139            } else {
140                let mut v = String::new();
141                while let Some(&c) = chars.peek() {
142                    if c == ';' || c == '}' || c.is_whitespace() {
143                        break;
144                    }
145                    v.push(c);
146                    chars.next();
147                }
148                v
149            };
150
151            if !name.is_empty() {
152                properties.insert(name, value);
153            }
154
155            // Skip optional semicolon
156            while chars.peek().map(|c| c.is_whitespace()).unwrap_or(false) {
157                chars.next();
158            }
159            if chars.peek() == Some(&';') {
160                chars.next();
161            }
162        }
163
164        rules.push(StyleRule {
165            selector,
166            properties,
167        });
168    }
169
170    Ok(rules)
171}
172
173fn parse_selector(chars: &mut std::iter::Peekable<std::str::Chars>) -> Result<Selector> {
174    match chars.peek() {
175        Some('*') => {
176            chars.next();
177            Ok(Selector::Universal)
178        }
179        Some('.') => {
180            chars.next();
181            let mut name = String::new();
182            while let Some(&c) = chars.peek() {
183                if c.is_alphanumeric() || c == '_' || c == '-' {
184                    name.push(c);
185                    chars.next();
186                } else {
187                    break;
188                }
189            }
190            Ok(Selector::Class(name))
191        }
192        Some('#') => {
193            chars.next();
194            let mut name = String::new();
195            while let Some(&c) = chars.peek() {
196                if c.is_alphanumeric() || c == '_' || c == '-' {
197                    name.push(c);
198                    chars.next();
199                } else {
200                    break;
201                }
202            }
203            Ok(Selector::Id(name))
204        }
205        Some(c) => bail!("Invalid selector start: '{}'", c),
206        None => bail!("Expected selector, got EOF"),
207    }
208}
209
210/// Apply stylesheet rules to a pipeline graph.
211///
212/// Rules are applied in order of specificity (lowest first).
213/// Explicit node attributes always override stylesheet values.
214pub fn apply_stylesheet(graph: &mut PipelineGraph, rules: &[StyleRule]) {
215    // Sort rules by specificity (stable sort preserves declaration order)
216    let mut sorted_rules: Vec<_> = rules.iter().collect();
217    sorted_rules.sort_by_key(|r| r.selector.specificity());
218
219    for node_idx in graph.graph.node_indices() {
220        let (node_id, node_classes, has_model, has_provider, has_effort) = {
221            let node = &graph.graph[node_idx];
222            (
223                node.id.clone(),
224                node.classes.clone(),
225                node.llm_model.is_some(),
226                node.llm_provider.is_some(),
227                node.reasoning_effort != "high", // "high" is default
228            )
229        };
230
231        for rule in &sorted_rules {
232            if rule.selector.matches(&node_id, &node_classes) {
233                let node = &mut graph.graph[node_idx];
234
235                // Only apply if the node doesn't have an explicit value
236                if let Some(model) = rule.properties.get("model") {
237                    if !has_model {
238                        node.llm_model = Some(model.clone());
239                    }
240                }
241                if let Some(provider) = rule.properties.get("provider") {
242                    if !has_provider {
243                        node.llm_provider = Some(provider.clone());
244                    }
245                }
246                if let Some(effort) = rule.properties.get("reasoning_effort") {
247                    if !has_effort {
248                        node.reasoning_effort = effort.clone();
249                    }
250                }
251            }
252        }
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259
260    #[test]
261    fn test_parse_stylesheet() {
262        let input = r#"
263            * { model: "claude-3-haiku"; reasoning_effort: "medium" }
264            .critical { model: "claude-3-opus" }
265            #special_node { provider: "anthropic" }
266        "#;
267        let rules = parse_stylesheet(input).unwrap();
268        assert_eq!(rules.len(), 3);
269        assert!(matches!(rules[0].selector, Selector::Universal));
270        assert!(matches!(rules[1].selector, Selector::Class(ref c) if c == "critical"));
271        assert!(matches!(rules[2].selector, Selector::Id(ref id) if id == "special_node"));
272    }
273
274    #[test]
275    fn test_selector_specificity() {
276        assert_eq!(Selector::Universal.specificity(), 0);
277        assert_eq!(Selector::Class("x".into()).specificity(), 1);
278        assert_eq!(Selector::Id("x".into()).specificity(), 2);
279    }
280
281    #[test]
282    fn test_selector_matches() {
283        assert!(Selector::Universal.matches("any", &[]));
284        assert!(Selector::Class("fast".into()).matches("x", &["fast".into()]));
285        assert!(!Selector::Class("fast".into()).matches("x", &["slow".into()]));
286        assert!(Selector::Id("x".into()).matches("x", &[]));
287        assert!(!Selector::Id("x".into()).matches("y", &[]));
288    }
289
290    #[test]
291    fn test_apply_stylesheet() {
292        use crate::attractor::dot_parser::parse_dot;
293        use crate::attractor::graph::PipelineGraph;
294
295        let input = r#"
296        digraph test {
297            graph [model_stylesheet="* { model: \"haiku\" }"]
298            start [shape=Mdiamond]
299            a [shape=box, class="fast"]
300            b [shape=box, llm_model="opus"]
301            finish [shape=Msquare]
302            start -> a -> b -> finish
303        }
304        "#;
305        let dot = parse_dot(input).unwrap();
306        let mut graph = PipelineGraph::from_dot(&dot).unwrap();
307
308        let rules = parse_stylesheet("* { model: \"haiku\" }").unwrap();
309        apply_stylesheet(&mut graph, &rules);
310
311        // Node 'a' should get model from stylesheet
312        let a = graph.node("a").unwrap();
313        assert_eq!(a.llm_model, Some("haiku".into()));
314
315        // Node 'b' already has explicit model, should keep it
316        let b = graph.node("b").unwrap();
317        assert_eq!(b.llm_model, Some("opus".into()));
318    }
319}