scud/attractor/handlers/
codergen.rs1use anyhow::Result;
6use async_trait::async_trait;
7use std::sync::Arc;
8
9use crate::attractor::context::Context;
10use crate::attractor::graph::{PipelineGraph, PipelineNode};
11use crate::attractor::outcome::{Outcome, StageStatus};
12use crate::attractor::run_directory::RunDirectory;
13use crate::backend::{AgentBackend, AgentRequest, AgentStatus};
14
15use super::Handler;
16
17pub struct CodergenHandler {
18 backend: Arc<dyn AgentBackend>,
19}
20
21impl CodergenHandler {
22 pub fn new(backend: Arc<dyn AgentBackend>) -> Self {
24 Self { backend }
25 }
26
27 pub fn simulated() -> Self {
29 Self {
30 backend: Arc::new(crate::backend::simulated::SimulatedBackend),
31 }
32 }
33}
34
35#[async_trait]
36impl Handler for CodergenHandler {
37 async fn execute(
38 &self,
39 node: &PipelineNode,
40 context: &Context,
41 graph: &PipelineGraph,
42 run_dir: &RunDirectory,
43 ) -> Result<Outcome> {
44 let prompt = expand_variables(&node.prompt, graph, context).await;
46
47 run_dir.write_prompt(&node.id, &prompt)?;
49
50 let request = AgentRequest {
52 prompt,
53 model: node.llm_model.clone(),
54 provider: node.llm_provider.clone(),
55 reasoning_effort: Some(node.reasoning_effort.clone()),
56 working_dir: std::env::current_dir().unwrap_or_default(),
57 timeout: node.timeout,
58 ..Default::default()
59 };
60
61 let handle = self.backend.execute(request).await?;
63 let result = handle.result().await?;
64
65 run_dir.write_response(&node.id, &result.text)?;
67
68 let status_json = serde_json::json!({
70 "node_id": node.id,
71 "status": match &result.status {
72 AgentStatus::Completed => "success",
73 AgentStatus::Failed(_) => "failure",
74 AgentStatus::Cancelled => "cancelled",
75 AgentStatus::Timeout => "timeout",
76 },
77 "tool_calls": result.tool_calls.len(),
78 });
79 run_dir.write_status(&node.id, &status_json)?;
80
81 let status = match result.status {
83 AgentStatus::Completed => StageStatus::Success,
84 AgentStatus::Failed(msg) => {
85 return Ok(Outcome::failure(msg).with_response(result.text));
86 }
87 AgentStatus::Cancelled => StageStatus::Cancelled,
88 AgentStatus::Timeout => StageStatus::Timeout,
89 };
90
91 Ok(Outcome {
92 status,
93 preferred_label: None,
94 suggested_next: vec![],
95 context_updates: std::collections::HashMap::new(),
96 response_text: Some(result.text),
97 summary: None,
98 })
99 }
100}
101
102async fn expand_variables(prompt: &str, graph: &PipelineGraph, context: &Context) -> String {
104 let mut result = prompt.to_string();
105
106 if let Some(ref goal) = graph.graph_attrs.goal {
108 result = result.replace("$goal", goal);
109 }
110
111 let snapshot = context.snapshot().await;
113 for (key, value) in &snapshot {
114 let pattern = format!("$context.{}", key);
115 let replacement = match value {
116 serde_json::Value::String(s) => s.clone(),
117 other => other.to_string(),
118 };
119 result = result.replace(&pattern, &replacement);
120 }
121
122 result
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128 use crate::attractor::context::Context;
129 use crate::attractor::dot_parser::parse_dot;
130 use crate::attractor::graph::PipelineGraph;
131 use crate::attractor::run_directory::RunDirectory;
132
133 #[tokio::test]
134 async fn test_codergen_simulated() {
135 let handler = CodergenHandler::simulated();
136 let dir = tempfile::tempdir().unwrap();
137 let run_dir = RunDirectory::create(dir.path(), "test").unwrap();
138
139 let dot = parse_dot(
140 r#"
141 digraph test {
142 graph [goal="Test goal"]
143 start [shape=Mdiamond]
144 task [shape=box, prompt="Do $goal"]
145 finish [shape=Msquare]
146 start -> task -> finish
147 }
148 "#,
149 )
150 .unwrap();
151 let graph = PipelineGraph::from_dot(&dot).unwrap();
152 let context = Context::new();
153 let node = graph.node("task").unwrap();
154
155 let outcome = handler
156 .execute(node, &context, &graph, &run_dir)
157 .await
158 .unwrap();
159 assert!(outcome.status.is_success());
160 assert!(outcome.response_text.is_some());
161
162 let response = run_dir.read_response("task").unwrap();
164 assert!(response.contains("Simulated"));
165 }
166
167 #[tokio::test]
168 async fn test_expand_goal() {
169 let dot = parse_dot(
170 r#"
171 digraph test {
172 graph [goal="Build a widget"]
173 start [shape=Mdiamond]
174 finish [shape=Msquare]
175 start -> finish
176 }
177 "#,
178 )
179 .unwrap();
180 let graph = PipelineGraph::from_dot(&dot).unwrap();
181 let context = Context::new();
182
183 let result = expand_variables("Your goal is: $goal", &graph, &context).await;
184 assert_eq!(result, "Your goal is: Build a widget");
185 }
186
187 #[tokio::test]
188 async fn test_expand_context() {
189 let dot = parse_dot(
190 r#"
191 digraph test {
192 start [shape=Mdiamond]
193 finish [shape=Msquare]
194 start -> finish
195 }
196 "#,
197 )
198 .unwrap();
199 let graph = PipelineGraph::from_dot(&dot).unwrap();
200 let context = Context::new();
201 context.set("name", serde_json::json!("Alice")).await;
202
203 let result = expand_variables("Hello $context.name", &graph, &context).await;
204 assert_eq!(result, "Hello Alice");
205 }
206}