Skip to main content

smol_workflow_engine/agent_providers/
claude_code.rs

1use super::common::*;
2use super::types::*;
3use anyhow::{bail, Context};
4use serde_json::{json, Value};
5use std::collections::HashMap;
6use std::path::PathBuf;
7
8#[derive(Debug, Clone, Default)]
9pub struct ClaudeCodeAgentProviderOptions {
10    pub command: Option<String>,
11    pub subcommand: Vec<String>,
12    pub args: Vec<String>,
13    pub cwd: Option<PathBuf>,
14    pub env: HashMap<String, String>,
15    pub timeout_ms: Option<u64>,
16}
17
18#[derive(Debug, Clone, Default)]
19pub struct ClaudeCodeAgentProvider {
20    options: ClaudeCodeAgentProviderOptions,
21}
22
23impl ClaudeCodeAgentProvider {
24    pub fn new(options: ClaudeCodeAgentProviderOptions) -> Self {
25        Self { options }
26    }
27}
28
29#[async_trait::async_trait]
30impl AgentProvider for ClaudeCodeAgentProvider {
31    fn name(&self) -> &str {
32        "claude-code"
33    }
34
35    fn schema_mode(&self) -> AgentProviderSchemaMode {
36        AgentProviderSchemaMode::Builtin
37    }
38
39    fn usage_mode(&self) -> AgentProviderUsageMode {
40        AgentProviderUsageMode::Builtin
41    }
42
43    async fn run(&self, input: AgentProviderRunInput) -> anyhow::Result<AgentProviderResult> {
44        run_claude_code(input, &self.options).await
45    }
46}
47
48async fn run_claude_code(
49    input: AgentProviderRunInput,
50    options: &ClaudeCodeAgentProviderOptions,
51) -> anyhow::Result<AgentProviderResult> {
52    let command = options.command.as_deref().unwrap_or("claude");
53    let mut args = Vec::<String>::new();
54    args.extend(options.subcommand.clone());
55    args.extend(options.args.clone());
56    if let Some(model) = option_str(&input.options, "model") {
57        args.extend(["--model".into(), model]);
58    }
59    if let Some(thinking) = option_str(&input.options, "thinking") {
60        args.extend(["--effort".into(), thinking]);
61    }
62    if let Some(agent_type) = option_str(&input.options, "agentType") {
63        args.extend(["--agent".into(), agent_type]);
64    }
65    args.extend([
66        "--output-format".into(),
67        "stream-json".into(),
68        "--verbose".into(),
69        "--input-format".into(),
70        "text".into(),
71    ]);
72    if let Some(schema) = option_schema(&input.options) {
73        args.extend(["--json-schema".into(), serde_json::to_string(schema)?]);
74    }
75    args.push("--print".into());
76
77    let cwd = input.context.cwd.as_deref().or(options.cwd.as_deref());
78    let (stdout, stderr) = run_command(
79        "Claude Code",
80        command,
81        &args,
82        Some(&input.prompt),
83        cwd,
84        &options.env,
85        options.timeout_ms,
86    )
87    .await?;
88    let events = parse_json_lines(&stdout);
89    let raw = if events.is_empty() {
90        parse_json_or_text(&stdout)
91    } else {
92        Value::Array(events.clone())
93    };
94    let structured = option_schema(&input.options).is_some();
95    let output = extract_output(&raw, structured)?;
96    let session_id = extract_session_id(&raw)
97        .context("Claude Code provider response did not include a session id")?;
98    let event_payloads = if events.is_empty() {
99        vec![raw.clone()]
100    } else {
101        events
102    };
103
104    Ok(AgentProviderResult {
105        output,
106        session_id: Some(session_id),
107        model: extract_model(&raw).or_else(|| option_model(&input.options)),
108        usage: extract_usage(&raw),
109        isolation: None,
110        raw: Some(to_json_value(
111            json!({ "events": event_payloads, "response": raw, "stderr": stderr }),
112        )),
113    })
114}
115
116fn extract_output(raw: &Value, structured: bool) -> anyhow::Result<Value> {
117    if structured {
118        if let Some(output) = extract_structured_output(raw) {
119            return Ok(output);
120        }
121    }
122
123    let candidate = extract_output_candidate(raw);
124    if !structured {
125        return Ok(match candidate {
126            Value::String(text) => Value::String(text.trim_end().to_string()),
127            value => value,
128        });
129    }
130
131    match candidate {
132        Value::String(text) => parse_structured_output(&text),
133        value => Ok(value),
134    }
135}
136
137fn extract_structured_output(raw: &Value) -> Option<Value> {
138    match raw {
139        Value::Array(items) => items.iter().find_map(extract_structured_output),
140        Value::Object(record) => record
141            .get("structured_output")
142            .or_else(|| record.get("structuredOutput"))
143            .cloned(),
144        _ => None,
145    }
146}
147
148fn extract_output_candidate(raw: &Value) -> Value {
149    match raw {
150        Value::String(_) => raw.clone(),
151        Value::Array(items) => items
152            .iter()
153            .rev()
154            .map(extract_output_candidate)
155            .find(|value| !value.is_null())
156            .unwrap_or_else(|| raw.clone()),
157        Value::Object(record) => {
158            for key in ["result", "output", "text", "content"] {
159                if let Some(value) = record.get(key) {
160                    return extract_content_text(value);
161                }
162            }
163            if let Some(message) = record.get("message") {
164                if message.is_object() {
165                    return extract_output_candidate(message);
166                }
167            }
168            raw.clone()
169        }
170        _ => raw.clone(),
171    }
172}
173
174fn extract_content_text(value: &Value) -> Value {
175    match value {
176        Value::Array(items) => Value::String(
177            items
178                .iter()
179                .map(|item| match item {
180                    Value::String(text) => text.clone(),
181                    Value::Object(record) => record
182                        .get("text")
183                        .and_then(Value::as_str)
184                        .unwrap_or("")
185                        .to_string(),
186                    _ => String::new(),
187                })
188                .collect::<Vec<_>>()
189                .join(""),
190        ),
191        _ => value.clone(),
192    }
193}
194
195fn parse_structured_output(text: &str) -> anyhow::Result<Value> {
196    let trimmed = text.trim();
197    if let Ok(value) = serde_json::from_str(trimmed) {
198        return Ok(value);
199    }
200    if let Some(value) = extract_fenced_json(trimmed) {
201        return serde_json::from_str(value)
202            .context("Claude Code provider did not return valid JSON for schema output");
203    }
204    bail!("Claude Code provider did not return valid JSON for schema output")
205}
206
207fn extract_fenced_json(text: &str) -> Option<&str> {
208    let start = text.find("```")?;
209    let after = &text[start + 3..];
210    let after = after.strip_prefix("json").unwrap_or(after).trim_start();
211    let end = after.find("```")?;
212    Some(after[..end].trim())
213}
214
215fn extract_session_id(raw: &Value) -> Option<String> {
216    match raw {
217        Value::Array(items) => items.iter().find_map(extract_session_id),
218        Value::Object(record) => {
219            if let Some(value) = record
220                .get("session_id")
221                .or_else(|| record.get("sessionId"))
222                .or_else(|| record.get("sessionID"))
223                .and_then(Value::as_str)
224            {
225                return Some(value.to_string());
226            }
227            record.values().find_map(extract_session_id)
228        }
229        _ => None,
230    }
231}
232
233fn extract_usage(raw: &Value) -> Option<AgentUsage> {
234    let mut usage_objects = Vec::new();
235    find_usage_objects(raw, &mut usage_objects);
236    let usage = usage_objects.last()?;
237    let mut normalized = normalize_usage(usage);
238    if normalized.cost.is_none() {
239        if let Some(total) = find_total_cost_usd(raw) {
240            normalized.cost = Some(AgentUsageCost {
241                total: Some(total),
242                currency: Some("USD".into()),
243                ..AgentUsageCost::default()
244            });
245        }
246    }
247    Some(normalized)
248}
249
250fn find_total_cost_usd(value: &Value) -> Option<f64> {
251    match value {
252        Value::Array(items) => items.iter().find_map(find_total_cost_usd),
253        Value::Object(record) => {
254            number_field_f64(record, &["total_cost_usd", "costUSD", "cost_usd"])
255                .or_else(|| record.values().find_map(find_total_cost_usd))
256        }
257        _ => None,
258    }
259}