Skip to main content

stakpak_shared/models/
openai_runtime.rs

1use crate::models::auth::ProviderAuth;
2use crate::models::integrations::openai::OpenAIConfig as InputOpenAIConfig;
3use crate::models::llm::ProviderConfig;
4pub use stakai::providers::openai::runtime::{
5    CodexBackendProfile, CompatibleBackendProfile, OfficialBackendProfile, OpenAIBackendProfile,
6};
7use stakai::types::{CompletionsConfig, OpenAIApiConfig, OpenAIOptions, ResponsesConfig};
8use thiserror::Error;
9
10const OFFICIAL_OPENAI_BASE_URL: &str = "https://api.openai.com/v1";
11
12#[derive(Debug, Clone, PartialEq, Eq)]
13pub enum OpenAIResolvedAuth {
14    ApiKey {
15        key: String,
16    },
17    OAuthBearer {
18        access_token: String,
19        refresh_token: Option<String>,
20        expires_at: Option<i64>,
21    },
22}
23
24impl OpenAIResolvedAuth {
25    pub fn authorization_token(&self) -> &str {
26        match self {
27            Self::ApiKey { key } => key,
28            Self::OAuthBearer { access_token, .. } => access_token,
29        }
30    }
31}
32
33#[derive(Debug, Clone)]
34pub struct OpenAIResolvedConfig {
35    pub auth: OpenAIResolvedAuth,
36    pub backend: OpenAIBackendProfile,
37    pub default_api_mode: OpenAIApiConfig,
38}
39
40impl OpenAIResolvedConfig {
41    pub fn to_stakai_config(&self) -> stakai::providers::openai::OpenAIConfig {
42        let mut config = stakai::providers::openai::OpenAIConfig::new(
43            self.auth.authorization_token().to_string(),
44        );
45
46        match &self.backend {
47            OpenAIBackendProfile::Official(profile) => {
48                if profile.base_url != OFFICIAL_OPENAI_BASE_URL {
49                    config = config.with_base_url(profile.base_url.clone());
50                }
51            }
52            OpenAIBackendProfile::Compatible(profile) => {
53                config = config.with_base_url(profile.base_url.clone());
54            }
55            OpenAIBackendProfile::Codex(profile) => {
56                config = config
57                    .with_base_url(profile.base_url.clone())
58                    .with_custom_header("originator", profile.originator.clone())
59                    .with_custom_header("ChatGPT-Account-Id", profile.chatgpt_account_id.clone());
60            }
61        }
62
63        match self.default_api_mode {
64            OpenAIApiConfig::Responses(_) => {
65                config.with_default_openai_options(OpenAIOptions::responses())
66            }
67            OpenAIApiConfig::Completions(_) => config,
68        }
69    }
70}
71
72#[derive(Debug, Clone)]
73pub struct OpenAIBackendResolutionInput {
74    provider_config: Option<ProviderConfig>,
75    auth: Option<ProviderAuth>,
76}
77
78impl OpenAIBackendResolutionInput {
79    pub fn new(provider_config: Option<ProviderConfig>, auth: Option<ProviderAuth>) -> Self {
80        Self {
81            provider_config,
82            auth,
83        }
84    }
85
86    fn provider_fields(&self) -> Result<OpenAIProviderFields, OpenAIResolutionError> {
87        match self.provider_config.as_ref() {
88            None => Ok(OpenAIProviderFields::default()),
89            Some(ProviderConfig::OpenAI { api_endpoint, .. }) => Ok(OpenAIProviderFields {
90                api_endpoint: api_endpoint.clone(),
91            }),
92            Some(other) => Err(OpenAIResolutionError::UnsupportedProviderConfig(
93                other.provider_type().to_string(),
94            )),
95        }
96    }
97}
98
99#[derive(Debug, Default, Clone)]
100struct OpenAIProviderFields {
101    api_endpoint: Option<String>,
102}
103
104#[derive(Debug, Error)]
105pub enum OpenAIResolutionError {
106    #[error("OpenAI runtime resolution only supports openai provider config, got {0}")]
107    UnsupportedProviderConfig(String),
108    #[error("ChatGPT Plus/Pro OAuth credentials are missing required chatgpt_account_id claim")]
109    MissingCodexAccountId,
110}
111
112pub fn resolve_openai_runtime(
113    input: OpenAIBackendResolutionInput,
114) -> Result<Option<OpenAIResolvedConfig>, OpenAIResolutionError> {
115    let provider_fields = input.provider_fields()?;
116    let Some(auth) = input.auth else {
117        return Ok(None);
118    };
119
120    match auth {
121        ProviderAuth::Api { key } => {
122            let base_url = provider_fields
123                .api_endpoint
124                .unwrap_or_else(|| OFFICIAL_OPENAI_BASE_URL.to_string());
125            let (backend, default_api_mode) = if base_url == OFFICIAL_OPENAI_BASE_URL {
126                (
127                    OpenAIBackendProfile::Official(OfficialBackendProfile { base_url }),
128                    OpenAIApiConfig::Responses(ResponsesConfig::default()),
129                )
130            } else {
131                (
132                    OpenAIBackendProfile::Compatible(CompatibleBackendProfile { base_url }),
133                    OpenAIApiConfig::Completions(CompletionsConfig::default()),
134                )
135            };
136
137            Ok(Some(OpenAIResolvedConfig {
138                auth: OpenAIResolvedAuth::ApiKey { key },
139                backend,
140                default_api_mode,
141            }))
142        }
143        ProviderAuth::OAuth {
144            access,
145            refresh,
146            expires,
147            ..
148        } => {
149            let Some(chatgpt_account_id) = InputOpenAIConfig::extract_chatgpt_account_id(&access)
150            else {
151                return Err(OpenAIResolutionError::MissingCodexAccountId);
152            };
153
154            Ok(Some(OpenAIResolvedConfig {
155                auth: OpenAIResolvedAuth::OAuthBearer {
156                    access_token: access,
157                    refresh_token: if refresh.is_empty() {
158                        None
159                    } else {
160                        Some(refresh)
161                    },
162                    expires_at: Some(expires),
163                },
164                backend: OpenAIBackendProfile::Codex(CodexBackendProfile {
165                    base_url: provider_fields
166                        .api_endpoint
167                        .unwrap_or_else(|| InputOpenAIConfig::OPENAI_CODEX_BASE_URL.to_string()),
168                    originator: "stakpak".to_string(),
169                    chatgpt_account_id,
170                }),
171                default_api_mode: OpenAIApiConfig::Responses(ResponsesConfig::default()),
172            }))
173        }
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180    use base64::Engine;
181
182    #[test]
183    fn test_to_stakai_config_for_codex_oauth() {
184        let resolved = OpenAIResolvedConfig {
185            auth: OpenAIResolvedAuth::OAuthBearer {
186                access_token: "access-token".to_string(),
187                refresh_token: Some("refresh-token".to_string()),
188                expires_at: Some(123),
189            },
190            backend: OpenAIBackendProfile::Codex(CodexBackendProfile {
191                base_url: InputOpenAIConfig::OPENAI_CODEX_BASE_URL.to_string(),
192                originator: "stakpak".to_string(),
193                chatgpt_account_id: "acct_test_123".to_string(),
194            }),
195            default_api_mode: OpenAIApiConfig::Responses(ResponsesConfig::default()),
196        };
197
198        let config = resolved.to_stakai_config();
199
200        assert_eq!(config.api_key, "access-token");
201        assert_eq!(config.base_url, InputOpenAIConfig::OPENAI_CODEX_BASE_URL);
202        assert_eq!(
203            config.custom_headers.get("ChatGPT-Account-Id"),
204            Some(&"acct_test_123".to_string())
205        );
206        assert_eq!(
207            config.custom_headers.get("originator"),
208            Some(&"stakpak".to_string())
209        );
210        assert!(matches!(
211            config.default_openai_options,
212            Some(OpenAIOptions {
213                api_config: Some(OpenAIApiConfig::Responses(_)),
214                ..
215            })
216        ));
217    }
218
219    #[test]
220    fn test_resolve_openai_runtime_for_oauth_codex() {
221        let payload = serde_json::json!({
222            "https://api.openai.com/auth": {
223                "chatgpt_account_id": "acct_test_789"
224            }
225        });
226        let encoded_payload =
227            base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(payload.to_string().as_bytes());
228        let access_token = format!("header.{}.signature", encoded_payload);
229
230        let auth = ProviderAuth::oauth_with_name(
231            access_token,
232            "refresh-token",
233            i64::MAX,
234            "ChatGPT Plus/Pro",
235        );
236        let resolved = resolve_openai_runtime(OpenAIBackendResolutionInput::new(
237            Some(ProviderConfig::openai_with_auth(auth.clone())),
238            Some(auth),
239        ));
240
241        assert!(resolved.is_ok());
242    }
243}