1#![deny(missing_docs)]
2mod auth;
20mod convert;
21mod types;
22
23use convert::{messages_to_input, tools_to_codex};
24use futures_util::StreamExt;
25use layer0::content::{Content, ContentBlock};
26use skg_turn::infer::{InferRequest, InferResponse, ToolCall};
27use skg_turn::provider::{Provider, ProviderError};
28use skg_turn::stream::{StreamEvent, StreamProvider, StreamRequest};
29use skg_turn::types::*;
30use rust_decimal::Decimal;
31use tracing::Instrument;
32use types::*;
33
34const DEFAULT_BASE_URL: &str = "https://chatgpt.com/backend-api";
36
37const CODEX_RESPONSES_PATH: &str = "/codex/responses";
39
40#[derive(Clone)]
42pub struct CodexProvider {
43 access_token: String,
44 account_id: String,
45 client: reqwest::Client,
46 base_url: String,
47}
48
49impl CodexProvider {
50 pub fn new(access_token: impl Into<String>) -> Result<Self, ProviderError> {
55 let token = access_token.into();
56 let account_id = auth::extract_account_id(&token).ok_or_else(|| {
57 ProviderError::AuthFailed("failed to extract account ID from JWT".into())
58 })?;
59 Ok(Self {
60 access_token: token,
61 account_id,
62 client: reqwest::Client::new(),
63 base_url: DEFAULT_BASE_URL.into(),
64 })
65 }
66
67 pub fn with_account_id(access_token: impl Into<String>, account_id: impl Into<String>) -> Self {
72 Self {
73 access_token: access_token.into(),
74 account_id: account_id.into(),
75 client: reqwest::Client::new(),
76 base_url: DEFAULT_BASE_URL.into(),
77 }
78 }
79
80 pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
84 self.base_url = url.into();
85 self
86 }
87
88 fn endpoint_url(&self) -> String {
90 let base = self.base_url.trim_end_matches('/');
91 format!("{base}{CODEX_RESPONSES_PATH}")
92 }
93
94 fn build_headers(&self) -> reqwest::header::HeaderMap {
96 let mut headers = reqwest::header::HeaderMap::new();
97 headers.insert(
98 "authorization",
99 format!("Bearer {}", self.access_token)
100 .parse()
101 .expect("valid header"),
102 );
103 headers.insert(
104 "chatgpt-account-id",
105 self.account_id.parse().expect("valid header"),
106 );
107 headers.insert(
108 "openai-beta",
109 "responses=experimental".parse().expect("valid header"),
110 );
111 headers.insert("originator", "pi".parse().expect("valid header"));
112 headers.insert(
113 "content-type",
114 "application/json".parse().expect("valid header"),
115 );
116 headers
117 }
118
119 fn build_codex_request(&self, request: &InferRequest) -> CodexRequest {
121 let model = request.model.clone().unwrap_or_else(|| "gpt-5".into());
122
123 let input = messages_to_input(&request.messages);
124 let tools = tools_to_codex(&request.tools);
125
126 CodexRequest {
127 model,
128 input,
129 stream: true,
130 instructions: request.system.clone(),
131 tools,
132 tool_choice: None,
133 temperature: request.temperature,
134 max_output_tokens: request.max_tokens,
135 reasoning: None,
136 prompt_cache_key: None,
137 store: Some(false),
138 }
139 }
140
141 fn build_codex_stream_request(&self, request: &StreamRequest) -> CodexRequest {
143 let infer = InferRequest {
144 model: request.model.clone(),
145 messages: request.messages.clone(),
146 tools: request.tools.clone(),
147 max_tokens: request.max_tokens,
148 temperature: request.temperature,
149 system: request.system.clone(),
150 extra: request.extra.clone(),
151 };
152 self.build_codex_request(&infer)
153 }
154
155 async fn stream_sse(
157 &self,
158 codex_request: CodexRequest,
159 on_event: &(dyn Fn(StreamEvent) + Send + Sync),
160 ) -> Result<InferResponse, ProviderError> {
161 let url = self.endpoint_url();
162 let headers = self.build_headers();
163
164 let http_response = self
165 .client
166 .post(&url)
167 .headers(headers)
168 .json(&codex_request)
169 .send()
170 .await
171 .map_err(|e| ProviderError::TransientError {
172 message: e.to_string(),
173 status: None,
174 })?;
175
176 let status = http_response.status();
177 if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
178 return Err(ProviderError::RateLimited);
179 }
180 if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
181 let body = http_response.text().await.unwrap_or_default();
182 return Err(ProviderError::AuthFailed(body));
183 }
184 if !status.is_success() {
185 let body = http_response.text().await.unwrap_or_default();
186 return Err(map_error_response(status, &body));
187 }
188
189 let mut stream = http_response.bytes_stream();
191 let mut buf = String::new();
192
193 let mut model_name = codex_request.model.clone();
195 let mut usage = ResponseUsage::default();
196 let mut stop_reason = StopReason::EndTurn;
197 let mut text_blocks: Vec<String> = Vec::new();
198 let mut tool_calls: Vec<ToolCall> = Vec::new();
199
200 let mut current_text = String::new();
202 let mut current_tool_call_id = String::new();
203 let mut current_tool_item_id = String::new();
204 let mut current_tool_name = String::new();
205 let mut current_tool_args = String::new();
206 let mut tool_call_index: usize = 0;
207
208 while let Some(chunk) = stream.next().await {
209 let bytes = chunk.map_err(|e| ProviderError::TransientError {
210 message: format!("stream read error: {e}"),
211 status: None,
212 })?;
213 buf.push_str(&String::from_utf8_lossy(&bytes));
214
215 while let Some(frame_end) = buf.find("\n\n") {
217 let frame = buf[..frame_end].to_string();
218 buf = buf[frame_end + 2..].to_string();
219
220 let mut data = String::new();
222 for line in frame.lines() {
223 if let Some(rest) = line.strip_prefix("data: ") {
224 if !data.is_empty() {
225 data.push('\n');
226 }
227 data.push_str(rest);
228 }
229 }
230
231 if data.is_empty() {
232 continue;
233 }
234
235 let event: SseEvent = match serde_json::from_str(&data) {
236 Ok(ev) => ev,
237 Err(e) => {
238 tracing::warn!(error = %e, "failed to parse Codex SSE event");
239 continue;
240 }
241 };
242
243 match event.event_type.as_str() {
244 "response.output_item.added" => {
245 if let Some(item) = event.data.get("item") {
246 let item_type = item.get("type").and_then(|v| v.as_str()).unwrap_or("");
247 match item_type {
248 "message" => {
249 current_text = String::new();
250 }
251 "function_call" => {
252 let call_id = item
253 .get("call_id")
254 .and_then(|v| v.as_str())
255 .unwrap_or("")
256 .to_string();
257 let item_id = item
258 .get("id")
259 .and_then(|v| v.as_str())
260 .unwrap_or("")
261 .to_string();
262 let name = item
263 .get("name")
264 .and_then(|v| v.as_str())
265 .unwrap_or("")
266 .to_string();
267
268 current_tool_call_id = call_id.clone();
269 current_tool_item_id = item_id;
270 current_tool_name = name.clone();
271 current_tool_args = String::new();
272
273 let skg_id = format!("{call_id}|{}", current_tool_item_id);
275 on_event(StreamEvent::ToolCallStart {
276 index: tool_call_index,
277 id: skg_id,
278 name,
279 });
280 }
281 _ => {}
282 }
283 }
284 }
285 "response.output_text.delta" => {
286 if let Some(delta) = event.data.get("delta").and_then(|v| v.as_str()) {
287 current_text.push_str(delta);
288 on_event(StreamEvent::TextDelta(delta.to_string()));
289 }
290 }
291 "response.function_call_arguments.delta" => {
292 if let Some(delta) = event.data.get("delta").and_then(|v| v.as_str()) {
293 current_tool_args.push_str(delta);
294 on_event(StreamEvent::ToolCallDelta {
295 index: tool_call_index,
296 json_delta: delta.to_string(),
297 });
298 }
299 }
300 "response.output_item.done" => {
301 if let Some(item) = event.data.get("item") {
302 let item_type = item.get("type").and_then(|v| v.as_str()).unwrap_or("");
303
304 match item_type {
305 "message" => {
306 let final_text = extract_output_text(item);
308 if !final_text.is_empty() {
309 current_text = final_text;
310 }
311 if !current_text.is_empty() {
312 text_blocks.push(current_text.clone());
313 }
314 current_text = String::new();
315 }
316 "function_call" => {
317 let final_args = item
319 .get("arguments")
320 .and_then(|v| v.as_str())
321 .unwrap_or(¤t_tool_args);
322 let input: serde_json::Value = serde_json::from_str(final_args)
323 .unwrap_or(serde_json::Value::Object(
324 serde_json::Map::new(),
325 ));
326 tool_calls.push(ToolCall {
327 id: format!(
328 "{}|{}",
329 current_tool_call_id, current_tool_item_id
330 ),
331 name: current_tool_name.clone(),
332 input,
333 });
334 tool_call_index += 1;
335 current_tool_args = String::new();
336 }
337 _ => {}
338 }
339 }
340 }
341 "response.completed" | "response.done" => {
342 if let Some(response) = event.data.get("response") {
343 if let Some(u) = response.get("usage") {
344 usage = ResponseUsage::from_value(u);
345 on_event(StreamEvent::Usage(TokenUsage {
346 input_tokens: usage.input_tokens,
347 output_tokens: usage.output_tokens,
348 cache_read_tokens: if usage.cached_tokens > 0 {
349 Some(usage.cached_tokens)
350 } else {
351 None
352 },
353 cache_creation_tokens: None,
354 }));
355 }
356 if let Some(status) = response.get("status").and_then(|v| v.as_str()) {
357 stop_reason = match status {
358 "completed" => StopReason::EndTurn,
359 "incomplete" => StopReason::MaxTokens,
360 "failed" | "cancelled" => StopReason::EndTurn,
361 _ => StopReason::EndTurn,
362 };
363 }
364 if let Some(m) = response.get("model").and_then(|v| v.as_str()) {
365 model_name = m.to_string();
366 }
367 }
368 }
369 "error" | "response.failed" => {
370 let msg = event
371 .data
372 .get("message")
373 .and_then(|v| v.as_str())
374 .or_else(|| {
375 event
376 .data
377 .get("error")
378 .and_then(|e| e.get("message"))
379 .and_then(|v| v.as_str())
380 })
381 .unwrap_or("Codex stream error");
382 return Err(ProviderError::TransientError {
383 message: msg.to_string(),
384 status: None,
385 });
386 }
387 _ => {
388 }
390 }
391 }
392 }
393
394 if !tool_calls.is_empty() && stop_reason == StopReason::EndTurn {
396 stop_reason = StopReason::ToolUse;
397 }
398
399 let content = if text_blocks.len() == 1 {
401 Content::Text(text_blocks.into_iter().next().unwrap())
402 } else if text_blocks.is_empty() {
403 Content::text("")
404 } else {
405 Content::Blocks(
406 text_blocks
407 .into_iter()
408 .map(|t| ContentBlock::Text { text: t })
409 .collect(),
410 )
411 };
412
413 let token_usage = TokenUsage {
415 input_tokens: usage.input_tokens,
416 output_tokens: usage.output_tokens,
417 cache_read_tokens: if usage.cached_tokens > 0 {
418 Some(usage.cached_tokens)
419 } else {
420 None
421 },
422 cache_creation_tokens: None,
423 };
424
425 let response = InferResponse {
426 content,
427 tool_calls,
428 stop_reason,
429 usage: token_usage,
430 model: model_name,
431 cost: Some(Decimal::ZERO),
432 truncated: None,
433 };
434
435 on_event(StreamEvent::Done(response.clone()));
436
437 tracing::info!(
438 input_tokens = usage.input_tokens,
439 output_tokens = usage.output_tokens,
440 "codex streaming inference finished"
441 );
442
443 Ok(response)
444 }
445}
446
447impl Provider for CodexProvider {
448 fn infer(
449 &self,
450 request: InferRequest,
451 ) -> impl std::future::Future<Output = Result<InferResponse, ProviderError>> + Send {
452 let codex_request = self.build_codex_request(&request);
453 let this = self.clone();
454 let model = request.model.as_deref().unwrap_or("unknown");
455 let span = tracing::info_span!("provider.infer", provider = "codex", model);
456
457 async move {
458 this.stream_sse(codex_request, &|_| {}).await
460 }
461 .instrument(span)
462 }
463}
464
465impl StreamProvider for CodexProvider {
466 fn infer_stream(
467 &self,
468 request: StreamRequest,
469 on_event: impl Fn(StreamEvent) + Send + Sync + 'static,
470 ) -> impl std::future::Future<Output = Result<InferResponse, ProviderError>> + Send {
471 let codex_request = self.build_codex_stream_request(&request);
472 let this = self.clone();
473 let model = request.model.as_deref().unwrap_or("unknown");
474 let span = tracing::info_span!("provider.infer_stream", provider = "codex", model);
475
476 async move { this.stream_sse(codex_request, &on_event).await }.instrument(span)
477 }
478}
479
480fn extract_output_text(item: &serde_json::Value) -> String {
482 item.get("content")
483 .and_then(|c| c.as_array())
484 .map(|parts| {
485 parts
486 .iter()
487 .filter_map(|p| {
488 let ptype = p.get("type").and_then(|v| v.as_str()).unwrap_or("");
489 match ptype {
490 "output_text" => p.get("text").and_then(|v| v.as_str()),
491 "refusal" => p.get("refusal").and_then(|v| v.as_str()),
492 _ => None,
493 }
494 })
495 .collect::<Vec<_>>()
496 .join("")
497 })
498 .unwrap_or_default()
499}
500
501fn map_error_response(status: reqwest::StatusCode, body: &str) -> ProviderError {
503 let status_u16 = status.as_u16();
504
505 if body.contains("usage_limit_reached")
507 || body.contains("usage_not_included")
508 || body.contains("rate_limit_exceeded")
509 {
510 return ProviderError::RateLimited;
511 }
512
513 if body.contains("content_filter") || body.contains("content policy") {
514 return ProviderError::ContentBlocked {
515 message: body.to_string(),
516 };
517 }
518
519 ProviderError::TransientError {
520 message: format!("HTTP {status}: {body}"),
521 status: Some(status_u16),
522 }
523}
524
525#[cfg(test)]
526mod tests {
527 use super::*;
528
529 #[test]
530 fn endpoint_url_default() {
531 let p = CodexProvider::with_account_id("tok", "acct");
532 assert_eq!(
533 p.endpoint_url(),
534 "https://chatgpt.com/backend-api/codex/responses"
535 );
536 }
537
538 #[test]
539 fn endpoint_url_custom() {
540 let p =
541 CodexProvider::with_account_id("tok", "acct").with_base_url("http://localhost:8080/");
542 assert_eq!(p.endpoint_url(), "http://localhost:8080/codex/responses");
543 }
544
545 #[test]
546 fn error_mapping_rate_limit() {
547 let err = map_error_response(
548 reqwest::StatusCode::BAD_REQUEST,
549 r#"{"error":{"code":"rate_limit_exceeded"}}"#,
550 );
551 assert!(matches!(err, ProviderError::RateLimited));
552 }
553
554 #[test]
555 fn error_mapping_content_filter() {
556 let err = map_error_response(reqwest::StatusCode::BAD_REQUEST, "content_filter triggered");
557 assert!(matches!(err, ProviderError::ContentBlocked { .. }));
558 }
559
560 #[test]
561 fn extract_output_text_basic() {
562 let item = serde_json::json!({
563 "type": "message",
564 "content": [
565 {"type": "output_text", "text": "Hello "},
566 {"type": "output_text", "text": "world"}
567 ]
568 });
569 assert_eq!(extract_output_text(&item), "Hello world");
570 }
571
572 #[test]
573 fn build_request_sets_stream_true() {
574 let p = CodexProvider::with_account_id("tok", "acct");
575 let req = InferRequest::new(vec![]);
576 let codex = p.build_codex_request(&req);
577 assert!(codex.stream);
578 assert_eq!(codex.store, Some(false));
579 }
580}