Skip to main content

smol_workflow_engine/agent_providers/
codex.rs

1use super::common::*;
2use super::types::*;
3use anyhow::{bail, Context};
4use serde_json::{json, Map, Value};
5use std::collections::HashMap;
6use std::fs;
7use std::path::PathBuf;
8
9#[derive(Debug, Clone)]
10pub struct CodexAgentProviderOptions {
11    pub command: Option<String>,
12    pub subcommand: Vec<String>,
13    pub args: Vec<String>,
14    pub cwd: Option<PathBuf>,
15    pub env: HashMap<String, String>,
16    pub timeout_ms: Option<u64>,
17}
18
19impl Default for CodexAgentProviderOptions {
20    fn default() -> Self {
21        Self {
22            command: None,
23            subcommand: vec!["exec".into()],
24            args: Vec::new(),
25            cwd: None,
26            env: HashMap::new(),
27            timeout_ms: None,
28        }
29    }
30}
31
32#[derive(Debug, Clone, Default)]
33pub struct CodexAgentProvider {
34    options: CodexAgentProviderOptions,
35}
36
37impl CodexAgentProvider {
38    pub fn new(options: CodexAgentProviderOptions) -> Self {
39        Self { options }
40    }
41}
42
43#[async_trait::async_trait]
44impl AgentProvider for CodexAgentProvider {
45    fn name(&self) -> &str {
46        "codex"
47    }
48
49    fn schema_mode(&self) -> AgentProviderSchemaMode {
50        AgentProviderSchemaMode::Builtin
51    }
52
53    fn usage_mode(&self) -> AgentProviderUsageMode {
54        AgentProviderUsageMode::Builtin
55    }
56
57    async fn run(&self, input: AgentProviderRunInput) -> anyhow::Result<AgentProviderResult> {
58        run_codex(input, &self.options).await
59    }
60}
61
62async fn run_codex(
63    input: AgentProviderRunInput,
64    options: &CodexAgentProviderOptions,
65) -> anyhow::Result<AgentProviderResult> {
66    let temp = temp_dir("smol-wf-codex-")?;
67    let output_path = temp.path().join("last-message.txt");
68    let schema_path = temp.path().join("schema.json");
69    let command = options.command.as_deref().unwrap_or("codex");
70    let mut args = Vec::new();
71    args.extend(options.subcommand.clone());
72    if args.is_empty() {
73        args.push("exec".into());
74    }
75    args.extend(options.args.clone());
76    if cfg!(feature = "integration-test") && !args.iter().any(|arg| arg == "--skip-git-repo-check")
77    {
78        args.push("--skip-git-repo-check".into());
79    }
80    if let Some(model) = option_str(&input.options, "model") {
81        args.extend(["--model".into(), model]);
82    }
83    args.extend([
84        "--json".into(),
85        "--output-last-message".into(),
86        output_path.to_string_lossy().into_owned(),
87    ]);
88    let has_schema = option_schema(&input.options).is_some();
89    if let Some(schema) = option_schema(&input.options) {
90        let schema = to_codex_output_schema(schema);
91        fs::write(&schema_path, serde_json::to_string_pretty(&schema)?)?;
92        args.extend([
93            "--output-schema".into(),
94            schema_path.to_string_lossy().into_owned(),
95        ]);
96    }
97    args.push("-".into());
98
99    let cwd = input.context.cwd.as_deref().or(options.cwd.as_deref());
100    let (stdout, stderr) = run_command(
101        "Codex",
102        command,
103        &args,
104        Some(&input.prompt),
105        cwd,
106        &options.env,
107        options.timeout_ms,
108    )
109    .await?;
110    let events = parse_json_lines(&stdout);
111    let session_id = extract_session_id(&events)
112        .context("Codex provider response did not include a session id")?;
113    let final_message = read_final_message(&output_path, &events)?;
114    let output = if has_schema {
115        parse_structured_output(&final_message)?
116    } else {
117        Value::String(final_message.trim_end().to_string())
118    };
119
120    Ok(AgentProviderResult {
121        output,
122        session_id: Some(session_id),
123        model: extract_model(&Value::Array(events.clone()))
124            .or_else(|| option_model(&input.options)),
125        usage: extract_usage(&events),
126        isolation: None,
127        raw: Some(to_json_value(json!({ "events": events, "stderr": stderr }))),
128    })
129}
130
131fn read_final_message(path: &PathBuf, events: &[Value]) -> anyhow::Result<String> {
132    match fs::read_to_string(path) {
133        Ok(message) if !message.trim().is_empty() => return Ok(message),
134        Ok(_) => {}
135        Err(error) if error.kind() == std::io::ErrorKind::NotFound => {}
136        Err(error) => bail!("Failed to read codex output file: {error}"),
137    }
138    if let Some(text) = extract_last_assistant_text(events) {
139        Ok(text)
140    } else {
141        bail!("Codex provider did not return a final assistant message")
142    }
143}
144
145fn to_codex_output_schema(schema: &Value) -> Value {
146    match schema {
147        Value::Array(items) => Value::Array(items.iter().map(to_codex_output_schema).collect()),
148        Value::Object(record) => {
149            let mut output = Map::new();
150            for (key, value) in record {
151                output.insert(key.clone(), to_codex_output_schema(value));
152            }
153            if is_object_schema(&output) {
154                let properties = output
155                    .get("properties")
156                    .and_then(Value::as_object)
157                    .cloned()
158                    .unwrap_or_default();
159                output.insert(
160                    "properties".into(),
161                    to_codex_output_schema(&Value::Object(properties)),
162                );
163                output.insert(
164                    "required".into(),
165                    record
166                        .get("required")
167                        .filter(|v| v.is_array())
168                        .cloned()
169                        .unwrap_or_else(|| json!([])),
170                );
171                output.insert("additionalProperties".into(), Value::Bool(false));
172            }
173            Value::Object(output)
174        }
175        _ => schema.clone(),
176    }
177}
178
179fn is_object_schema(schema: &Map<String, Value>) -> bool {
180    schema.get("type") == Some(&Value::String("object".into())) || schema.contains_key("properties")
181}
182
183fn parse_structured_output(text: &str) -> anyhow::Result<Value> {
184    parse_structured_output_seen(text.trim(), &mut Vec::new())
185}
186
187fn parse_structured_output_seen(text: &str, seen: &mut Vec<String>) -> anyhow::Result<Value> {
188    let trimmed = text.trim();
189    if seen.iter().any(|item| item == trimmed) {
190        bail!("Codex provider did not return valid JSON for schema output");
191    }
192    seen.push(trimmed.to_string());
193
194    if let Ok(parsed) = serde_json::from_str::<Value>(trimmed) {
195        if let Value::String(inner) = parsed {
196            return parse_structured_output_seen(&inner, seen);
197        }
198        return Ok(parsed);
199    }
200
201    if let Some(fenced) = extract_fenced_json(trimmed) {
202        return parse_structured_output_seen(fenced, seen);
203    }
204    if let Some(unescaped) = try_unescape_json_like_text(trimmed) {
205        return parse_structured_output_seen(&unescaped, seen);
206    }
207    if let Some(object_text) = extract_likely_json_text(trimmed) {
208        return parse_structured_output_seen(object_text, seen);
209    }
210    bail!("Codex provider did not return valid JSON for schema output")
211}
212
213fn extract_fenced_json(text: &str) -> Option<&str> {
214    let start = text.find("```")?;
215    let after = &text[start + 3..];
216    let after = after.strip_prefix("json").unwrap_or(after).trim_start();
217    let end = after.find("```")?;
218    Some(after[..end].trim())
219}
220
221fn try_unescape_json_like_text(text: &str) -> Option<String> {
222    if !text.contains("\\n") && !text.contains("\\\"") {
223        return None;
224    }
225    serde_json::from_str::<String>(&format!("\"{text}\""))
226        .ok()
227        .or_else(|| {
228            Some(
229                text.replace("\\n", "\n")
230                    .replace("\\t", "\t")
231                    .replace("\\\"", "\""),
232            )
233        })
234}
235
236fn extract_likely_json_text(text: &str) -> Option<&str> {
237    let object = text.find('{').zip(text.rfind('}')).filter(|(s, e)| e > s);
238    let array = text.find('[').zip(text.rfind(']')).filter(|(s, e)| e > s);
239    object.or(array).map(|(s, e)| &text[s..=e])
240}
241
242fn extract_last_assistant_text(events: &[Value]) -> Option<String> {
243    let mut text = None;
244    for event in events {
245        if let Some(candidate) = extract_assistant_text(event) {
246            text = Some(candidate);
247        }
248    }
249    text
250}
251
252fn extract_assistant_text(value: &Value) -> Option<String> {
253    match value {
254        Value::Array(items) => items.iter().rev().find_map(extract_assistant_text),
255        Value::Object(record) => {
256            let text = extract_text(
257                record
258                    .get("text")
259                    .or_else(|| record.get("output"))
260                    .or_else(|| record.get("message"))
261                    .or_else(|| record.get("content"))?,
262            );
263            if (matches!(
264                record.get("role").and_then(Value::as_str),
265                Some("assistant")
266            ) || matches!(
267                record.get("type").and_then(Value::as_str),
268                Some("assistant_message" | "message")
269            )) && text.is_some()
270            {
271                return text;
272            }
273            for key in [
274                "message",
275                "content",
276                "output",
277                "text",
278                "delta",
279                "part",
280                "parts",
281                "item",
282                "event",
283                "data",
284                "properties",
285            ] {
286                if let Some(candidate) = record.get(key).and_then(extract_assistant_text) {
287                    return Some(candidate);
288                }
289            }
290            None
291        }
292        _ => None,
293    }
294}
295
296fn extract_text(value: &Value) -> Option<String> {
297    match value {
298        Value::String(text) => Some(text.clone()),
299        Value::Array(items) => {
300            let text = items
301                .iter()
302                .map(|item| extract_text(item).unwrap_or_default())
303                .collect::<Vec<_>>()
304                .join("");
305            (!text.is_empty()).then_some(text)
306        }
307        Value::Object(record) => record
308            .get("text")
309            .or_else(|| record.get("content"))
310            .or_else(|| record.get("message"))
311            .or_else(|| record.get("output"))
312            .and_then(extract_text),
313        _ => None,
314    }
315}
316
317fn extract_session_id(events: &[Value]) -> Option<String> {
318    for event in events {
319        if event.get("type").and_then(Value::as_str) == Some("session_meta") {
320            if let Some(id) = get_path(event, &["payload", "id"]).and_then(Value::as_str) {
321                return Some(id.to_string());
322            }
323        }
324        if event.get("type").and_then(Value::as_str) == Some("thread.started") {
325            if let Some(id) = event.get("thread_id").and_then(Value::as_str) {
326                return Some(id.to_string());
327            }
328        }
329        if let Some(id) = event
330            .get("session_id")
331            .or_else(|| event.get("sessionId"))
332            .or_else(|| event.get("sessionID"))
333            .and_then(Value::as_str)
334        {
335            return Some(id.to_string());
336        }
337    }
338    None
339}
340
341fn extract_usage(events: &[Value]) -> Option<AgentUsage> {
342    let mut usage = None;
343    for event in events {
344        if let Some(candidate) = find_first_usage_object(event) {
345            usage = Some(merge_usage_right(usage, normalize_usage(&candidate)));
346        }
347    }
348    usage
349}