Skip to main content

steer_auth_openai/
lib.rs

1use async_trait::async_trait;
2use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
3use rand::Rng;
4use serde::{Deserialize, Serialize};
5use sha2::{Digest, Sha256};
6use std::net::SocketAddr;
7use std::sync::Arc;
8use std::time::{Duration, SystemTime, UNIX_EPOCH};
9use tracing::info;
10
11use steer_auth_plugin::AuthPlugin;
12use steer_auth_plugin::{
13    AuthDirective, AuthError, AuthErrorAction, AuthErrorContext, AuthHeaderContext,
14    AuthHeaderProvider, AuthMethod, AuthProgress, AuthSource, AuthStorage, AuthTokens,
15    AuthenticationFlow, Credential, CredentialType, DynAuthenticationFlow, HeaderPair,
16    InstructionPolicy, ModelId, ModelVisibilityPolicy, OpenAiResponsesAuth, ProviderId, Result,
17};
18use steer_tools::tools::{
19    AST_GREP_TOOL_NAME, BASH_TOOL_NAME, DISPATCH_AGENT_TOOL_NAME, EDIT_TOOL_NAME, FETCH_TOOL_NAME,
20    GLOB_TOOL_NAME, GREP_TOOL_NAME, LS_TOOL_NAME, MULTI_EDIT_TOOL_NAME, REPLACE_TOOL_NAME,
21    TODO_READ_TOOL_NAME, TODO_WRITE_TOOL_NAME, VIEW_TOOL_NAME,
22};
23
24mod callback_server;
25use callback_server::{CallbackResponse, CallbackServerHandle, spawn_callback_server};
26
27const PROVIDER_ID: &str = "openai";
28const AUTHORIZE_URL: &str = "https://auth.openai.com/oauth/authorize";
29const TOKEN_URL: &str = "https://auth.openai.com/oauth/token";
30const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
31const REDIRECT_URI: &str = "http://localhost:1455/auth/callback";
32const SCOPES: &str = "openid profile email offline_access";
33const ORIGINATOR: &str = "codex_cli_rs";
34const CALLBACK_PATH: &str = "/auth/callback";
35const CALLBACK_PORT: u16 = 1455;
36
37const CODEX_BASE_URL: &str = "https://chatgpt.com/backend-api/codex/responses";
38const OPENAI_BETA: &str = "responses=experimental";
39const GPT_5_2_CODEX_MODEL_ID: &str = "gpt-5.2-codex";
40const GPT_5_3_CODEX_MODEL_ID: &str = "gpt-5.3-codex";
41const CODEX_SYSTEM_PROMPT: &str = r#"You are Codex, based on GPT-5. You are running as a coding agent in the Codex CLI on a user's computer.
42
43## General
44
45- When searching for text or files, prefer using `rg` or `rg --files` respectively because `rg` is much faster than alternatives like `grep`. (If the `rg` command is not found, then use alternatives.)
46
47## Editing constraints
48
49- Default to ASCII when editing or creating files. Only introduce non-ASCII or other Unicode characters when there is a clear justification and the file already uses them.
50- Add succinct code comments that explain what is going on if code is not self-explanatory. You should not add comments like "Assigns the value to the variable", but a brief comment might be useful ahead of a complex code block that the user would otherwise have to spend time parsing out. Usage of these comments should be rare.
51- Try to use apply_patch for single file edits, but it is fine to explore other options to make the edit if it does not work well. Do not use apply_patch for changes that are auto-generated (i.e. generating package.json or running a lint or format command like gofmt) or when scripting is more efficient (such as search and replacing a string across a codebase).
52- You may be in a dirty git worktree.
53    * NEVER revert existing changes you did not make unless explicitly requested, since these changes were made by the user.
54    * If asked to make a commit or code edits and there are unrelated changes to your work or changes that you didn't make in those files, don't revert those changes.
55    * If the changes are in files you've touched recently, you should read carefully and understand how you can work with the changes rather than reverting them.
56    * If the changes are in unrelated files, just ignore them and don't revert them.
57- Do not amend a commit unless explicitly requested to do so.
58- While you are working, you might notice unexpected changes that you didn't make. If this happens, STOP IMMEDIATELY and ask the user how they would like to proceed.
59- **NEVER** use destructive commands like `git reset --hard` or `git checkout --` unless specifically requested or approved by the user.
60
61## Plan tool
62
63When using the planning tool:
64- Skip using the planning tool for straightforward tasks (roughly the easiest 25%).
65- Do not make single-step plans.
66- When you made a plan, update it after having performed one of the sub-tasks that you shared on the plan.
67
68## Special user requests
69
70- If the user makes a simple request (such as asking for the time) which you can fulfill by running a terminal command (such as `date`), you should do so.
71- If the user asks for a "review", default to a code review mindset: prioritise identifying bugs, risks, behavioural regressions, and missing tests. Findings must be the primary focus of the response - keep summaries or overviews brief and only after enumerating the issues. Present findings first (ordered by severity with file/line references), follow with open questions or assumptions, and offer a change-summary only as a secondary detail. If no findings are discovered, state that explicitly and mention any residual risks or testing gaps.
72
73## Frontend tasks
74When doing frontend design tasks, avoid collapsing into "AI slop" or safe, average-looking layouts.
75Aim for interfaces that feel intentional, bold, and a bit surprising.
76- Typography: Use expressive, purposeful fonts and avoid default stacks (Inter, Roboto, Arial, system).
77- Color & Look: Choose a clear visual direction; define CSS variables; avoid purple-on-white defaults. No purple bias or dark mode bias.
78- Motion: Use a few meaningful animations (page-load, staggered reveals) instead of generic micro-motions.
79- Background: Don't rely on flat, single-color backgrounds; use gradients, shapes, or subtle patterns to build atmosphere.
80- Overall: Avoid boilerplate layouts and interchangeable UI patterns. Vary themes, type families, and visual languages across outputs.
81- Ensure the page loads properly on both desktop and mobile
82
83Exception: If working within an existing website or design system, preserve the established patterns, structure, and visual language.
84
85## Presenting your work and final message
86
87You are producing plain text that will later be styled by the CLI. Follow these rules exactly. Formatting should make results easy to scan, but not feel mechanical. Use judgment to decide how much structure adds value.
88
89- Default: be very concise; friendly coding teammate tone.
90- Ask only when needed; suggest ideas; mirror the user's style.
91- For substantial work, summarize clearly; follow final‑answer formatting.
92- Skip heavy formatting for simple confirmations.
93- Don't dump large files you've written; reference paths only.
94- No "save/copy this file" - User is on the same machine.
95- Offer logical next steps (tests, commits, build) briefly; add verify steps if you couldn't do something.
96- For code changes:
97  * Lead with a quick explanation of the change, and then give more details on the context covering where and why a change was made. Do not start this explanation with "summary", just jump right in.
98  * If there are natural next steps the user may want to take, suggest them at the end of your response. Do not make suggestions if there are no natural next steps.
99  * When suggesting multiple options, use numeric lists for the suggestions so the user can quickly respond with a single number.
100- The user does not command execution outputs. When asked to show the output of a command (e.g. `git show`), relay the important details in your answer or summarize the key lines so the user understands the result.
101
102### Final answer structure and style guidelines
103
104- Plain text; CLI handles styling. Use structure only when it helps scanability.
105- Headers: optional; short Title Case (1-3 words) wrapped in **…**; no blank line before the first bullet; add only if they truly help.
106- Bullets: use - ; merge related points; keep to one line when possible; 4–6 per list ordered by importance; keep phrasing consistent.
107- Monospace: backticks for commands/paths/env vars/code ids and inline examples; use for literal keyword bullets; never combine with **.
108- Code samples or multi-line snippets should be wrapped in fenced code blocks; include an info string as often as possible.
109- Structure: group related bullets; order sections general → specific → supporting; for subsections, start with a bolded keyword bullet, then items; match complexity to the task.
110- Tone: collaborative, concise, factual; present tense, active voice; self‑contained; no "above/below"; parallel wording.
111- Don'ts: no nested bullets/hierarchies; no ANSI codes; don't cram unrelated keywords; keep keyword lists short—wrap/reformat if long; avoid naming formatting styles in answers.
112- Adaptation: code explanations → precise, structured with code refs; simple tasks → lead with outcome; big changes → logical walkthrough + rationale + next actions; casual one-offs → plain sentences, no headers/bullets.
113- File References: When referencing files in your response follow the below rules:
114  * Use inline code to make file paths clickable.
115  * Each reference should have a stand alone path. Even if it's the same file.
116  * Accepted: absolute, workspace‑relative, a/ or b/ diff prefixes, or bare filename/suffix.
117  * Optionally include line/column (1‑based): :line[:column] or #Lline[Ccolumn] (column defaults to 1).
118  * Do not use URIs like file://, vscode://, or https://.
119  * Do not provide range of lines
120  * Examples: src/app.ts, src/app.ts:42, b/server/index.js#L10, C:\repo\project\main.rs:12:5
121"#;
122
123fn steer_codex_bridge_prompt() -> String {
124    format!(
125        r"## Codex Running in Steer
126
127You are running Codex inside Steer, an open-source terminal coding assistant.
128
129### CRITICAL tool replacements
130- apply_patch does NOT exist. Use `{EDIT_TOOL_NAME}` instead.
131- update_plan does NOT exist. Use `{TODO_WRITE_TOOL_NAME}` instead.
132- read_plan does NOT exist. Use `{TODO_READ_TOOL_NAME}` instead.
133
134### Steer tool names
135- File: `{VIEW_TOOL_NAME}`, `{REPLACE_TOOL_NAME}`, `{EDIT_TOOL_NAME}`, `{MULTI_EDIT_TOOL_NAME}`
136- Search: `{GREP_TOOL_NAME}` (text), `{AST_GREP_TOOL_NAME}` (syntax), `{GLOB_TOOL_NAME}` (paths), `{LS_TOOL_NAME}` (list directories)
137- Exec: `{BASH_TOOL_NAME}`
138- Web: `{FETCH_TOOL_NAME}`
139- Agents: `{DISPATCH_AGENT_TOOL_NAME}`
140- Todos: `{TODO_READ_TOOL_NAME}`, `{TODO_WRITE_TOOL_NAME}`
141
142Tool names are case-sensitive; use exact casing.
143
144### File path rules
145- `{VIEW_TOOL_NAME}`, `{REPLACE_TOOL_NAME}`, `{EDIT_TOOL_NAME}`, `{MULTI_EDIT_TOOL_NAME}`, and `{LS_TOOL_NAME}` require absolute paths.
146
147### Edit semantics
148- `{EDIT_TOOL_NAME}` uses exact string replacement (empty `old_string` creates a file).
149- `{MULTI_EDIT_TOOL_NAME}` applies multiple exact replacements in a single file.
150- `{REPLACE_TOOL_NAME}` overwrites the entire file contents.
151
152### Search guidance
153- Prefer `{GREP_TOOL_NAME}`/`{AST_GREP_TOOL_NAME}`/`{GLOB_TOOL_NAME}`/`{LS_TOOL_NAME}` over shelling out to `rg` via `{BASH_TOOL_NAME}`.
154
155### Todo guidance
156- Use `{TODO_READ_TOOL_NAME}`/`{TODO_WRITE_TOOL_NAME}` for complex or multi-step tasks; skip them for simple, single-step work unless the user asks.
157",
158    )
159}
160
161fn codex_instructions() -> String {
162    format!("{CODEX_SYSTEM_PROMPT}\n\n{}", steer_codex_bridge_prompt())
163}
164
165const CHATGPT_ACCOUNT_ID_NESTED_CLAIM: &str = "https://api.openai.com/auth";
166
167#[derive(Debug)]
168pub struct PkceChallenge {
169    pub verifier: String,
170    pub challenge: String,
171}
172
173#[derive(Debug, Clone, PartialEq, Eq, Hash)]
174pub struct ChatGptAccountId(pub String);
175
176#[derive(Clone)]
177pub struct OpenAIOAuth {
178    client_id: String,
179    redirect_uri: String,
180    http_client: reqwest::Client,
181}
182
183impl Default for OpenAIOAuth {
184    fn default() -> Self {
185        Self::new()
186    }
187}
188
189impl OpenAIOAuth {
190    pub fn new() -> Self {
191        Self {
192            client_id: CLIENT_ID.to_string(),
193            redirect_uri: REDIRECT_URI.to_string(),
194            http_client: reqwest::Client::new(),
195        }
196    }
197
198    pub fn generate_pkce() -> PkceChallenge {
199        let verifier = generate_random_string(128);
200        let challenge = base64_url_encode(&sha256(&verifier));
201        PkceChallenge {
202            verifier,
203            challenge,
204        }
205    }
206
207    pub fn generate_state() -> String {
208        generate_random_string(32)
209    }
210
211    pub fn build_auth_url(&self, pkce: &PkceChallenge, state: &str) -> String {
212        let params = [
213            ("response_type", "code"),
214            ("client_id", &self.client_id),
215            ("redirect_uri", &self.redirect_uri),
216            ("scope", SCOPES),
217            ("code_challenge", &pkce.challenge),
218            ("code_challenge_method", "S256"),
219            ("state", state),
220            ("id_token_add_organizations", "true"),
221            ("codex_cli_simplified_flow", "true"),
222            ("originator", ORIGINATOR),
223        ];
224
225        let query = serde_urlencoded::to_string(params).unwrap_or_default();
226        format!("{AUTHORIZE_URL}?{query}")
227    }
228
229    pub async fn exchange_code_for_tokens(
230        &self,
231        code: &str,
232        pkce_verifier: &str,
233    ) -> Result<AuthTokens> {
234        #[derive(Serialize)]
235        struct TokenRequest {
236            grant_type: String,
237            client_id: String,
238            code: String,
239            redirect_uri: String,
240            code_verifier: String,
241        }
242
243        #[derive(Deserialize)]
244        struct TokenResponse {
245            id_token: Option<String>,
246            access_token: String,
247            refresh_token: Option<String>,
248            expires_in: Option<u64>,
249        }
250
251        let request = TokenRequest {
252            grant_type: "authorization_code".to_string(),
253            client_id: self.client_id.clone(),
254            code: code.to_string(),
255            redirect_uri: self.redirect_uri.clone(),
256            code_verifier: pkce_verifier.to_string(),
257        };
258
259        let response = self
260            .http_client
261            .post(TOKEN_URL)
262            .form(&request)
263            .send()
264            .await?;
265
266        if !response.status().is_success() {
267            let status = response.status();
268            let error_text = response
269                .text()
270                .await
271                .unwrap_or_else(|_| "Unknown error".to_string());
272            return Err(AuthError::InvalidResponse(format!(
273                "Token exchange failed with status {status}: {error_text}"
274            )));
275        }
276
277        let token_response: TokenResponse = response.json().await.map_err(|e| {
278            AuthError::InvalidResponse(format!("Failed to parse token response: {e}"))
279        })?;
280
281        if token_response.access_token.trim().is_empty() {
282            return Err(AuthError::InvalidResponse(
283                "Empty access_token in token response".to_string(),
284            ));
285        }
286
287        let id_token = token_response.id_token.ok_or_else(|| {
288            AuthError::InvalidResponse("Missing id_token in token response".to_string())
289        })?;
290
291        let refresh_token = token_response.refresh_token.ok_or_else(|| {
292            AuthError::InvalidResponse("Missing refresh_token in token response".to_string())
293        })?;
294
295        let expires_at =
296            resolve_expires_at(token_response.expires_in, &token_response.access_token)?;
297
298        Ok(AuthTokens {
299            access_token: token_response.access_token,
300            refresh_token,
301            expires_at,
302            id_token: Some(id_token),
303        })
304    }
305
306    pub async fn refresh_tokens(&self, refresh_token: &str) -> Result<AuthTokens> {
307        #[derive(Serialize)]
308        struct RefreshRequest {
309            grant_type: String,
310            refresh_token: String,
311            client_id: String,
312        }
313
314        #[derive(Deserialize)]
315        struct TokenResponse {
316            id_token: Option<String>,
317            access_token: String,
318            refresh_token: Option<String>,
319            expires_in: Option<u64>,
320        }
321
322        let request = RefreshRequest {
323            grant_type: "refresh_token".to_string(),
324            refresh_token: refresh_token.to_string(),
325            client_id: self.client_id.clone(),
326        };
327
328        let response = self
329            .http_client
330            .post(TOKEN_URL)
331            .form(&request)
332            .send()
333            .await?;
334
335        if !response.status().is_success() {
336            if matches!(
337                response.status(),
338                reqwest::StatusCode::UNAUTHORIZED | reqwest::StatusCode::BAD_REQUEST
339            ) {
340                return Err(AuthError::ReauthRequired);
341            }
342
343            let status = response.status();
344            let error_text = response
345                .text()
346                .await
347                .unwrap_or_else(|_| "Unknown error".to_string());
348            return Err(AuthError::InvalidResponse(format!(
349                "Token refresh failed with status {status}: {error_text}"
350            )));
351        }
352
353        let token_response: TokenResponse = response.json().await.map_err(|e| {
354            AuthError::InvalidResponse(format!("Failed to parse refresh response: {e}"))
355        })?;
356
357        if token_response.access_token.trim().is_empty() {
358            return Err(AuthError::InvalidResponse(
359                "Empty access_token in refresh response".to_string(),
360            ));
361        }
362
363        let expires_at =
364            resolve_expires_at(token_response.expires_in, &token_response.access_token)?;
365
366        let refresh_token = token_response
367            .refresh_token
368            .unwrap_or_else(|| refresh_token.to_string());
369
370        Ok(AuthTokens {
371            access_token: token_response.access_token,
372            refresh_token,
373            expires_at,
374            id_token: token_response.id_token,
375        })
376    }
377}
378
379/// Check if tokens need refresh (within 5 minutes of expiry).
380pub fn tokens_need_refresh(tokens: &AuthTokens) -> bool {
381    match tokens.expires_at.duration_since(SystemTime::now()) {
382        Ok(duration) => duration.as_secs() <= 300,
383        Err(_) => true,
384    }
385}
386
387/// Refresh tokens if needed, updating storage when refreshed.
388pub async fn refresh_if_needed(
389    storage: &Arc<dyn AuthStorage>,
390    oauth_client: &OpenAIOAuth,
391) -> Result<AuthTokens> {
392    let credential = storage
393        .get_credential(PROVIDER_ID, CredentialType::OAuth2)
394        .await?
395        .ok_or(AuthError::ReauthRequired)?;
396
397    let mut tokens = match credential {
398        Credential::OAuth2(tokens) => tokens,
399        Credential::ApiKey { .. } => return Err(AuthError::ReauthRequired),
400    };
401
402    if tokens.id_token.is_none() || tokens_need_refresh(&tokens) {
403        match oauth_client.refresh_tokens(&tokens.refresh_token).await {
404            Ok(new_tokens) => {
405                let merged_tokens = AuthTokens {
406                    id_token: new_tokens.id_token.or(tokens.id_token),
407                    ..new_tokens
408                };
409                storage
410                    .set_credential(PROVIDER_ID, Credential::OAuth2(merged_tokens.clone()))
411                    .await?;
412                tokens = merged_tokens;
413            }
414            Err(AuthError::ReauthRequired) => {
415                storage
416                    .remove_credential(PROVIDER_ID, CredentialType::OAuth2)
417                    .await?;
418                return Err(AuthError::ReauthRequired);
419            }
420            Err(e) => return Err(e),
421        }
422    }
423
424    if tokens.id_token.is_none() {
425        return Err(AuthError::ReauthRequired);
426    }
427
428    Ok(tokens)
429}
430
431async fn force_refresh(
432    storage: &Arc<dyn AuthStorage>,
433    oauth_client: &OpenAIOAuth,
434) -> Result<AuthTokens> {
435    let credential = storage
436        .get_credential(PROVIDER_ID, CredentialType::OAuth2)
437        .await?
438        .ok_or(AuthError::ReauthRequired)?;
439
440    let tokens = match credential {
441        Credential::OAuth2(tokens) => tokens,
442        Credential::ApiKey { .. } => return Err(AuthError::ReauthRequired),
443    };
444
445    match oauth_client.refresh_tokens(&tokens.refresh_token).await {
446        Ok(new_tokens) => {
447            let merged_tokens = AuthTokens {
448                id_token: new_tokens.id_token.or(tokens.id_token),
449                ..new_tokens
450            };
451            storage
452                .set_credential(PROVIDER_ID, Credential::OAuth2(merged_tokens.clone()))
453                .await?;
454            Ok(merged_tokens)
455        }
456        Err(AuthError::ReauthRequired) => {
457            storage
458                .remove_credential(PROVIDER_ID, CredentialType::OAuth2)
459                .await?;
460            Err(AuthError::ReauthRequired)
461        }
462        Err(e) => Err(e),
463    }
464}
465
466pub fn extract_chatgpt_account_id(id_token: &str) -> Result<ChatGptAccountId> {
467    extract_chatgpt_account_id_from_id_token(id_token)
468}
469
470fn resolve_expires_at(expires_in: Option<u64>, access_token: &str) -> Result<SystemTime> {
471    if let Some(expires_in) = expires_in {
472        return Ok(SystemTime::now() + Duration::from_secs(expires_in));
473    }
474
475    let payload = decode_jwt_payload(access_token)?;
476    let exp = payload
477        .get("exp")
478        .and_then(|v| v.as_u64().or_else(|| v.as_i64().map(|v| v as u64)))
479        .ok_or_else(|| {
480            AuthError::InvalidResponse("Missing exp claim in access token".to_string())
481        })?;
482
483    Ok(UNIX_EPOCH + Duration::from_secs(exp))
484}
485
486fn decode_jwt_payload(access_token: &str) -> Result<serde_json::Value> {
487    let parts: Vec<&str> = access_token.split('.').collect();
488    if parts.len() < 2 {
489        return Err(AuthError::InvalidResponse(
490            "Invalid access token format".to_string(),
491        ));
492    }
493
494    let payload_bytes = URL_SAFE_NO_PAD
495        .decode(parts[1])
496        .map_err(|e| AuthError::InvalidResponse(format!("Invalid token payload: {e}")))?;
497
498    serde_json::from_slice(&payload_bytes)
499        .map_err(|e| AuthError::InvalidResponse(format!("Invalid token payload JSON: {e}")))
500}
501
502fn extract_chatgpt_account_id_from_id_token(id_token: &str) -> Result<ChatGptAccountId> {
503    let payload = decode_jwt_payload(id_token)?;
504
505    if let Some(account_id) = payload
506        .get(CHATGPT_ACCOUNT_ID_NESTED_CLAIM)
507        .and_then(|v| v.get("chatgpt_account_id"))
508        .and_then(|v| v.as_str())
509        .filter(|s| !s.is_empty())
510    {
511        return Ok(ChatGptAccountId(account_id.to_string()));
512    }
513
514    if let Some(account_id) = payload
515        .get("chatgpt_account_id")
516        .and_then(|v| v.as_str())
517        .filter(|s| !s.is_empty())
518    {
519        return Ok(ChatGptAccountId(account_id.to_string()));
520    }
521
522    Err(AuthError::InvalidResponse(
523        "Missing chatgpt account id in token".to_string(),
524    ))
525}
526
527fn parse_callback_input(input: &str) -> Result<CallbackResponse> {
528    let trimmed = input.trim();
529
530    if trimmed.contains("code=") && trimmed.contains("state=") {
531        let query = if trimmed.contains("://") {
532            let url = url::Url::parse(trimmed)
533                .map_err(|_| AuthError::InvalidCredential("Invalid redirect URL".to_string()))?;
534            url.query().unwrap_or("").to_string()
535        } else {
536            trimmed.to_string()
537        };
538
539        let params: std::collections::HashMap<String, String> =
540            url::form_urlencoded::parse(query.as_bytes())
541                .into_owned()
542                .collect();
543
544        let code = params
545            .get("code")
546            .ok_or_else(|| AuthError::MissingInput("code parameter".to_string()))?;
547        let state = params
548            .get("state")
549            .ok_or_else(|| AuthError::MissingInput("state parameter".to_string()))?;
550
551        return Ok(CallbackResponse {
552            code: code.clone(),
553            state: state.clone(),
554        });
555    }
556
557    if let Some((code, state)) = trimmed.split_once('#') {
558        if code.is_empty() || state.is_empty() {
559            return Err(AuthError::InvalidResponse(
560                "Invalid callback code format".to_string(),
561            ));
562        }
563        return Ok(CallbackResponse {
564            code: code.to_string(),
565            state: state.to_string(),
566        });
567    }
568
569    let parts: Vec<&str> = trimmed.split_whitespace().collect();
570    if parts.len() == 2 {
571        return Ok(CallbackResponse {
572            code: parts[0].to_string(),
573            state: parts[1].to_string(),
574        });
575    }
576
577    Err(AuthError::InvalidResponse(
578        "Invalid callback input. Paste the full redirect URL or code#state.".to_string(),
579    ))
580}
581
582fn generate_random_string(length: usize) -> String {
583    const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~";
584    let mut rng = rand::thread_rng();
585
586    (0..length)
587        .map(|_| {
588            let idx = rng.gen_range(0..CHARSET.len());
589            CHARSET[idx] as char
590        })
591        .collect()
592}
593
594fn sha256(data: &str) -> Vec<u8> {
595    let mut hasher = Sha256::new();
596    hasher.update(data.as_bytes());
597    hasher.finalize().to_vec()
598}
599
600fn base64_url_encode(data: &[u8]) -> String {
601    URL_SAFE_NO_PAD.encode(data)
602}
603
604#[derive(Debug)]
605pub struct OpenAIAuthState {
606    pub kind: OpenAIAuthStateKind,
607}
608
609#[derive(Debug)]
610pub enum OpenAIAuthStateKind {
611    OAuthStarted {
612        verifier: String,
613        state: String,
614        auth_url: String,
615        callback_server: Option<CallbackServerHandle>,
616    },
617}
618
619pub struct OpenAIOAuthFlow {
620    storage: Arc<dyn AuthStorage>,
621    oauth_client: OpenAIOAuth,
622}
623
624impl OpenAIOAuthFlow {
625    pub fn new(storage: Arc<dyn AuthStorage>) -> Self {
626        Self {
627            storage,
628            oauth_client: OpenAIOAuth::new(),
629        }
630    }
631}
632
633#[async_trait]
634impl AuthenticationFlow for OpenAIOAuthFlow {
635    type State = OpenAIAuthState;
636
637    fn available_methods(&self) -> Vec<AuthMethod> {
638        vec![AuthMethod::OAuth]
639    }
640
641    async fn start_auth(&self, method: AuthMethod) -> Result<Self::State> {
642        match method {
643            AuthMethod::OAuth => {
644                let pkce = OpenAIOAuth::generate_pkce();
645                let state = OpenAIOAuth::generate_state();
646                let auth_url = self.oauth_client.build_auth_url(&pkce, &state);
647
648                let callback_server = match spawn_callback_server(
649                    state.clone(),
650                    SocketAddr::from(([127, 0, 0, 1], CALLBACK_PORT)),
651                    CALLBACK_PATH,
652                )
653                .await
654                {
655                    Ok(handle) => Some(handle),
656                    Err(err) => {
657                        info!(
658                            "OpenAI OAuth callback server unavailable, falling back to manual paste: {}",
659                            err
660                        );
661                        None
662                    }
663                };
664
665                Ok(OpenAIAuthState {
666                    kind: OpenAIAuthStateKind::OAuthStarted {
667                        verifier: pkce.verifier,
668                        state,
669                        auth_url,
670                        callback_server,
671                    },
672                })
673            }
674            AuthMethod::ApiKey => Err(AuthError::UnsupportedMethod {
675                method: format!("{method:?}"),
676                provider: PROVIDER_ID.to_string(),
677            }),
678        }
679    }
680
681    async fn get_initial_progress(
682        &self,
683        state: &Self::State,
684        method: AuthMethod,
685    ) -> Result<AuthProgress> {
686        match method {
687            AuthMethod::OAuth => {
688                let OpenAIAuthStateKind::OAuthStarted { auth_url, .. } = &state.kind;
689                Ok(AuthProgress::OAuthStarted {
690                    auth_url: auth_url.clone(),
691                })
692            }
693            AuthMethod::ApiKey => Err(AuthError::UnsupportedMethod {
694                method: format!("{method:?}"),
695                provider: PROVIDER_ID.to_string(),
696            }),
697        }
698    }
699
700    async fn handle_input(&self, state: &mut Self::State, input: &str) -> Result<AuthProgress> {
701        match &mut state.kind {
702            OpenAIAuthStateKind::OAuthStarted {
703                verifier,
704                state: expected_state,
705                callback_server,
706                ..
707            } => {
708                let callback = if input.trim().is_empty() {
709                    if let Some(server) = callback_server {
710                        if let Some(result) = server.try_recv() {
711                            result?
712                        } else {
713                            return Ok(AuthProgress::InProgress(
714                                "Waiting for OAuth callback...".to_string(),
715                            ));
716                        }
717                    } else {
718                        return Ok(AuthProgress::NeedInput(
719                            "Paste the redirect URL from your browser".to_string(),
720                        ));
721                    }
722                } else {
723                    parse_callback_input(input)?
724                };
725
726                if callback.state != *expected_state {
727                    return Err(AuthError::StateMismatch);
728                }
729
730                let tokens = self
731                    .oauth_client
732                    .exchange_code_for_tokens(&callback.code, verifier)
733                    .await?;
734
735                self.storage
736                    .set_credential(PROVIDER_ID, Credential::OAuth2(tokens))
737                    .await?;
738
739                if let Some(server) = callback_server.take() {
740                    drop(server);
741                }
742
743                Ok(AuthProgress::Complete)
744            }
745        }
746    }
747
748    async fn is_authenticated(&self) -> Result<bool> {
749        if let Some(Credential::OAuth2(tokens)) = self
750            .storage
751            .get_credential(PROVIDER_ID, CredentialType::OAuth2)
752            .await?
753        {
754            return Ok(tokens.id_token.is_some() && !tokens_need_refresh(&tokens));
755        }
756
757        Ok(false)
758    }
759
760    fn provider_name(&self) -> String {
761        PROVIDER_ID.to_string()
762    }
763}
764
765#[derive(Clone)]
766struct OpenAiHeaderProvider {
767    storage: Arc<dyn AuthStorage>,
768    oauth: OpenAIOAuth,
769}
770
771impl OpenAiHeaderProvider {
772    fn new(storage: Arc<dyn AuthStorage>) -> Self {
773        Self {
774            storage,
775            oauth: OpenAIOAuth::new(),
776        }
777    }
778
779    async fn header_pairs(&self, _ctx: AuthHeaderContext) -> Result<Vec<HeaderPair>> {
780        let tokens = refresh_if_needed(&self.storage, &self.oauth).await?;
781        let id_token = tokens
782            .id_token
783            .as_deref()
784            .ok_or(AuthError::ReauthRequired)?;
785        let account_id = extract_chatgpt_account_id(id_token)?;
786
787        Ok(vec![
788            HeaderPair {
789                name: "authorization".to_string(),
790                value: format!("Bearer {}", tokens.access_token),
791            },
792            HeaderPair {
793                name: "chatgpt-account-id".to_string(),
794                value: account_id.0,
795            },
796            HeaderPair {
797                name: "openai-beta".to_string(),
798                value: OPENAI_BETA.to_string(),
799            },
800            HeaderPair {
801                name: "originator".to_string(),
802                value: ORIGINATOR.to_string(),
803            },
804        ])
805    }
806}
807
808#[async_trait]
809impl AuthHeaderProvider for OpenAiHeaderProvider {
810    async fn headers(&self, ctx: AuthHeaderContext) -> Result<Vec<HeaderPair>> {
811        self.header_pairs(ctx).await
812    }
813
814    async fn on_auth_error(&self, _ctx: AuthErrorContext) -> Result<AuthErrorAction> {
815        match force_refresh(&self.storage, &self.oauth).await {
816            Ok(_) => Ok(AuthErrorAction::RetryOnce),
817            Err(AuthError::ReauthRequired) => Ok(AuthErrorAction::ReauthRequired),
818            Err(err) => Err(err),
819        }
820    }
821}
822
823struct OpenAiModelVisibility;
824
825impl ModelVisibilityPolicy for OpenAiModelVisibility {
826    fn allow_model(&self, model_id: &ModelId, auth_source: &AuthSource) -> bool {
827        if model_id.provider_id.0 != PROVIDER_ID {
828            return true;
829        }
830
831        if matches!(
832            model_id.model_id.as_str(),
833            GPT_5_2_CODEX_MODEL_ID | GPT_5_3_CODEX_MODEL_ID
834        ) {
835            return matches!(auth_source, AuthSource::Plugin { .. });
836        }
837
838        true
839    }
840}
841
842#[derive(Clone)]
843pub struct OpenAiAuthPlugin;
844
845impl Default for OpenAiAuthPlugin {
846    fn default() -> Self {
847        Self::new()
848    }
849}
850
851impl OpenAiAuthPlugin {
852    pub fn new() -> Self {
853        Self
854    }
855}
856
857#[async_trait]
858impl AuthPlugin for OpenAiAuthPlugin {
859    fn provider_id(&self) -> ProviderId {
860        ProviderId(PROVIDER_ID.to_string())
861    }
862
863    fn supported_methods(&self) -> Vec<AuthMethod> {
864        vec![AuthMethod::OAuth]
865    }
866
867    fn create_flow(&self, storage: Arc<dyn AuthStorage>) -> Option<Box<dyn DynAuthenticationFlow>> {
868        Some(Box::new(steer_auth_plugin::AuthFlowWrapper::new(
869            OpenAIOAuthFlow::new(storage),
870        )))
871    }
872
873    async fn resolve_auth(&self, storage: Arc<dyn AuthStorage>) -> Result<Option<AuthDirective>> {
874        let is_authenticated = self.is_authenticated(storage.clone()).await?;
875        if !is_authenticated {
876            return Ok(None);
877        }
878
879        let headers = Arc::new(OpenAiHeaderProvider::new(storage));
880        let directive = OpenAiResponsesAuth {
881            headers,
882            base_url_override: Some(CODEX_BASE_URL.to_string()),
883            require_streaming: Some(true),
884            instruction_policy: Some(InstructionPolicy::Override(codex_instructions())),
885            include: Some(vec!["reasoning.encrypted_content".to_string()]),
886        };
887
888        Ok(Some(AuthDirective::OpenAiResponses(directive)))
889    }
890
891    async fn is_authenticated(&self, storage: Arc<dyn AuthStorage>) -> Result<bool> {
892        if let Some(Credential::OAuth2(tokens)) = storage
893            .get_credential(PROVIDER_ID, CredentialType::OAuth2)
894            .await?
895        {
896            return Ok(tokens.id_token.is_some() && !tokens_need_refresh(&tokens));
897        }
898
899        Ok(false)
900    }
901
902    fn model_visibility(&self) -> Option<Box<dyn ModelVisibilityPolicy>> {
903        Some(Box::new(OpenAiModelVisibility))
904    }
905}
906
907#[cfg(test)]
908mod tests {
909    use super::*;
910    use steer_auth_plugin::{AuthMethod, AuthSource};
911
912    #[test]
913    fn test_auth_url_building() {
914        let oauth = OpenAIOAuth::new();
915        let pkce = OpenAIOAuth::generate_pkce();
916        let state = OpenAIOAuth::generate_state();
917
918        let url = oauth.build_auth_url(&pkce, &state);
919
920        assert!(url.contains(AUTHORIZE_URL));
921        assert!(url.contains(&format!("client_id={CLIENT_ID}")));
922        assert!(url.contains("response_type=code"));
923        assert!(url.contains("code_challenge="));
924        assert!(url.contains("code_challenge_method=S256"));
925        assert!(url.contains("id_token_add_organizations=true"));
926        assert!(url.contains("codex_cli_simplified_flow=true"));
927        assert!(url.contains(&format!("originator={ORIGINATOR}")));
928        assert!(url.contains("redirect_uri=http%3A%2F%2Flocalhost%3A1455%2Fauth%2Fcallback"));
929    }
930
931    #[test]
932    fn test_parse_callback_input_url() {
933        let input = "http://localhost:1455/auth/callback?code=abc123&state=state456";
934        let parsed = parse_callback_input(input).unwrap();
935        assert_eq!(parsed.code, "abc123");
936        assert_eq!(parsed.state, "state456");
937    }
938
939    #[test]
940    fn test_extract_chatgpt_account_id() {
941        let payload = serde_json::json!({
942            CHATGPT_ACCOUNT_ID_NESTED_CLAIM: {
943                "chatgpt_account_id": "acct_123"
944            },
945            "exp": 1_700_000_000u64
946        });
947        let token = make_jwt(payload);
948        let account_id = extract_chatgpt_account_id(&token).unwrap();
949        assert_eq!(account_id.0, "acct_123");
950    }
951
952    #[test]
953    fn test_extract_chatgpt_account_id_nested_claim() {
954        let payload = serde_json::json!({
955            CHATGPT_ACCOUNT_ID_NESTED_CLAIM: {
956                "chatgpt_account_id": "acct_nested"
957            },
958            "exp": 1_700_000_000u64
959        });
960        let token = make_jwt(payload);
961        let account_id = extract_chatgpt_account_id(&token).unwrap();
962        assert_eq!(account_id.0, "acct_nested");
963    }
964
965    #[test]
966    fn test_resolve_expires_at_from_token() {
967        let payload = serde_json::json!({
968            "chatgpt_account_id": "acct_123",
969            "exp": 1_700_000_000u64
970        });
971        let token = make_jwt(payload);
972        let exp = resolve_expires_at(None, &token).unwrap();
973        assert_eq!(exp, UNIX_EPOCH + Duration::from_secs(1_700_000_000u64));
974    }
975
976    #[test]
977    fn test_openai_codex_models_require_plugin_auth() {
978        let visibility = OpenAiModelVisibility;
979        let codex_5_2 = ModelId {
980            provider_id: ProviderId(PROVIDER_ID.to_string()),
981            model_id: GPT_5_2_CODEX_MODEL_ID.to_string(),
982        };
983        let codex_5_3 = ModelId {
984            provider_id: ProviderId(PROVIDER_ID.to_string()),
985            model_id: GPT_5_3_CODEX_MODEL_ID.to_string(),
986        };
987
988        assert!(visibility.allow_model(
989            &codex_5_2,
990            &AuthSource::Plugin {
991                method: AuthMethod::OAuth,
992            }
993        ));
994        assert!(visibility.allow_model(
995            &codex_5_3,
996            &AuthSource::Plugin {
997                method: AuthMethod::OAuth,
998            }
999        ));
1000        assert!(!visibility.allow_model(
1001            &codex_5_2,
1002            &AuthSource::ApiKey {
1003                origin: steer_auth_plugin::ApiKeyOrigin::Env,
1004            }
1005        ));
1006        assert!(!visibility.allow_model(
1007            &codex_5_3,
1008            &AuthSource::ApiKey {
1009                origin: steer_auth_plugin::ApiKeyOrigin::Stored,
1010            }
1011        ));
1012    }
1013
1014    fn make_jwt(payload: serde_json::Value) -> String {
1015        let header = base64_url_encode(b"{}");
1016        let payload = base64_url_encode(payload.to_string().as_bytes());
1017        format!("{header}.{payload}.sig")
1018    }
1019}