smol_workflow_engine/agent_providers/
claude_code.rs1use super::common::*;
2use super::types::*;
3use anyhow::{bail, Context};
4use serde_json::{json, Value};
5use std::collections::HashMap;
6use std::path::PathBuf;
7
8#[derive(Debug, Clone, Default)]
9pub struct ClaudeCodeAgentProviderOptions {
10 pub command: Option<String>,
11 pub subcommand: Vec<String>,
12 pub args: Vec<String>,
13 pub cwd: Option<PathBuf>,
14 pub env: HashMap<String, String>,
15 pub timeout_ms: Option<u64>,
16}
17
18#[derive(Debug, Clone, Default)]
19pub struct ClaudeCodeAgentProvider {
20 options: ClaudeCodeAgentProviderOptions,
21}
22
23impl ClaudeCodeAgentProvider {
24 pub fn new(options: ClaudeCodeAgentProviderOptions) -> Self {
25 Self { options }
26 }
27}
28
29#[async_trait::async_trait]
30impl AgentProvider for ClaudeCodeAgentProvider {
31 fn name(&self) -> &str {
32 "claude-code"
33 }
34
35 fn schema_mode(&self) -> AgentProviderSchemaMode {
36 AgentProviderSchemaMode::Builtin
37 }
38
39 fn usage_mode(&self) -> AgentProviderUsageMode {
40 AgentProviderUsageMode::Builtin
41 }
42
43 async fn run(&self, input: AgentProviderRunInput) -> anyhow::Result<AgentProviderResult> {
44 run_claude_code(input, &self.options).await
45 }
46}
47
48async fn run_claude_code(
49 input: AgentProviderRunInput,
50 options: &ClaudeCodeAgentProviderOptions,
51) -> anyhow::Result<AgentProviderResult> {
52 let command = options.command.as_deref().unwrap_or("claude");
53 let mut args = Vec::<String>::new();
54 args.extend(options.subcommand.clone());
55 args.extend(options.args.clone());
56 if let Some(model) = option_str(&input.options, "model") {
57 args.extend(["--model".into(), model]);
58 }
59 if let Some(thinking) = option_str(&input.options, "thinking") {
60 args.extend(["--effort".into(), thinking]);
61 }
62 if let Some(agent_type) = option_str(&input.options, "agentType") {
63 args.extend(["--agent".into(), agent_type]);
64 }
65 args.extend([
66 "--output-format".into(),
67 "stream-json".into(),
68 "--verbose".into(),
69 "--input-format".into(),
70 "text".into(),
71 ]);
72 if let Some(schema) = option_schema(&input.options) {
73 args.extend(["--json-schema".into(), serde_json::to_string(schema)?]);
74 }
75 args.push("--print".into());
76
77 let cwd = input.context.cwd.as_deref().or(options.cwd.as_deref());
78 let (stdout, stderr) = run_command(
79 "Claude Code",
80 command,
81 &args,
82 Some(&input.prompt),
83 cwd,
84 &options.env,
85 options.timeout_ms,
86 )
87 .await?;
88 let events = parse_json_lines(&stdout);
89 let raw = if events.is_empty() {
90 parse_json_or_text(&stdout)
91 } else {
92 Value::Array(events.clone())
93 };
94 let structured = option_schema(&input.options).is_some();
95 let output = extract_output(&raw, structured)?;
96 let session_id = extract_session_id(&raw)
97 .context("Claude Code provider response did not include a session id")?;
98 let event_payloads = if events.is_empty() {
99 vec![raw.clone()]
100 } else {
101 events
102 };
103
104 Ok(AgentProviderResult {
105 output,
106 session_id: Some(session_id),
107 model: extract_model(&raw).or_else(|| option_model(&input.options)),
108 usage: extract_usage(&raw),
109 isolation: None,
110 raw: Some(to_json_value(
111 json!({ "events": event_payloads, "response": raw, "stderr": stderr }),
112 )),
113 })
114}
115
116fn extract_output(raw: &Value, structured: bool) -> anyhow::Result<Value> {
117 if structured {
118 if let Some(output) = extract_structured_output(raw) {
119 return Ok(output);
120 }
121 }
122
123 let candidate = extract_output_candidate(raw);
124 if !structured {
125 return Ok(match candidate {
126 Value::String(text) => Value::String(text.trim_end().to_string()),
127 value => value,
128 });
129 }
130
131 match candidate {
132 Value::String(text) => parse_structured_output(&text),
133 value => Ok(value),
134 }
135}
136
137fn extract_structured_output(raw: &Value) -> Option<Value> {
138 match raw {
139 Value::Array(items) => items.iter().find_map(extract_structured_output),
140 Value::Object(record) => record
141 .get("structured_output")
142 .or_else(|| record.get("structuredOutput"))
143 .cloned(),
144 _ => None,
145 }
146}
147
148fn extract_output_candidate(raw: &Value) -> Value {
149 match raw {
150 Value::String(_) => raw.clone(),
151 Value::Array(items) => items
152 .iter()
153 .rev()
154 .map(extract_output_candidate)
155 .find(|value| !value.is_null())
156 .unwrap_or_else(|| raw.clone()),
157 Value::Object(record) => {
158 for key in ["result", "output", "text", "content"] {
159 if let Some(value) = record.get(key) {
160 return extract_content_text(value);
161 }
162 }
163 if let Some(message) = record.get("message") {
164 if message.is_object() {
165 return extract_output_candidate(message);
166 }
167 }
168 raw.clone()
169 }
170 _ => raw.clone(),
171 }
172}
173
174fn extract_content_text(value: &Value) -> Value {
175 match value {
176 Value::Array(items) => Value::String(
177 items
178 .iter()
179 .map(|item| match item {
180 Value::String(text) => text.clone(),
181 Value::Object(record) => record
182 .get("text")
183 .and_then(Value::as_str)
184 .unwrap_or("")
185 .to_string(),
186 _ => String::new(),
187 })
188 .collect::<Vec<_>>()
189 .join(""),
190 ),
191 _ => value.clone(),
192 }
193}
194
195fn parse_structured_output(text: &str) -> anyhow::Result<Value> {
196 let trimmed = text.trim();
197 if let Ok(value) = serde_json::from_str(trimmed) {
198 return Ok(value);
199 }
200 if let Some(value) = extract_fenced_json(trimmed) {
201 return serde_json::from_str(value)
202 .context("Claude Code provider did not return valid JSON for schema output");
203 }
204 bail!("Claude Code provider did not return valid JSON for schema output")
205}
206
207fn extract_fenced_json(text: &str) -> Option<&str> {
208 let start = text.find("```")?;
209 let after = &text[start + 3..];
210 let after = after.strip_prefix("json").unwrap_or(after).trim_start();
211 let end = after.find("```")?;
212 Some(after[..end].trim())
213}
214
215fn extract_session_id(raw: &Value) -> Option<String> {
216 match raw {
217 Value::Array(items) => items.iter().find_map(extract_session_id),
218 Value::Object(record) => {
219 if let Some(value) = record
220 .get("session_id")
221 .or_else(|| record.get("sessionId"))
222 .or_else(|| record.get("sessionID"))
223 .and_then(Value::as_str)
224 {
225 return Some(value.to_string());
226 }
227 record.values().find_map(extract_session_id)
228 }
229 _ => None,
230 }
231}
232
233fn extract_usage(raw: &Value) -> Option<AgentUsage> {
234 let mut usage_objects = Vec::new();
235 find_usage_objects(raw, &mut usage_objects);
236 let usage = usage_objects.last()?;
237 let mut normalized = normalize_usage(usage);
238 if normalized.cost.is_none() {
239 if let Some(total) = find_total_cost_usd(raw) {
240 normalized.cost = Some(AgentUsageCost {
241 total: Some(total),
242 currency: Some("USD".into()),
243 ..AgentUsageCost::default()
244 });
245 }
246 }
247 Some(normalized)
248}
249
250fn find_total_cost_usd(value: &Value) -> Option<f64> {
251 match value {
252 Value::Array(items) => items.iter().find_map(find_total_cost_usd),
253 Value::Object(record) => {
254 number_field_f64(record, &["total_cost_usd", "costUSD", "cost_usd"])
255 .or_else(|| record.values().find_map(find_total_cost_usd))
256 }
257 _ => None,
258 }
259}