scud/attractor/
stylesheet.rs1use anyhow::{bail, Result};
12use std::collections::HashMap;
13
14use super::graph::PipelineGraph;
15
16#[derive(Debug, Clone)]
18pub struct StyleRule {
19 pub selector: Selector,
20 pub properties: HashMap<String, String>,
21}
22
23#[derive(Debug, Clone)]
25pub enum Selector {
26 Universal,
28 Class(String),
30 Id(String),
32}
33
34impl Selector {
35 pub fn specificity(&self) -> u8 {
37 match self {
38 Selector::Universal => 0,
39 Selector::Class(_) => 1,
40 Selector::Id(_) => 2,
41 }
42 }
43
44 pub fn matches(&self, node_id: &str, node_classes: &[String]) -> bool {
46 match self {
47 Selector::Universal => true,
48 Selector::Class(class) => node_classes.iter().any(|c| c == class),
49 Selector::Id(id) => node_id == id,
50 }
51 }
52}
53
54pub fn parse_stylesheet(input: &str) -> Result<Vec<StyleRule>> {
63 let mut rules = Vec::new();
64 let mut chars = input.chars().peekable();
65
66 loop {
67 while chars.peek().map(|c| c.is_whitespace()).unwrap_or(false) {
69 chars.next();
70 }
71
72 if chars.peek().is_none() {
73 break;
74 }
75
76 let selector = parse_selector(&mut chars)?;
78
79 while chars.peek().map(|c| c.is_whitespace()).unwrap_or(false) {
81 chars.next();
82 }
83
84 match chars.next() {
86 Some('{') => {}
87 _ => bail!("Expected '{{' after selector"),
88 }
89
90 let mut properties = HashMap::new();
92 loop {
93 while chars.peek().map(|c| c.is_whitespace()).unwrap_or(false) {
95 chars.next();
96 }
97
98 if chars.peek() == Some(&'}') {
99 chars.next();
100 break;
101 }
102
103 if chars.peek().is_none() {
104 bail!("Unterminated rule block");
105 }
106
107 let mut name = String::new();
109 while let Some(&c) = chars.peek() {
110 if c == ':' || c.is_whitespace() {
111 break;
112 }
113 name.push(c);
114 chars.next();
115 }
116
117 while chars.peek().map(|c| c.is_whitespace()).unwrap_or(false) {
119 chars.next();
120 }
121 if chars.peek() == Some(&':') {
122 chars.next();
123 }
124 while chars.peek().map(|c| c.is_whitespace()).unwrap_or(false) {
125 chars.next();
126 }
127
128 let value = if chars.peek() == Some(&'"') {
130 chars.next(); let mut v = String::new();
132 while let Some(c) = chars.next() {
133 if c == '"' {
134 break;
135 }
136 v.push(c);
137 }
138 v
139 } else {
140 let mut v = String::new();
141 while let Some(&c) = chars.peek() {
142 if c == ';' || c == '}' || c.is_whitespace() {
143 break;
144 }
145 v.push(c);
146 chars.next();
147 }
148 v
149 };
150
151 if !name.is_empty() {
152 properties.insert(name, value);
153 }
154
155 while chars.peek().map(|c| c.is_whitespace()).unwrap_or(false) {
157 chars.next();
158 }
159 if chars.peek() == Some(&';') {
160 chars.next();
161 }
162 }
163
164 rules.push(StyleRule {
165 selector,
166 properties,
167 });
168 }
169
170 Ok(rules)
171}
172
173fn parse_selector(chars: &mut std::iter::Peekable<std::str::Chars>) -> Result<Selector> {
174 match chars.peek() {
175 Some('*') => {
176 chars.next();
177 Ok(Selector::Universal)
178 }
179 Some('.') => {
180 chars.next();
181 let mut name = String::new();
182 while let Some(&c) = chars.peek() {
183 if c.is_alphanumeric() || c == '_' || c == '-' {
184 name.push(c);
185 chars.next();
186 } else {
187 break;
188 }
189 }
190 Ok(Selector::Class(name))
191 }
192 Some('#') => {
193 chars.next();
194 let mut name = String::new();
195 while let Some(&c) = chars.peek() {
196 if c.is_alphanumeric() || c == '_' || c == '-' {
197 name.push(c);
198 chars.next();
199 } else {
200 break;
201 }
202 }
203 Ok(Selector::Id(name))
204 }
205 Some(c) => bail!("Invalid selector start: '{}'", c),
206 None => bail!("Expected selector, got EOF"),
207 }
208}
209
210pub fn apply_stylesheet(graph: &mut PipelineGraph, rules: &[StyleRule]) {
215 let mut sorted_rules: Vec<_> = rules.iter().collect();
217 sorted_rules.sort_by_key(|r| r.selector.specificity());
218
219 for node_idx in graph.graph.node_indices() {
220 let (node_id, node_classes, has_model, has_provider, has_effort) = {
221 let node = &graph.graph[node_idx];
222 (
223 node.id.clone(),
224 node.classes.clone(),
225 node.llm_model.is_some(),
226 node.llm_provider.is_some(),
227 node.reasoning_effort != "high", )
229 };
230
231 for rule in &sorted_rules {
232 if rule.selector.matches(&node_id, &node_classes) {
233 let node = &mut graph.graph[node_idx];
234
235 if let Some(model) = rule.properties.get("model") {
237 if !has_model {
238 node.llm_model = Some(model.clone());
239 }
240 }
241 if let Some(provider) = rule.properties.get("provider") {
242 if !has_provider {
243 node.llm_provider = Some(provider.clone());
244 }
245 }
246 if let Some(effort) = rule.properties.get("reasoning_effort") {
247 if !has_effort {
248 node.reasoning_effort = effort.clone();
249 }
250 }
251 }
252 }
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259
260 #[test]
261 fn test_parse_stylesheet() {
262 let input = r#"
263 * { model: "claude-3-haiku"; reasoning_effort: "medium" }
264 .critical { model: "claude-3-opus" }
265 #special_node { provider: "anthropic" }
266 "#;
267 let rules = parse_stylesheet(input).unwrap();
268 assert_eq!(rules.len(), 3);
269 assert!(matches!(rules[0].selector, Selector::Universal));
270 assert!(matches!(rules[1].selector, Selector::Class(ref c) if c == "critical"));
271 assert!(matches!(rules[2].selector, Selector::Id(ref id) if id == "special_node"));
272 }
273
274 #[test]
275 fn test_selector_specificity() {
276 assert_eq!(Selector::Universal.specificity(), 0);
277 assert_eq!(Selector::Class("x".into()).specificity(), 1);
278 assert_eq!(Selector::Id("x".into()).specificity(), 2);
279 }
280
281 #[test]
282 fn test_selector_matches() {
283 assert!(Selector::Universal.matches("any", &[]));
284 assert!(Selector::Class("fast".into()).matches("x", &["fast".into()]));
285 assert!(!Selector::Class("fast".into()).matches("x", &["slow".into()]));
286 assert!(Selector::Id("x".into()).matches("x", &[]));
287 assert!(!Selector::Id("x".into()).matches("y", &[]));
288 }
289
290 #[test]
291 fn test_apply_stylesheet() {
292 use crate::attractor::dot_parser::parse_dot;
293 use crate::attractor::graph::PipelineGraph;
294
295 let input = r#"
296 digraph test {
297 graph [model_stylesheet="* { model: \"haiku\" }"]
298 start [shape=Mdiamond]
299 a [shape=box, class="fast"]
300 b [shape=box, llm_model="opus"]
301 finish [shape=Msquare]
302 start -> a -> b -> finish
303 }
304 "#;
305 let dot = parse_dot(input).unwrap();
306 let mut graph = PipelineGraph::from_dot(&dot).unwrap();
307
308 let rules = parse_stylesheet("* { model: \"haiku\" }").unwrap();
309 apply_stylesheet(&mut graph, &rules);
310
311 let a = graph.node("a").unwrap();
313 assert_eq!(a.llm_model, Some("haiku".into()));
314
315 let b = graph.node("b").unwrap();
317 assert_eq!(b.llm_model, Some("opus".into()));
318 }
319}