smol_workflow_engine/agent_providers/
claude_code.rs1use super::common::*;
2use super::types::*;
3use anyhow::{bail, Context};
4use serde_json::{json, Value};
5use std::collections::HashMap;
6use std::path::PathBuf;
7
8#[derive(Debug, Clone, Default)]
9pub struct ClaudeCodeAgentProviderOptions {
10 pub command: Option<String>,
11 pub subcommand: Vec<String>,
12 pub args: Vec<String>,
13 pub cwd: Option<PathBuf>,
14 pub env: HashMap<String, String>,
15 pub timeout_ms: Option<u64>,
16}
17
18#[derive(Debug, Clone, Default)]
19pub struct ClaudeCodeAgentProvider {
20 options: ClaudeCodeAgentProviderOptions,
21}
22
23impl ClaudeCodeAgentProvider {
24 pub fn new(options: ClaudeCodeAgentProviderOptions) -> Self {
25 Self { options }
26 }
27}
28
29#[async_trait::async_trait]
30impl AgentProvider for ClaudeCodeAgentProvider {
31 fn name(&self) -> &str {
32 "claude-code"
33 }
34
35 fn schema_mode(&self) -> AgentProviderSchemaMode {
36 AgentProviderSchemaMode::Builtin
37 }
38
39 fn usage_mode(&self) -> AgentProviderUsageMode {
40 AgentProviderUsageMode::Builtin
41 }
42
43 async fn run(&self, input: AgentProviderRunInput) -> anyhow::Result<AgentProviderResult> {
44 run_claude_code(input, &self.options).await
45 }
46}
47
48async fn run_claude_code(
49 input: AgentProviderRunInput,
50 options: &ClaudeCodeAgentProviderOptions,
51) -> anyhow::Result<AgentProviderResult> {
52 let command = options.command.as_deref().unwrap_or("claude");
53 let mut args = Vec::<String>::new();
54 args.extend(options.subcommand.clone());
55 args.extend(options.args.clone());
56 if let Some(model) = option_str(&input.options, "model") {
57 args.extend(["--model".into(), model]);
58 }
59 if let Some(thinking) = option_str(&input.options, "thinking") {
60 args.extend(["--effort".into(), thinking]);
61 }
62 if let Some(agent_type) = option_str(&input.options, "agentType") {
63 args.extend(["--agent".into(), agent_type]);
64 }
65 args.extend([
66 "--output-format".into(),
67 "stream-json".into(),
68 "--verbose".into(),
69 "--input-format".into(),
70 "text".into(),
71 ]);
72 if let Some(schema) = option_schema(&input.options) {
73 args.extend(["--json-schema".into(), serde_json::to_string(schema)?]);
74 }
75 args.push("--print".into());
76
77 let cwd = input.context.cwd.as_deref().or(options.cwd.as_deref());
78 let (stdout, stderr) = run_command(RunCommandRequest {
79 provider: "Claude Code",
80 command,
81 args: &args,
82 stdin: Some(&input.prompt),
83 cwd,
84 env: &options.env,
85 timeout_ms: options.timeout_ms,
86 environment: input.environment.as_ref(),
87 })
88 .await?;
89 let events = parse_json_lines(&stdout);
90 let raw = if events.is_empty() {
91 parse_json_or_text(&stdout)
92 } else {
93 Value::Array(events.clone())
94 };
95 let structured = option_schema(&input.options).is_some();
96 let output = extract_output(&raw, structured)?;
97 let session_id = extract_session_id(&raw)
98 .context("Claude Code provider response did not include a session id")?;
99 let event_payloads = if events.is_empty() {
100 vec![raw.clone()]
101 } else {
102 events
103 };
104
105 Ok(AgentProviderResult {
106 output,
107 session_id: Some(session_id),
108 model: extract_model(&raw).or_else(|| option_model(&input.options)),
109 usage: extract_usage(&raw),
110 isolation: None,
111 raw: Some(to_json_value(
112 json!({ "events": event_payloads, "response": raw, "stderr": stderr }),
113 )),
114 })
115}
116
117fn extract_output(raw: &Value, structured: bool) -> anyhow::Result<Value> {
118 if structured {
119 if let Some(output) = extract_structured_output(raw) {
120 return Ok(output);
121 }
122 }
123
124 let candidate = extract_output_candidate(raw);
125 if !structured {
126 return Ok(match candidate {
127 Value::String(text) => Value::String(text.trim_end().to_string()),
128 value => value,
129 });
130 }
131
132 match candidate {
133 Value::String(text) => parse_structured_output(&text),
134 value => Ok(value),
135 }
136}
137
138fn extract_structured_output(raw: &Value) -> Option<Value> {
139 match raw {
140 Value::Array(items) => items.iter().find_map(extract_structured_output),
141 Value::Object(record) => record
142 .get("structured_output")
143 .or_else(|| record.get("structuredOutput"))
144 .cloned(),
145 _ => None,
146 }
147}
148
149fn extract_output_candidate(raw: &Value) -> Value {
150 match raw {
151 Value::String(_) => raw.clone(),
152 Value::Array(items) => items
153 .iter()
154 .rev()
155 .map(extract_output_candidate)
156 .find(|value| !value.is_null())
157 .unwrap_or_else(|| raw.clone()),
158 Value::Object(record) => {
159 for key in ["result", "output", "text", "content"] {
160 if let Some(value) = record.get(key) {
161 return extract_content_text(value);
162 }
163 }
164 if let Some(message) = record.get("message") {
165 if message.is_object() {
166 return extract_output_candidate(message);
167 }
168 }
169 raw.clone()
170 }
171 _ => raw.clone(),
172 }
173}
174
175fn extract_content_text(value: &Value) -> Value {
176 match value {
177 Value::Array(items) => Value::String(
178 items
179 .iter()
180 .map(|item| match item {
181 Value::String(text) => text.clone(),
182 Value::Object(record) => record
183 .get("text")
184 .and_then(Value::as_str)
185 .unwrap_or("")
186 .to_string(),
187 _ => String::new(),
188 })
189 .collect::<Vec<_>>()
190 .join(""),
191 ),
192 _ => value.clone(),
193 }
194}
195
196fn parse_structured_output(text: &str) -> anyhow::Result<Value> {
197 let trimmed = text.trim();
198 if let Ok(value) = serde_json::from_str(trimmed) {
199 return Ok(value);
200 }
201 if let Some(value) = extract_fenced_json(trimmed) {
202 return serde_json::from_str(value)
203 .context("Claude Code provider did not return valid JSON for schema output");
204 }
205 bail!("Claude Code provider did not return valid JSON for schema output")
206}
207
208fn extract_fenced_json(text: &str) -> Option<&str> {
209 let start = text.find("```")?;
210 let after = &text[start + 3..];
211 let after = after.strip_prefix("json").unwrap_or(after).trim_start();
212 let end = after.find("```")?;
213 Some(after[..end].trim())
214}
215
216fn extract_session_id(raw: &Value) -> Option<String> {
217 match raw {
218 Value::Array(items) => items.iter().find_map(extract_session_id),
219 Value::Object(record) => {
220 if let Some(value) = record
221 .get("session_id")
222 .or_else(|| record.get("sessionId"))
223 .or_else(|| record.get("sessionID"))
224 .and_then(Value::as_str)
225 {
226 return Some(value.to_string());
227 }
228 record.values().find_map(extract_session_id)
229 }
230 _ => None,
231 }
232}
233
234fn extract_usage(raw: &Value) -> Option<AgentUsage> {
235 let mut usage_objects = Vec::new();
236 find_usage_objects(raw, &mut usage_objects);
237 let usage = usage_objects.last()?;
238 let mut normalized = normalize_usage(usage);
239 if normalized.cost.is_none() {
240 if let Some(total) = find_total_cost_usd(raw) {
241 normalized.cost = Some(AgentUsageCost {
242 total: Some(total),
243 currency: Some("USD".into()),
244 ..AgentUsageCost::default()
245 });
246 }
247 }
248 Some(normalized)
249}
250
251fn find_total_cost_usd(value: &Value) -> Option<f64> {
252 match value {
253 Value::Array(items) => items.iter().find_map(find_total_cost_usd),
254 Value::Object(record) => {
255 number_field_f64(record, &["total_cost_usd", "costUSD", "cost_usd"])
256 .or_else(|| record.values().find_map(find_total_cost_usd))
257 }
258 _ => None,
259 }
260}