Skip to main content

smol_workflow_engine/agent_providers/
codex.rs

1use super::common::*;
2use super::types::*;
3use crate::environment::EnvironmentPath;
4use anyhow::{bail, Context};
5use serde_json::{json, Map, Value};
6use std::collections::HashMap;
7use std::path::PathBuf;
8
9#[derive(Debug, Clone)]
10pub struct CodexAgentProviderOptions {
11    pub command: Option<String>,
12    pub subcommand: Vec<String>,
13    pub args: Vec<String>,
14    pub cwd: Option<PathBuf>,
15    pub env: HashMap<String, String>,
16    pub timeout_ms: Option<u64>,
17}
18
19impl Default for CodexAgentProviderOptions {
20    fn default() -> Self {
21        Self {
22            command: None,
23            subcommand: vec!["exec".into()],
24            args: Vec::new(),
25            cwd: None,
26            env: HashMap::new(),
27            timeout_ms: None,
28        }
29    }
30}
31
32#[derive(Debug, Clone, Default)]
33pub struct CodexAgentProvider {
34    options: CodexAgentProviderOptions,
35}
36
37impl CodexAgentProvider {
38    pub fn new(options: CodexAgentProviderOptions) -> Self {
39        Self { options }
40    }
41}
42
43#[async_trait::async_trait]
44impl AgentProvider for CodexAgentProvider {
45    fn name(&self) -> &str {
46        "codex"
47    }
48
49    fn schema_mode(&self) -> AgentProviderSchemaMode {
50        AgentProviderSchemaMode::Builtin
51    }
52
53    fn usage_mode(&self) -> AgentProviderUsageMode {
54        AgentProviderUsageMode::Builtin
55    }
56
57    async fn run(&self, input: AgentProviderRunInput) -> anyhow::Result<AgentProviderResult> {
58        run_codex(input, &self.options).await
59    }
60}
61
62async fn run_codex(
63    input: AgentProviderRunInput,
64    options: &CodexAgentProviderOptions,
65) -> anyhow::Result<AgentProviderResult> {
66    let temp = input.environment.create_temp_dir("smol-wf-codex-").await?;
67    let output_path = join_environment_path(&temp, "last-message.txt");
68    let schema_path = join_environment_path(&temp, "schema.json");
69    let command = options.command.as_deref().unwrap_or("codex");
70    let mut args = Vec::new();
71    args.extend(options.subcommand.clone());
72    if args.is_empty() {
73        args.push("exec".into());
74    }
75    args.extend(options.args.clone());
76    if cfg!(feature = "integration-test") && !args.iter().any(|arg| arg == "--skip-git-repo-check")
77    {
78        args.push("--skip-git-repo-check".into());
79    }
80    if let Some(model) = option_str(&input.options, "model") {
81        args.extend(["--model".into(), model]);
82    }
83    args.extend([
84        "--json".into(),
85        "--output-last-message".into(),
86        output_path.0.clone(),
87    ]);
88    let has_schema = option_schema(&input.options).is_some();
89    if let Some(schema) = option_schema(&input.options) {
90        let schema = to_codex_output_schema(schema);
91        input
92            .environment
93            .write_file(
94                &schema_path,
95                serde_json::to_string_pretty(&schema)?.as_bytes(),
96            )
97            .await?;
98        args.extend(["--output-schema".into(), schema_path.0.clone()]);
99    }
100    args.push("-".into());
101
102    let cwd = input.context.cwd.as_deref().or(options.cwd.as_deref());
103    let (stdout, stderr) = run_command(RunCommandRequest {
104        provider: "Codex",
105        command,
106        args: &args,
107        stdin: Some(&input.prompt),
108        cwd,
109        env: &options.env,
110        timeout_ms: options.timeout_ms,
111        environment: input.environment.as_ref(),
112    })
113    .await?;
114    let events = parse_json_lines(&stdout);
115    let session_id = extract_session_id(&events)
116        .context("Codex provider response did not include a session id")?;
117    let final_message =
118        read_final_message(input.environment.as_ref(), &output_path, &events).await?;
119    let output = if has_schema {
120        parse_structured_output(&final_message)?
121    } else {
122        Value::String(final_message.trim_end().to_string())
123    };
124
125    Ok(AgentProviderResult {
126        output,
127        session_id: Some(session_id),
128        model: extract_model(&Value::Array(events.clone()))
129            .or_else(|| option_model(&input.options)),
130        usage: extract_usage(&events),
131        isolation: None,
132        raw: Some(to_json_value(json!({ "events": events, "stderr": stderr }))),
133    })
134}
135
136fn join_environment_path(base: &EnvironmentPath, child: &str) -> EnvironmentPath {
137    EnvironmentPath(format!("{}/{}", base.as_str().trim_end_matches('/'), child))
138}
139
140async fn read_final_message(
141    environment: &dyn crate::environment::AgentExecutionEnvironment,
142    path: &EnvironmentPath,
143    events: &[Value],
144) -> anyhow::Result<String> {
145    match environment.read_file(path).await {
146        Ok(bytes) => {
147            let message = String::from_utf8_lossy(&bytes).into_owned();
148            if !message.trim().is_empty() {
149                return Ok(message);
150            }
151        }
152        Err(error) => {
153            let not_found = error
154                .chain()
155                .find_map(|cause| cause.downcast_ref::<std::io::Error>())
156                .is_some_and(|error| error.kind() == std::io::ErrorKind::NotFound);
157            if !not_found {
158                bail!("Failed to read codex output file: {error}");
159            }
160        }
161    }
162    if let Some(text) = extract_last_assistant_text(events) {
163        Ok(text)
164    } else {
165        bail!("Codex provider did not return a final assistant message")
166    }
167}
168
169fn to_codex_output_schema(schema: &Value) -> Value {
170    match schema {
171        Value::Array(items) => Value::Array(items.iter().map(to_codex_output_schema).collect()),
172        Value::Object(record) => {
173            let mut output = Map::new();
174            for (key, value) in record {
175                output.insert(key.clone(), to_codex_output_schema(value));
176            }
177            if is_object_schema(&output) {
178                let properties = output
179                    .get("properties")
180                    .and_then(Value::as_object)
181                    .cloned()
182                    .unwrap_or_default();
183                output.insert(
184                    "properties".into(),
185                    to_codex_output_schema(&Value::Object(properties)),
186                );
187                output.insert(
188                    "required".into(),
189                    record
190                        .get("required")
191                        .filter(|v| v.is_array())
192                        .cloned()
193                        .unwrap_or_else(|| json!([])),
194                );
195                output.insert("additionalProperties".into(), Value::Bool(false));
196            }
197            Value::Object(output)
198        }
199        _ => schema.clone(),
200    }
201}
202
203fn is_object_schema(schema: &Map<String, Value>) -> bool {
204    schema.get("type") == Some(&Value::String("object".into())) || schema.contains_key("properties")
205}
206
207fn parse_structured_output(text: &str) -> anyhow::Result<Value> {
208    parse_structured_output_seen(text.trim(), &mut Vec::new())
209}
210
211fn parse_structured_output_seen(text: &str, seen: &mut Vec<String>) -> anyhow::Result<Value> {
212    let trimmed = text.trim();
213    if seen.iter().any(|item| item == trimmed) {
214        bail!("Codex provider did not return valid JSON for schema output");
215    }
216    seen.push(trimmed.to_string());
217
218    if let Ok(parsed) = serde_json::from_str::<Value>(trimmed) {
219        if let Value::String(inner) = parsed {
220            return parse_structured_output_seen(&inner, seen);
221        }
222        return Ok(parsed);
223    }
224
225    if let Some(fenced) = extract_fenced_json(trimmed) {
226        return parse_structured_output_seen(fenced, seen);
227    }
228    if let Some(unescaped) = try_unescape_json_like_text(trimmed) {
229        return parse_structured_output_seen(&unescaped, seen);
230    }
231    if let Some(object_text) = extract_likely_json_text(trimmed) {
232        return parse_structured_output_seen(object_text, seen);
233    }
234    bail!("Codex provider did not return valid JSON for schema output")
235}
236
237fn extract_fenced_json(text: &str) -> Option<&str> {
238    let start = text.find("```")?;
239    let after = &text[start + 3..];
240    let after = after.strip_prefix("json").unwrap_or(after).trim_start();
241    let end = after.find("```")?;
242    Some(after[..end].trim())
243}
244
245fn try_unescape_json_like_text(text: &str) -> Option<String> {
246    if !text.contains("\\n") && !text.contains("\\\"") {
247        return None;
248    }
249    serde_json::from_str::<String>(&format!("\"{text}\""))
250        .ok()
251        .or_else(|| {
252            Some(
253                text.replace("\\n", "\n")
254                    .replace("\\t", "\t")
255                    .replace("\\\"", "\""),
256            )
257        })
258}
259
260fn extract_likely_json_text(text: &str) -> Option<&str> {
261    let object = text.find('{').zip(text.rfind('}')).filter(|(s, e)| e > s);
262    let array = text.find('[').zip(text.rfind(']')).filter(|(s, e)| e > s);
263    object.or(array).map(|(s, e)| &text[s..=e])
264}
265
266fn extract_last_assistant_text(events: &[Value]) -> Option<String> {
267    let mut text = None;
268    for event in events {
269        if let Some(candidate) = extract_assistant_text(event) {
270            text = Some(candidate);
271        }
272    }
273    text
274}
275
276fn extract_assistant_text(value: &Value) -> Option<String> {
277    match value {
278        Value::Array(items) => items.iter().rev().find_map(extract_assistant_text),
279        Value::Object(record) => {
280            let text = extract_text(
281                record
282                    .get("text")
283                    .or_else(|| record.get("output"))
284                    .or_else(|| record.get("message"))
285                    .or_else(|| record.get("content"))?,
286            );
287            if (matches!(
288                record.get("role").and_then(Value::as_str),
289                Some("assistant")
290            ) || matches!(
291                record.get("type").and_then(Value::as_str),
292                Some("assistant_message" | "message")
293            )) && text.is_some()
294            {
295                return text;
296            }
297            for key in [
298                "message",
299                "content",
300                "output",
301                "text",
302                "delta",
303                "part",
304                "parts",
305                "item",
306                "event",
307                "data",
308                "properties",
309            ] {
310                if let Some(candidate) = record.get(key).and_then(extract_assistant_text) {
311                    return Some(candidate);
312                }
313            }
314            None
315        }
316        _ => None,
317    }
318}
319
320fn extract_text(value: &Value) -> Option<String> {
321    match value {
322        Value::String(text) => Some(text.clone()),
323        Value::Array(items) => {
324            let text = items
325                .iter()
326                .map(|item| extract_text(item).unwrap_or_default())
327                .collect::<Vec<_>>()
328                .join("");
329            (!text.is_empty()).then_some(text)
330        }
331        Value::Object(record) => record
332            .get("text")
333            .or_else(|| record.get("content"))
334            .or_else(|| record.get("message"))
335            .or_else(|| record.get("output"))
336            .and_then(extract_text),
337        _ => None,
338    }
339}
340
341fn extract_session_id(events: &[Value]) -> Option<String> {
342    for event in events {
343        if event.get("type").and_then(Value::as_str) == Some("session_meta") {
344            if let Some(id) = get_path(event, &["payload", "id"]).and_then(Value::as_str) {
345                return Some(id.to_string());
346            }
347        }
348        if event.get("type").and_then(Value::as_str) == Some("thread.started") {
349            if let Some(id) = event.get("thread_id").and_then(Value::as_str) {
350                return Some(id.to_string());
351            }
352        }
353        if let Some(id) = event
354            .get("session_id")
355            .or_else(|| event.get("sessionId"))
356            .or_else(|| event.get("sessionID"))
357            .and_then(Value::as_str)
358        {
359            return Some(id.to_string());
360        }
361    }
362    None
363}
364
365fn extract_usage(events: &[Value]) -> Option<AgentUsage> {
366    let mut usage = None;
367    for event in events {
368        if let Some(candidate) = find_first_usage_object(event) {
369            usage = Some(merge_usage_right(usage, normalize_usage(&candidate)));
370        }
371    }
372    usage
373}