Skip to main content

scud/attractor/
transforms.rs

1//! Pre-execution graph transforms.
2//!
3//! Applied in order after parsing, before validation:
4//! 1. Variable expansion: Replace `$goal` in node prompts
5//! 2. Stylesheet application: Apply model_stylesheet CSS-like rules
6
7use super::graph::PipelineGraph;
8use super::stylesheet::{parse_stylesheet, apply_stylesheet};
9
10/// Apply all transforms to a pipeline graph.
11pub fn apply_transforms(graph: &mut PipelineGraph) {
12    expand_goal_variables(graph);
13    apply_stylesheet_transform(graph);
14}
15
16/// Replace `$goal` in all node prompts with the graph-level goal attribute.
17fn expand_goal_variables(graph: &mut PipelineGraph) {
18    let goal = match &graph.graph_attrs.goal {
19        Some(g) => g.clone(),
20        None => return,
21    };
22
23    for node_idx in graph.graph.node_indices() {
24        let node = &mut graph.graph[node_idx];
25        if node.prompt.contains("$goal") {
26            node.prompt = node.prompt.replace("$goal", &goal);
27        }
28    }
29}
30
31/// Apply model_stylesheet rules to resolve llm_model, llm_provider, reasoning_effort.
32fn apply_stylesheet_transform(graph: &mut PipelineGraph) {
33    let stylesheet_src = match &graph.graph_attrs.model_stylesheet {
34        Some(s) => s.clone(),
35        None => return,
36    };
37
38    let rules = match parse_stylesheet(&stylesheet_src) {
39        Ok(r) => r,
40        Err(_) => return, // Silently skip invalid stylesheets (validator will catch it)
41    };
42
43    apply_stylesheet(graph, &rules);
44}
45
46#[cfg(test)]
47mod tests {
48    use super::*;
49    use crate::attractor::dot_parser::parse_dot;
50    use crate::attractor::graph::PipelineGraph;
51
52    #[test]
53    fn test_goal_expansion() {
54        let input = r#"
55        digraph test {
56            graph [goal="Build a REST API"]
57            start [shape=Mdiamond]
58            task [shape=box, prompt="Your goal: $goal. Do it."]
59            finish [shape=Msquare]
60            start -> task -> finish
61        }
62        "#;
63        let dot = parse_dot(input).unwrap();
64        let mut graph = PipelineGraph::from_dot(&dot).unwrap();
65        apply_transforms(&mut graph);
66
67        let task = graph.node("task").unwrap();
68        assert_eq!(task.prompt, "Your goal: Build a REST API. Do it.");
69    }
70}