scud/attractor/handlers/
codergen.rs1use anyhow::Result;
6use async_trait::async_trait;
7use std::sync::Arc;
8
9use crate::attractor::context::Context;
10use crate::attractor::graph::{PipelineGraph, PipelineNode};
11use crate::attractor::outcome::{Outcome, StageStatus};
12use crate::attractor::run_directory::RunDirectory;
13use crate::backend::{AgentBackend, AgentRequest, AgentStatus};
14
15use super::Handler;
16
17pub struct CodergenHandler {
18 backend: Arc<dyn AgentBackend>,
19}
20
21impl CodergenHandler {
22 pub fn new(backend: Arc<dyn AgentBackend>) -> Self {
24 Self { backend }
25 }
26
27 pub fn simulated() -> Self {
29 Self {
30 backend: Arc::new(crate::backend::simulated::SimulatedBackend),
31 }
32 }
33}
34
35#[async_trait]
36impl Handler for CodergenHandler {
37 async fn execute(
38 &self,
39 node: &PipelineNode,
40 context: &Context,
41 graph: &PipelineGraph,
42 run_dir: &RunDirectory,
43 ) -> Result<Outcome> {
44 let prompt = expand_variables(&node.prompt, graph, context).await;
46
47 run_dir.write_prompt(&node.id, &prompt)?;
49
50 let request = AgentRequest {
52 prompt,
53 model: node.llm_model.clone(),
54 provider: node.llm_provider.clone(),
55 reasoning_effort: Some(node.reasoning_effort.clone()),
56 working_dir: std::env::current_dir().unwrap_or_default(),
57 timeout: node.timeout,
58 ..Default::default()
59 };
60
61 let handle = self.backend.execute(request).await?;
63 let result = handle.result().await?;
64
65 run_dir.write_response(&node.id, &result.text)?;
67
68 let status_json = serde_json::json!({
70 "node_id": node.id,
71 "status": match &result.status {
72 AgentStatus::Completed => "success",
73 AgentStatus::Failed(_) => "failure",
74 AgentStatus::Cancelled => "cancelled",
75 AgentStatus::Timeout => "timeout",
76 },
77 "tool_calls": result.tool_calls.len(),
78 });
79 run_dir.write_status(&node.id, &status_json)?;
80
81 let status = match result.status {
83 AgentStatus::Completed => StageStatus::Success,
84 AgentStatus::Failed(msg) => {
85 return Ok(Outcome::failure(msg).with_response(result.text));
86 }
87 AgentStatus::Cancelled => StageStatus::Cancelled,
88 AgentStatus::Timeout => StageStatus::Timeout,
89 };
90
91 Ok(Outcome {
92 status,
93 preferred_label: None,
94 suggested_next: vec![],
95 context_updates: std::collections::HashMap::new(),
96 response_text: Some(result.text),
97 summary: None,
98 })
99 }
100}
101
102async fn expand_variables(
104 prompt: &str,
105 graph: &PipelineGraph,
106 context: &Context,
107) -> String {
108 let mut result = prompt.to_string();
109
110 if let Some(ref goal) = graph.graph_attrs.goal {
112 result = result.replace("$goal", goal);
113 }
114
115 let snapshot = context.snapshot().await;
117 for (key, value) in &snapshot {
118 let pattern = format!("$context.{}", key);
119 let replacement = match value {
120 serde_json::Value::String(s) => s.clone(),
121 other => other.to_string(),
122 };
123 result = result.replace(&pattern, &replacement);
124 }
125
126 result
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132 use crate::attractor::context::Context;
133 use crate::attractor::dot_parser::parse_dot;
134 use crate::attractor::graph::PipelineGraph;
135 use crate::attractor::run_directory::RunDirectory;
136
137 #[tokio::test]
138 async fn test_codergen_simulated() {
139 let handler = CodergenHandler::simulated();
140 let dir = tempfile::tempdir().unwrap();
141 let run_dir = RunDirectory::create(dir.path(), "test").unwrap();
142
143 let dot = parse_dot(r#"
144 digraph test {
145 graph [goal="Test goal"]
146 start [shape=Mdiamond]
147 task [shape=box, prompt="Do $goal"]
148 finish [shape=Msquare]
149 start -> task -> finish
150 }
151 "#).unwrap();
152 let graph = PipelineGraph::from_dot(&dot).unwrap();
153 let context = Context::new();
154 let node = graph.node("task").unwrap();
155
156 let outcome = handler.execute(node, &context, &graph, &run_dir).await.unwrap();
157 assert!(outcome.status.is_success());
158 assert!(outcome.response_text.is_some());
159
160 let response = run_dir.read_response("task").unwrap();
162 assert!(response.contains("Simulated"));
163 }
164
165 #[tokio::test]
166 async fn test_expand_goal() {
167 let dot = parse_dot(r#"
168 digraph test {
169 graph [goal="Build a widget"]
170 start [shape=Mdiamond]
171 finish [shape=Msquare]
172 start -> finish
173 }
174 "#).unwrap();
175 let graph = PipelineGraph::from_dot(&dot).unwrap();
176 let context = Context::new();
177
178 let result = expand_variables("Your goal is: $goal", &graph, &context).await;
179 assert_eq!(result, "Your goal is: Build a widget");
180 }
181
182 #[tokio::test]
183 async fn test_expand_context() {
184 let dot = parse_dot(r#"
185 digraph test {
186 start [shape=Mdiamond]
187 finish [shape=Msquare]
188 start -> finish
189 }
190 "#).unwrap();
191 let graph = PipelineGraph::from_dot(&dot).unwrap();
192 let context = Context::new();
193 context.set("name", serde_json::json!("Alice")).await;
194
195 let result = expand_variables("Hello $context.name", &graph, &context).await;
196 assert_eq!(result, "Hello Alice");
197 }
198}