Skip to main content

smol_workflow_engine/agent_providers/
claude_code.rs

1use super::common::*;
2use super::types::*;
3use anyhow::{bail, Context};
4use serde_json::{json, Value};
5use std::collections::HashMap;
6use std::path::PathBuf;
7
8#[derive(Debug, Clone, Default)]
9pub struct ClaudeCodeAgentProviderOptions {
10    pub command: Option<String>,
11    pub subcommand: Vec<String>,
12    pub args: Vec<String>,
13    pub cwd: Option<PathBuf>,
14    pub env: HashMap<String, String>,
15    pub timeout_ms: Option<u64>,
16}
17
18#[derive(Debug, Clone, Default)]
19pub struct ClaudeCodeAgentProvider {
20    options: ClaudeCodeAgentProviderOptions,
21}
22
23impl ClaudeCodeAgentProvider {
24    pub fn new(options: ClaudeCodeAgentProviderOptions) -> Self {
25        Self { options }
26    }
27}
28
29#[async_trait::async_trait]
30impl AgentProvider for ClaudeCodeAgentProvider {
31    fn name(&self) -> &str {
32        "claude-code"
33    }
34
35    fn schema_mode(&self) -> AgentProviderSchemaMode {
36        AgentProviderSchemaMode::Builtin
37    }
38
39    fn usage_mode(&self) -> AgentProviderUsageMode {
40        AgentProviderUsageMode::Builtin
41    }
42
43    async fn run(&self, input: AgentProviderRunInput) -> anyhow::Result<AgentProviderResult> {
44        run_claude_code(input, &self.options).await
45    }
46}
47
48async fn run_claude_code(
49    input: AgentProviderRunInput,
50    options: &ClaudeCodeAgentProviderOptions,
51) -> anyhow::Result<AgentProviderResult> {
52    let command = options.command.as_deref().unwrap_or("claude");
53    let mut args = Vec::<String>::new();
54    args.extend(options.subcommand.clone());
55    args.extend(options.args.clone());
56    if let Some(model) = option_str(&input.options, "model") {
57        args.extend(["--model".into(), model]);
58    }
59    if let Some(thinking) = option_str(&input.options, "thinking") {
60        args.extend(["--effort".into(), thinking]);
61    }
62    if let Some(agent_type) = option_str(&input.options, "agentType") {
63        args.extend(["--agent".into(), agent_type]);
64    }
65    args.extend([
66        "--output-format".into(),
67        "stream-json".into(),
68        "--verbose".into(),
69        "--input-format".into(),
70        "text".into(),
71    ]);
72    if let Some(schema) = option_schema(&input.options) {
73        args.extend(["--json-schema".into(), serde_json::to_string(schema)?]);
74    }
75    args.push("--print".into());
76
77    let cwd = input.context.cwd.as_deref().or(options.cwd.as_deref());
78    let (stdout, stderr) = run_command(RunCommandRequest {
79        provider: "Claude Code",
80        command,
81        args: &args,
82        stdin: Some(&input.prompt),
83        cwd,
84        env: &options.env,
85        timeout_ms: options.timeout_ms,
86        environment: input.environment.as_ref(),
87    })
88    .await?;
89    let events = parse_json_lines(&stdout);
90    let raw = if events.is_empty() {
91        parse_json_or_text(&stdout)
92    } else {
93        Value::Array(events.clone())
94    };
95    let structured = option_schema(&input.options).is_some();
96    let output = extract_output(&raw, structured)?;
97    let session_id = extract_session_id(&raw)
98        .context("Claude Code provider response did not include a session id")?;
99    let event_payloads = if events.is_empty() {
100        vec![raw.clone()]
101    } else {
102        events
103    };
104
105    Ok(AgentProviderResult {
106        output,
107        session_id: Some(session_id),
108        model: extract_model(&raw).or_else(|| option_model(&input.options)),
109        usage: extract_usage(&raw),
110        isolation: None,
111        raw: Some(to_json_value(
112            json!({ "events": event_payloads, "response": raw, "stderr": stderr }),
113        )),
114    })
115}
116
117fn extract_output(raw: &Value, structured: bool) -> anyhow::Result<Value> {
118    if structured {
119        if let Some(output) = extract_structured_output(raw) {
120            return Ok(output);
121        }
122    }
123
124    let candidate = extract_output_candidate(raw);
125    if !structured {
126        return Ok(match candidate {
127            Value::String(text) => Value::String(text.trim_end().to_string()),
128            value => value,
129        });
130    }
131
132    match candidate {
133        Value::String(text) => parse_structured_output(&text),
134        value => Ok(value),
135    }
136}
137
138fn extract_structured_output(raw: &Value) -> Option<Value> {
139    match raw {
140        Value::Array(items) => items.iter().find_map(extract_structured_output),
141        Value::Object(record) => record
142            .get("structured_output")
143            .or_else(|| record.get("structuredOutput"))
144            .cloned(),
145        _ => None,
146    }
147}
148
149fn extract_output_candidate(raw: &Value) -> Value {
150    match raw {
151        Value::String(_) => raw.clone(),
152        Value::Array(items) => items
153            .iter()
154            .rev()
155            .map(extract_output_candidate)
156            .find(|value| !value.is_null())
157            .unwrap_or_else(|| raw.clone()),
158        Value::Object(record) => {
159            for key in ["result", "output", "text", "content"] {
160                if let Some(value) = record.get(key) {
161                    return extract_content_text(value);
162                }
163            }
164            if let Some(message) = record.get("message") {
165                if message.is_object() {
166                    return extract_output_candidate(message);
167                }
168            }
169            raw.clone()
170        }
171        _ => raw.clone(),
172    }
173}
174
175fn extract_content_text(value: &Value) -> Value {
176    match value {
177        Value::Array(items) => Value::String(
178            items
179                .iter()
180                .map(|item| match item {
181                    Value::String(text) => text.clone(),
182                    Value::Object(record) => record
183                        .get("text")
184                        .and_then(Value::as_str)
185                        .unwrap_or("")
186                        .to_string(),
187                    _ => String::new(),
188                })
189                .collect::<Vec<_>>()
190                .join(""),
191        ),
192        _ => value.clone(),
193    }
194}
195
196fn parse_structured_output(text: &str) -> anyhow::Result<Value> {
197    let trimmed = text.trim();
198    if let Ok(value) = serde_json::from_str(trimmed) {
199        return Ok(value);
200    }
201    if let Some(value) = extract_fenced_json(trimmed) {
202        return serde_json::from_str(value)
203            .context("Claude Code provider did not return valid JSON for schema output");
204    }
205    bail!("Claude Code provider did not return valid JSON for schema output")
206}
207
208fn extract_fenced_json(text: &str) -> Option<&str> {
209    let start = text.find("```")?;
210    let after = &text[start + 3..];
211    let after = after.strip_prefix("json").unwrap_or(after).trim_start();
212    let end = after.find("```")?;
213    Some(after[..end].trim())
214}
215
216fn extract_session_id(raw: &Value) -> Option<String> {
217    match raw {
218        Value::Array(items) => items.iter().find_map(extract_session_id),
219        Value::Object(record) => {
220            if let Some(value) = record
221                .get("session_id")
222                .or_else(|| record.get("sessionId"))
223                .or_else(|| record.get("sessionID"))
224                .and_then(Value::as_str)
225            {
226                return Some(value.to_string());
227            }
228            record.values().find_map(extract_session_id)
229        }
230        _ => None,
231    }
232}
233
234fn extract_usage(raw: &Value) -> Option<AgentUsage> {
235    let mut usage_objects = Vec::new();
236    find_usage_objects(raw, &mut usage_objects);
237    let usage = usage_objects.last()?;
238    let mut normalized = normalize_usage(usage);
239    if normalized.cost.is_none() {
240        if let Some(total) = find_total_cost_usd(raw) {
241            normalized.cost = Some(AgentUsageCost {
242                total: Some(total),
243                currency: Some("USD".into()),
244                ..AgentUsageCost::default()
245            });
246        }
247    }
248    Some(normalized)
249}
250
251fn find_total_cost_usd(value: &Value) -> Option<f64> {
252    match value {
253        Value::Array(items) => items.iter().find_map(find_total_cost_usd),
254        Value::Object(record) => {
255            number_field_f64(record, &["total_cost_usd", "costUSD", "cost_usd"])
256                .or_else(|| record.values().find_map(find_total_cost_usd))
257        }
258        _ => None,
259    }
260}