Skip to main content

zeph_llm/openai/
mod.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! `OpenAI` API backend.
5//!
6//! [`OpenAiProvider`] targets the `OpenAI` Chat Completions and Embeddings APIs.
7//! It also serves as the foundation for [`crate::compatible::CompatibleProvider`],
8//! which points the same implementation at any `OpenAI`-compatible endpoint
9//! (Together AI, Fireworks, local vLLM, etc.).
10//!
11//! # Supported capabilities
12//!
13//! - Chat completion (non-streaming and SSE streaming)
14//! - Native tool use (function calling)
15//! - Embeddings (`text-embedding-*` family)
16//! - Reasoning effort for `o*` models (`low` / `medium` / `high`)
17//! - Vision via base64-encoded images in message content
18//!
19//! # Configuration
20//!
21//! ```toml
22//! [[llm.providers]]
23//! name = "openai"
24//! type = "openai"
25//! model = "gpt-4o"
26//! max_tokens = 4096
27//! embedding_model = "text-embedding-3-small"
28//! api_key_vault = "ZEPH_OPENAI_API_KEY"
29//! ```
30
31use std::fmt;
32
33use crate::error::LlmError;
34use crate::tool_desc::build_tool_description;
35use base64::{Engine, engine::general_purpose::STANDARD};
36use serde::{Deserialize, Serialize};
37
38/// Configuration for [`OpenAiProvider`].
39///
40/// Pass to [`OpenAiProvider::new`] instead of individual positional arguments to avoid
41/// silent parameter transposition.
42///
43/// # Examples
44///
45/// ```
46/// use zeph_llm::openai::{OpenAiConfig, OpenAiProvider};
47///
48/// let cfg = OpenAiConfig {
49///     api_key: "sk-...".into(),
50///     base_url: "https://api.openai.com/v1".into(),
51///     model: "gpt-4o".into(),
52///     max_tokens: 4096,
53///     embedding_model: Some("text-embedding-3-small".into()),
54///     reasoning_effort: None,
55/// };
56/// let provider = OpenAiProvider::new(cfg);
57/// ```
58#[derive(Debug, Clone)]
59pub struct OpenAiConfig {
60    /// Secret API key sent in the `Authorization: Bearer` header.
61    pub api_key: String,
62    /// Base URL of the endpoint, e.g. `"https://api.openai.com/v1"`.
63    /// Trailing slashes are stripped automatically.
64    pub base_url: String,
65    /// Chat model identifier, e.g. `"gpt-4o"`.
66    pub model: String,
67    /// Upper bound on completion tokens returned by the model.
68    pub max_tokens: u32,
69    /// Embedding model identifier. Set to `None` when the endpoint does not support embeddings.
70    pub embedding_model: Option<String>,
71    /// Reasoning effort level for `o*` models (`"low"`, `"medium"`, or `"high"`).
72    /// Leave `None` for standard chat models.
73    pub reasoning_effort: Option<String>,
74}
75
76use crate::provider::{
77    ChatExtras, ChatResponse, ChatStream, GenerationOverrides, LlmProvider, Message, MessagePart,
78    Role, StatusTx, ToolDefinition, ToolUseRequest,
79};
80use crate::retry::send_with_retry;
81use crate::sse::openai_sse_to_stream;
82use crate::usage::UsageTracker;
83
84const MAX_RETRIES: u32 = 3;
85
86/// [`LlmProvider`] backend for the `OpenAI` API (and compatible endpoints).
87///
88/// For `OpenAI`-compatible third-party services, prefer [`crate::compatible::CompatibleProvider`]
89/// which wraps this type with a named provider for logging.
90///
91/// Construct with [`OpenAiProvider::new`] and chain optional builder methods:
92/// - [`with_generation_overrides`](Self::with_generation_overrides)
93/// - [`with_status_tx`](Self::with_status_tx)
94pub struct OpenAiProvider {
95    client: reqwest::Client,
96    api_key: String,
97    base_url: String,
98    model: String,
99    max_tokens: u32,
100    embedding_model: Option<String>,
101    /// Reasoning effort level for `o*` models (`"low"`, `"medium"`, or `"high"`).
102    reasoning_effort: Option<String>,
103    pub(crate) status_tx: Option<StatusTx>,
104    usage: UsageTracker,
105    generation_overrides: Option<GenerationOverrides>,
106    /// When `true`, append a compact JSON hint of the tool's output schema to its description.
107    forward_output_schema: bool,
108    /// Maximum bytes of the compact JSON appended as the output schema hint.
109    output_schema_hint_bytes: usize,
110    /// Maximum bytes of the combined description (base + hint). `usize::MAX` means no cap.
111    max_tool_description_bytes: usize,
112}
113
114impl fmt::Debug for OpenAiProvider {
115    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
116        f.debug_struct("OpenAiProvider")
117            .field("client", &"<reqwest::Client>")
118            .field("api_key", &"<redacted>")
119            .field("base_url", &self.base_url)
120            .field("model", &self.model)
121            .field("max_tokens", &self.max_tokens)
122            .field("embedding_model", &self.embedding_model)
123            .field("reasoning_effort", &self.reasoning_effort)
124            .field("status_tx", &self.status_tx.is_some())
125            .field("usage", &self.usage)
126            .field("generation_overrides", &self.generation_overrides)
127            .field("forward_output_schema", &self.forward_output_schema)
128            .field("output_schema_hint_bytes", &self.output_schema_hint_bytes)
129            .field(
130                "max_tool_description_bytes",
131                &self.max_tool_description_bytes,
132            )
133            .finish()
134    }
135}
136
137impl Clone for OpenAiProvider {
138    fn clone(&self) -> Self {
139        Self {
140            client: self.client.clone(),
141            api_key: self.api_key.clone(),
142            base_url: self.base_url.clone(),
143            model: self.model.clone(),
144            max_tokens: self.max_tokens,
145            embedding_model: self.embedding_model.clone(),
146            reasoning_effort: self.reasoning_effort.clone(),
147            status_tx: self.status_tx.clone(),
148            usage: UsageTracker::default(),
149            generation_overrides: self.generation_overrides.clone(),
150            forward_output_schema: self.forward_output_schema,
151            output_schema_hint_bytes: self.output_schema_hint_bytes,
152            max_tool_description_bytes: self.max_tool_description_bytes,
153        }
154    }
155}
156
157impl OpenAiProvider {
158    /// Create a new provider from an [`OpenAiConfig`].
159    #[must_use]
160    pub fn new(cfg: OpenAiConfig) -> Self {
161        let mut base_url = cfg.base_url;
162        while base_url.ends_with('/') {
163            base_url.pop();
164        }
165        Self {
166            client: crate::http::llm_client(600),
167            api_key: cfg.api_key,
168            base_url,
169            model: cfg.model,
170            max_tokens: cfg.max_tokens,
171            embedding_model: cfg.embedding_model,
172            reasoning_effort: cfg.reasoning_effort,
173            status_tx: None,
174            usage: UsageTracker::default(),
175            generation_overrides: None,
176            forward_output_schema: false,
177            output_schema_hint_bytes: 1024,
178            max_tool_description_bytes: usize::MAX,
179        }
180    }
181
182    /// Override generation parameters (temperature, top-p, frequency/presence penalty).
183    #[must_use]
184    pub fn with_generation_overrides(mut self, overrides: GenerationOverrides) -> Self {
185        self.generation_overrides = Some(overrides);
186        self
187    }
188
189    /// Enable forwarding of MCP tool output schemas as a description hint.
190    ///
191    /// `max_description_bytes` caps the combined `base + hint` string. Pass `usize::MAX` for no cap.
192    #[must_use]
193    pub fn with_output_schema_forwarding(
194        mut self,
195        enabled: bool,
196        hint_bytes: usize,
197        max_description_bytes: usize,
198    ) -> Self {
199        self.forward_output_schema = enabled;
200        self.output_schema_hint_bytes = hint_bytes;
201        self.max_tool_description_bytes = max_description_bytes;
202        self
203    }
204
205    /// Replace the underlying HTTP client. Mainly used in tests to inject a mock transport.
206    #[must_use]
207    pub fn with_client(mut self, client: reqwest::Client) -> Self {
208        self.client = client;
209        self
210    }
211
212    /// Attach a status event sender so the UI receives retry and fallback notifications.
213    #[must_use]
214    pub fn with_status_tx(mut self, tx: StatusTx) -> Self {
215        self.status_tx = Some(tx);
216        self
217    }
218
219    /// Derive a filesystem-safe cache slug from the provider's base URL hostname.
220    ///
221    /// Only ASCII alphanumeric characters and underscores are kept to prevent
222    /// path traversal via unusual base URLs.
223    #[must_use]
224    pub fn cache_slug(&self) -> String {
225        let host = self
226            .base_url
227            .trim_start_matches("https://")
228            .trim_start_matches("http://")
229            .split('/')
230            .next()
231            .unwrap_or("openai")
232            .split(':')
233            .next()
234            .unwrap_or("openai");
235        let slug: String = host
236            .chars()
237            .map(|c| if c == '.' || c == '-' { '_' } else { c })
238            .filter(|c| c.is_ascii_alphanumeric() || *c == '_')
239            .collect();
240        if slug.is_empty() {
241            "openai".to_string()
242        } else {
243            slug
244        }
245    }
246
247    /// Fetch the list of available models from GET `{base_url}/models` and cache them.
248    ///
249    /// # Errors
250    ///
251    /// Returns an error if the API request fails.
252    pub async fn list_models_remote(
253        &self,
254    ) -> Result<Vec<crate::model_cache::RemoteModelInfo>, LlmError> {
255        let url = format!("{}/models", self.base_url);
256        let resp = self
257            .client
258            .get(&url)
259            .bearer_auth(&self.api_key)
260            .send()
261            .await?;
262
263        if !resp.status().is_success() {
264            let status = resp.status();
265            let body = resp.text().await.unwrap_or_default();
266            tracing::debug!(status = %status, body = %body, "OpenAI list_models_remote error body");
267            return Err(LlmError::ApiError {
268                provider: "openai".into(),
269                status: status.as_u16(),
270            });
271        }
272
273        let page: serde_json::Value = resp.json().await?;
274        let models: Vec<crate::model_cache::RemoteModelInfo> = page
275            .get("data")
276            .and_then(|v| v.as_array())
277            .map(|arr| {
278                arr.iter()
279                    .filter_map(|item| {
280                        let id = item.get("id")?.as_str()?.to_string();
281                        let created_at = item.get("created").and_then(serde_json::Value::as_i64);
282                        Some(crate::model_cache::RemoteModelInfo {
283                            display_name: id.clone(),
284                            id,
285                            context_window: None,
286                            created_at,
287                        })
288                    })
289                    .collect()
290            })
291            .unwrap_or_default();
292
293        let slug = self.cache_slug();
294        let cache = crate::model_cache::ModelCache::for_slug(&slug);
295        cache.save(&models)?;
296        Ok(models)
297    }
298
299    fn store_cache_usage(&self, usage: &OpenAiUsage) {
300        self.usage
301            .record_usage(usage.prompt_tokens, usage.completion_tokens);
302        let cached = usage
303            .prompt_tokens_details
304            .as_ref()
305            .map_or(0, |d| d.cached_tokens);
306        if cached > 0 {
307            self.usage.record_cache(0, cached);
308        }
309        let reasoning = usage
310            .completion_tokens_details
311            .as_ref()
312            .map_or(0, |d| d.reasoning_tokens);
313        if reasoning > 0 {
314            self.usage.record_reasoning(reasoning);
315        }
316        tracing::debug!(
317            prompt_tokens = usage.prompt_tokens,
318            cached_tokens = cached,
319            completion_tokens = usage.completion_tokens,
320            reasoning_tokens = reasoning,
321            "OpenAI API usage"
322        );
323    }
324
325    async fn send_request(&self, messages: &[Message]) -> Result<String, LlmError> {
326        let reasoning = self
327            .reasoning_effort
328            .as_deref()
329            .map(|effort| Reasoning { effort });
330
331        let (temperature, top_p, frequency_penalty, presence_penalty) =
332            if let Some(ref ov) = self.generation_overrides {
333                (
334                    ov.temperature,
335                    ov.top_p,
336                    ov.frequency_penalty,
337                    ov.presence_penalty,
338                )
339            } else {
340                (None, None, None, None)
341            };
342
343        let response = if has_image_parts(messages) {
344            let vision_messages = convert_messages_vision(messages);
345            let body = VisionChatRequest {
346                model: &self.model,
347                messages: vision_messages,
348                completion_tokens: CompletionTokens::for_model(&self.model, self.max_tokens),
349                stream: false,
350                reasoning,
351                temperature,
352                top_p,
353                frequency_penalty,
354                presence_penalty,
355            };
356            send_with_retry("OpenAI", MAX_RETRIES, self.status_tx.as_ref(), || {
357                self.openai_post(format!("{}/chat/completions", self.base_url))
358                    .json(&body)
359                    .send()
360            })
361            .await?
362        } else {
363            let api_messages = convert_messages(messages);
364            let body = ChatRequest {
365                model: &self.model,
366                messages: &api_messages,
367                completion_tokens: CompletionTokens::for_model(&self.model, self.max_tokens),
368                stream: false,
369                reasoning,
370                temperature,
371                top_p,
372                frequency_penalty,
373                presence_penalty,
374            };
375            send_with_retry("OpenAI", MAX_RETRIES, self.status_tx.as_ref(), || {
376                self.openai_post(format!("{}/chat/completions", self.base_url))
377                    .json(&body)
378                    .send()
379            })
380            .await?
381        };
382
383        let status = response.status();
384        let text = response.text().await.map_err(LlmError::Http)?;
385
386        if !status.is_success() {
387            tracing::error!("OpenAI API error {status}: {text}");
388            return Err(crate::http::map_error_response(status, &text, "openai"));
389        }
390
391        let resp: OpenAiChatResponse = serde_json::from_str(&text)?;
392
393        if let Some(ref usage) = resp.usage {
394            self.store_cache_usage(usage);
395        }
396
397        resp.choices
398            .first()
399            .map(|c| c.message.content.clone())
400            .ok_or(LlmError::EmptyResponse {
401                provider: "openai".into(),
402            })
403    }
404
405    async fn send_stream_request(
406        &self,
407        messages: &[Message],
408    ) -> Result<reqwest::Response, LlmError> {
409        let api_messages = convert_messages(messages);
410        let reasoning = self
411            .reasoning_effort
412            .as_deref()
413            .map(|effort| Reasoning { effort });
414
415        let (temperature, top_p, frequency_penalty, presence_penalty) =
416            if let Some(ref ov) = self.generation_overrides {
417                (
418                    ov.temperature,
419                    ov.top_p,
420                    ov.frequency_penalty,
421                    ov.presence_penalty,
422                )
423            } else {
424                (None, None, None, None)
425            };
426
427        let body = ChatRequest {
428            model: &self.model,
429            messages: &api_messages,
430            completion_tokens: CompletionTokens::for_model(&self.model, self.max_tokens),
431            stream: true,
432            reasoning,
433            temperature,
434            top_p,
435            frequency_penalty,
436            presence_penalty,
437        };
438
439        let response = send_with_retry("OpenAI", MAX_RETRIES, self.status_tx.as_ref(), || {
440            self.openai_post(format!("{}/chat/completions", self.base_url))
441                .json(&body)
442                .send()
443        })
444        .await?;
445
446        let status = response.status();
447
448        if !status.is_success() {
449            let text = response.text().await.map_err(LlmError::Http)?;
450            tracing::error!("OpenAI API streaming request error {status}: {text}");
451            return Err(crate::http::map_error_response(status, &text, "openai"));
452        }
453
454        Ok(response)
455    }
456
457    /// Builds an authenticated POST request with `Content-Type: application/json`.
458    fn openai_post(&self, url: String) -> reqwest::RequestBuilder {
459        self.client
460            .post(url)
461            .bearer_auth(&self.api_key)
462            .header("Content-Type", "application/json")
463    }
464}
465
466impl LlmProvider for OpenAiProvider {
467    fn context_window(&self) -> Option<usize> {
468        if self.model.starts_with("gpt-4o") || self.model.starts_with("gpt-4") {
469            Some(128_000)
470        } else if self.model.starts_with("gpt-3.5") {
471            Some(16_385)
472        } else if self.model.starts_with("gpt-5") {
473            Some(1_000_000)
474        } else if starts_with_o_digit(&self.model) {
475            Some(200_000)
476        } else {
477            None
478        }
479    }
480
481    #[cfg_attr(
482        feature = "profiling",
483        tracing::instrument(
484            name = "llm.chat",
485            skip_all,
486            fields(provider = self.name(), model = self.model_identifier())
487        )
488    )]
489    async fn chat(&self, messages: &[Message]) -> Result<String, LlmError> {
490        self.send_request(messages).await
491    }
492
493    async fn chat_with_extras(
494        &self,
495        messages: &[Message],
496    ) -> Result<(String, ChatExtras), LlmError> {
497        Ok((self.send_request(messages).await?, ChatExtras::default()))
498    }
499
500    #[cfg_attr(
501        feature = "profiling",
502        tracing::instrument(
503            name = "llm.chat_stream",
504            skip_all,
505            fields(provider = self.name(), model = self.model_identifier())
506        )
507    )]
508    async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
509        let response = self.send_stream_request(messages).await?;
510        Ok(openai_sse_to_stream(response))
511    }
512
513    fn supports_streaming(&self) -> bool {
514        true
515    }
516
517    #[cfg_attr(
518        feature = "profiling",
519        tracing::instrument(
520            name = "llm.embed",
521            skip_all,
522            fields(provider = self.name(), model = self.model_identifier())
523        )
524    )]
525    async fn embed(&self, text: &str) -> Result<Vec<f32>, LlmError> {
526        use crate::embed::truncate_for_embed;
527
528        let model = self
529            .embedding_model
530            .as_deref()
531            .ok_or(LlmError::EmbedUnsupported {
532                provider: "openai".into(),
533            })?;
534
535        let text = truncate_for_embed(text);
536        let body = EmbeddingRequest {
537            input: &text,
538            model,
539        };
540
541        let response = self
542            .openai_post(format!("{}/embeddings", self.base_url))
543            .json(&body)
544            .send()
545            .await?;
546
547        let status = response.status();
548        let body_text = response.text().await.map_err(LlmError::Http)?;
549
550        if !status.is_success() {
551            tracing::error!("OpenAI embedding API error {status}: {body_text}");
552            if status == reqwest::StatusCode::BAD_REQUEST {
553                return Err(LlmError::InvalidInput {
554                    provider: "openai".into(),
555                    message: body_text,
556                });
557            }
558            return Err(LlmError::ApiError {
559                provider: "openai".into(),
560                status: status.as_u16(),
561            });
562        }
563
564        let resp: EmbeddingResponse = serde_json::from_str(&body_text)?;
565
566        resp.data
567            .first()
568            .map(|d| d.embedding.clone())
569            .ok_or(LlmError::EmptyResponse {
570                provider: "openai".into(),
571            })
572    }
573
574    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, LlmError> {
575        use crate::embed::truncate_for_embed;
576
577        if texts.is_empty() {
578            return Ok(Vec::new());
579        }
580
581        let model = self
582            .embedding_model
583            .as_deref()
584            .ok_or(LlmError::EmbedUnsupported {
585                provider: "openai".into(),
586            })?;
587
588        let truncated: Vec<std::borrow::Cow<'_, str>> =
589            texts.iter().map(|t| truncate_for_embed(t)).collect();
590        let refs: Vec<&str> = truncated.iter().map(std::convert::AsRef::as_ref).collect();
591
592        let body = EmbeddingBatchRequest { model, input: refs };
593
594        let response = self
595            .openai_post(format!("{}/embeddings", self.base_url))
596            .json(&body)
597            .send()
598            .await?;
599
600        let status = response.status();
601        let body_text = response.text().await.map_err(LlmError::Http)?;
602
603        if !status.is_success() {
604            tracing::error!("OpenAI batch embedding API error {status}: {body_text}");
605            if status == reqwest::StatusCode::BAD_REQUEST {
606                return Err(LlmError::InvalidInput {
607                    provider: "openai".into(),
608                    message: body_text,
609                });
610            }
611            return Err(LlmError::ApiError {
612                provider: "openai".into(),
613                status: status.as_u16(),
614            });
615        }
616
617        let resp: EmbeddingResponse = serde_json::from_str(&body_text)?;
618
619        if resp.data.len() != texts.len() {
620            return Err(LlmError::Other(format!(
621                "OpenAI returned {} embeddings for {} inputs",
622                resp.data.len(),
623                texts.len()
624            )));
625        }
626
627        // Sort by index to guarantee order even if the API ever returns out of order.
628        let mut data = resp.data;
629        data.sort_unstable_by_key(|d| d.index);
630
631        Ok(data.into_iter().map(|d| d.embedding).collect())
632    }
633
634    fn supports_embeddings(&self) -> bool {
635        self.embedding_model.is_some()
636    }
637
638    #[allow(clippy::unnecessary_literal_bound)]
639    fn name(&self) -> &str {
640        "openai"
641    }
642
643    fn model_identifier(&self) -> &str {
644        &self.model
645    }
646
647    fn list_models(&self) -> Vec<String> {
648        vec![self.model.clone()]
649    }
650
651    fn last_cache_usage(&self) -> Option<(u64, u64)> {
652        self.usage.last_cache_usage()
653    }
654
655    fn last_usage(&self) -> Option<(u64, u64)> {
656        self.usage.last_usage()
657    }
658
659    fn last_reasoning_tokens(&self) -> Option<u64> {
660        self.usage.last_reasoning()
661    }
662
663    fn debug_request_json(
664        &self,
665        messages: &[Message],
666        tools: &[ToolDefinition],
667        stream: bool,
668    ) -> serde_json::Value {
669        let reasoning = self
670            .reasoning_effort
671            .as_deref()
672            .map(|effort| Reasoning { effort });
673        let (temperature, top_p, frequency_penalty, presence_penalty) = self
674            .generation_overrides
675            .as_ref()
676            .map(|ov| {
677                (
678                    ov.temperature,
679                    ov.top_p,
680                    ov.frequency_penalty,
681                    ov.presence_penalty,
682                )
683            })
684            .unwrap_or_default();
685
686        if !tools.is_empty() {
687            let api_messages = convert_messages_structured(messages);
688            let descriptions: Vec<String> = tools
689                .iter()
690                .map(|t| {
691                    build_tool_description(
692                        &t.description,
693                        t.output_schema.as_ref(),
694                        self.forward_output_schema,
695                        self.output_schema_hint_bytes,
696                        self.max_tool_description_bytes,
697                        t.name.as_str(),
698                    )
699                })
700                .collect();
701            let api_tools: Vec<OpenAiTool<'_>> = tools
702                .iter()
703                .zip(descriptions.iter())
704                .map(|(t, desc)| OpenAiTool {
705                    r#type: "function",
706                    function: OpenAiFunction {
707                        name: t.name.as_str(),
708                        description: desc.as_str(),
709                        parameters: prepare_tool_params(&t.parameters),
710                    },
711                })
712                .collect();
713            let body = ToolChatRequest {
714                model: &self.model,
715                messages: &api_messages,
716                completion_tokens: CompletionTokens::for_model(&self.model, self.max_tokens),
717                tools: &api_tools,
718                reasoning,
719                temperature,
720                top_p,
721                frequency_penalty,
722                presence_penalty,
723            };
724            return serde_json::to_value(&body)
725                .unwrap_or_else(|e| serde_json::json!({ "serialization_error": e.to_string() }));
726        }
727
728        if has_image_parts(messages) {
729            let vision_messages = convert_messages_vision(messages);
730            let body = VisionChatRequest {
731                model: &self.model,
732                messages: vision_messages,
733                completion_tokens: CompletionTokens::for_model(&self.model, self.max_tokens),
734                stream,
735                reasoning,
736                temperature,
737                top_p,
738                frequency_penalty,
739                presence_penalty,
740            };
741            return serde_json::to_value(&body)
742                .unwrap_or_else(|e| serde_json::json!({ "serialization_error": e.to_string() }));
743        }
744
745        let api_messages = convert_messages(messages);
746        let body = ChatRequest {
747            model: &self.model,
748            messages: &api_messages,
749            completion_tokens: CompletionTokens::for_model(&self.model, self.max_tokens),
750            stream,
751            reasoning,
752            temperature,
753            top_p,
754            frequency_penalty,
755            presence_penalty,
756        };
757        serde_json::to_value(&body)
758            .unwrap_or_else(|e| serde_json::json!({ "serialization_error": e.to_string() }))
759    }
760
761    fn supports_structured_output(&self) -> bool {
762        true
763    }
764
765    async fn chat_typed<T>(&self, messages: &[Message]) -> Result<T, LlmError>
766    where
767        T: serde::de::DeserializeOwned + schemars::JsonSchema + 'static,
768        Self: Sized,
769    {
770        let (raw_schema, _) = crate::provider::cached_schema::<T>()?;
771        let mut schema_value = raw_schema;
772        inline_refs_openai(&mut schema_value, 8);
773        normalize_for_openai_strict(&mut schema_value, 16);
774        let type_name = crate::provider::short_type_name::<T>();
775
776        let api_messages = convert_messages(messages);
777        let body = TypedChatRequest {
778            model: &self.model,
779            messages: &api_messages,
780            completion_tokens: CompletionTokens::for_model(&self.model, self.max_tokens),
781            response_format: ResponseFormat {
782                r#type: "json_schema",
783                json_schema: JsonSchemaFormat {
784                    name: type_name,
785                    schema: schema_value,
786                    strict: true,
787                },
788            },
789        };
790
791        let response = self
792            .openai_post(format!("{}/chat/completions", self.base_url))
793            .json(&body)
794            .send()
795            .await?;
796
797        let status = response.status();
798        let text = response.text().await.map_err(LlmError::Http)?;
799
800        if !status.is_success() {
801            return Err(crate::http::map_error_response(status, &text, "openai"));
802        }
803
804        let resp: OpenAiChatResponse = serde_json::from_str(&text)?;
805
806        if let Some(ref usage) = resp.usage {
807            self.store_cache_usage(usage);
808        }
809
810        let content = resp
811            .choices
812            .first()
813            .map(|c| c.message.content.as_str())
814            .ok_or(LlmError::EmptyResponse {
815                provider: "openai".into(),
816            })?;
817
818        serde_json::from_str::<T>(content).map_err(|e| LlmError::StructuredParse(e.to_string()))
819    }
820
821    fn supports_vision(&self) -> bool {
822        true
823    }
824
825    #[cfg_attr(
826        feature = "profiling",
827        tracing::instrument(
828            name = "llm.chat_with_tools",
829            skip_all,
830            fields(provider = self.name(), model = self.model_identifier(), tool_count = tools.len())
831        )
832    )]
833    async fn chat_with_tools(
834        &self,
835        messages: &[Message],
836        tools: &[ToolDefinition],
837    ) -> Result<ChatResponse, LlmError> {
838        let api_messages = convert_messages_structured(messages);
839        let reasoning = self
840            .reasoning_effort
841            .as_deref()
842            .map(|effort| Reasoning { effort });
843
844        let descriptions: Vec<String> = tools
845            .iter()
846            .map(|t| {
847                build_tool_description(
848                    &t.description,
849                    t.output_schema.as_ref(),
850                    self.forward_output_schema,
851                    self.output_schema_hint_bytes,
852                    self.max_tool_description_bytes,
853                    t.name.as_str(),
854                )
855            })
856            .collect();
857        let api_tools: Vec<OpenAiTool> = tools
858            .iter()
859            .zip(descriptions.iter())
860            .map(|(t, desc)| OpenAiTool {
861                r#type: "function",
862                function: OpenAiFunction {
863                    name: t.name.as_str(),
864                    description: desc.as_str(),
865                    parameters: prepare_tool_params(&t.parameters),
866                },
867            })
868            .collect();
869
870        let (temperature, top_p, frequency_penalty, presence_penalty) = self
871            .generation_overrides
872            .as_ref()
873            .map(|ov| {
874                (
875                    ov.temperature,
876                    ov.top_p,
877                    ov.frequency_penalty,
878                    ov.presence_penalty,
879                )
880            })
881            .unwrap_or_default();
882        let body = ToolChatRequest {
883            model: &self.model,
884            messages: &api_messages,
885            completion_tokens: CompletionTokens::for_model(&self.model, self.max_tokens),
886            tools: &api_tools,
887            reasoning,
888            temperature,
889            top_p,
890            frequency_penalty,
891            presence_penalty,
892        };
893
894        let response = self
895            .openai_post(format!("{}/chat/completions", self.base_url))
896            .json(&body)
897            .send()
898            .await?;
899
900        let status = response.status();
901        let text = response.text().await.map_err(LlmError::Http)?;
902
903        if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
904            return Err(LlmError::RateLimited);
905        }
906
907        if status == reqwest::StatusCode::BAD_REQUEST {
908            tracing::warn!("OpenAI tool chat 400 bad request: {text}");
909            if crate::error::body_is_context_length_error(&text) {
910                return Err(LlmError::ContextLengthExceeded);
911            }
912            return Err(LlmError::InvalidInput {
913                provider: self.name().to_owned(),
914                message: text,
915            });
916        }
917
918        if !status.is_success() {
919            tracing::error!("OpenAI API error {status}: {text}");
920            return Err(LlmError::ApiError {
921                provider: "openai".into(),
922                status: status.as_u16(),
923            });
924        }
925
926        self.decode_tool_chat_response(&text, "openai")
927    }
928}
929
930impl OpenAiProvider {
931    /// Decode a raw tool-chat JSON response body into a [`ChatResponse`].
932    ///
933    /// Records usage via `store_cache_usage`.  Pass `provider_name` so that
934    /// `EmptyResponse` errors carry the correct provider label.
935    pub(crate) fn decode_tool_chat_response(
936        &self,
937        text: &str,
938        provider_name: &str,
939    ) -> Result<ChatResponse, LlmError> {
940        let resp: ToolChatResponse = serde_json::from_str(text)?;
941
942        if let Some(ref usage) = resp.usage {
943            self.store_cache_usage(usage);
944        }
945
946        let choice = resp
947            .choices
948            .into_iter()
949            .next()
950            .ok_or(LlmError::EmptyResponse {
951                provider: provider_name.into(),
952            })?;
953
954        if let Some(tool_calls) = choice.message.tool_calls
955            && !tool_calls.is_empty()
956        {
957            let text = if choice.message.content.is_empty() {
958                None
959            } else {
960                Some(choice.message.content)
961            };
962            let calls = tool_calls
963                .into_iter()
964                .map(|tc| {
965                    let input = serde_json::from_str(&tc.function.arguments)
966                        .unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
967                    ToolUseRequest {
968                        id: tc.id,
969                        name: tc.function.name.into(),
970                        input,
971                    }
972                })
973                .collect();
974            return Ok(ChatResponse::ToolUse {
975                text,
976                tool_calls: calls,
977                thinking_blocks: vec![],
978            });
979        }
980
981        // Inject truncation marker when finish_reason is "length" so the agent loop
982        // can detect MaxTokens stop reason without touching ChatResponse structure.
983        let content = if choice.finish_reason.as_deref() == Some("length") {
984            let truncation_marker = crate::provider::MAX_TOKENS_TRUNCATION_MARKER;
985            if choice.message.content.is_empty() {
986                format!(
987                    "[Response truncated: {truncation_marker}. Please reduce the request scope.]"
988                )
989            } else {
990                format!(
991                    "{}\n[Response truncated: {truncation_marker}.]",
992                    choice.message.content
993                )
994            }
995        } else {
996            choice.message.content
997        };
998        Ok(ChatResponse::Text(content))
999    }
1000
1001    /// Build a serialized `TypedChatRequest` body for `chat_typed`.
1002    ///
1003    /// Extracts and normalises the JSON Schema for `T`, wraps it in
1004    /// `response_format: json_schema`, and returns the raw bytes ready for an
1005    /// HTTP POST body.
1006    ///
1007    /// # Errors
1008    ///
1009    /// Returns [`LlmError::StructuredParse`] if schema extraction or serialisation fails.
1010    #[cfg(any(feature = "gonka", feature = "cocoon"))]
1011    pub(crate) fn build_typed_chat_body<T>(&self, messages: &[Message]) -> Result<Vec<u8>, LlmError>
1012    where
1013        T: serde::de::DeserializeOwned + schemars::JsonSchema + 'static,
1014    {
1015        let (raw_schema, _) = crate::provider::cached_schema::<T>()?;
1016        let mut schema_value = raw_schema;
1017        inline_refs_openai(&mut schema_value, 8);
1018        normalize_for_openai_strict(&mut schema_value, 16);
1019        let type_name = crate::provider::short_type_name::<T>();
1020
1021        let api_messages = convert_messages(messages);
1022        let body = TypedChatRequest {
1023            model: &self.model,
1024            messages: &api_messages,
1025            completion_tokens: CompletionTokens::for_model(&self.model, self.max_tokens),
1026            response_format: ResponseFormat {
1027                r#type: "json_schema",
1028                json_schema: JsonSchemaFormat {
1029                    name: type_name,
1030                    schema: schema_value,
1031                    strict: true,
1032                },
1033            },
1034        };
1035
1036        serde_json::to_vec(&body).map_err(|e| LlmError::StructuredParse(e.to_string()))
1037    }
1038}
1039
1040#[derive(Serialize)]
1041#[serde(tag = "type", rename_all = "snake_case")]
1042enum OpenAiContentPart {
1043    Text { text: String },
1044    ImageUrl { image_url: ImageUrlDetail },
1045}
1046
1047#[derive(Serialize)]
1048struct ImageUrlDetail {
1049    url: String,
1050}
1051
1052#[derive(Serialize)]
1053struct VisionApiMessage {
1054    role: String,
1055    content: Vec<OpenAiContentPart>,
1056}
1057
1058#[derive(Serialize)]
1059struct VisionChatRequest<'a> {
1060    model: &'a str,
1061    messages: Vec<VisionApiMessage>,
1062    #[serde(flatten)]
1063    completion_tokens: CompletionTokens,
1064    #[serde(skip_serializing_if = "std::ops::Not::not")]
1065    stream: bool,
1066    #[serde(skip_serializing_if = "Option::is_none")]
1067    reasoning: Option<Reasoning<'a>>,
1068    #[serde(skip_serializing_if = "Option::is_none")]
1069    temperature: Option<f64>,
1070    #[serde(skip_serializing_if = "Option::is_none")]
1071    top_p: Option<f64>,
1072    #[serde(skip_serializing_if = "Option::is_none")]
1073    frequency_penalty: Option<f64>,
1074    #[serde(skip_serializing_if = "Option::is_none")]
1075    presence_penalty: Option<f64>,
1076}
1077
1078fn has_image_parts(messages: &[Message]) -> bool {
1079    messages
1080        .iter()
1081        .any(|m| m.parts.iter().any(|p| matches!(p, MessagePart::Image(_))))
1082}
1083
1084fn convert_messages_vision(messages: &[Message]) -> Vec<VisionApiMessage> {
1085    messages
1086        .iter()
1087        .map(|msg| {
1088            let role = match msg.role {
1089                Role::System => "system",
1090                Role::User => "user",
1091                Role::Assistant => "assistant",
1092            };
1093            let has_images = msg.parts.iter().any(|p| matches!(p, MessagePart::Image(_)));
1094            if has_images {
1095                let mut parts = Vec::new();
1096                let text_str: String = msg
1097                    .parts
1098                    .iter()
1099                    .filter_map(MessagePart::as_plain_text)
1100                    .collect::<Vec<_>>()
1101                    .join("");
1102                if !text_str.is_empty() {
1103                    parts.push(OpenAiContentPart::Text { text: text_str });
1104                }
1105                for part in &msg.parts {
1106                    if let Some(img) = part.as_image() {
1107                        let b64 = STANDARD.encode(&img.data);
1108                        parts.push(OpenAiContentPart::ImageUrl {
1109                            image_url: ImageUrlDetail {
1110                                url: format!("data:{};base64,{b64}", img.mime_type),
1111                            },
1112                        });
1113                    }
1114                }
1115                if parts.is_empty() {
1116                    parts.push(OpenAiContentPart::Text {
1117                        text: msg.to_llm_content().to_owned(),
1118                    });
1119                }
1120                VisionApiMessage {
1121                    role: role.to_owned(),
1122                    content: parts,
1123                }
1124            } else {
1125                VisionApiMessage {
1126                    role: role.to_owned(),
1127                    content: vec![OpenAiContentPart::Text {
1128                        text: msg.to_llm_content().to_owned(),
1129                    }],
1130                }
1131            }
1132        })
1133        .collect()
1134}
1135
1136fn convert_messages(messages: &[Message]) -> Vec<ApiMessage<'_>> {
1137    messages
1138        .iter()
1139        .map(|msg| {
1140            let role = match msg.role {
1141                Role::System => "system",
1142                Role::User => "user",
1143                Role::Assistant => "assistant",
1144            };
1145            ApiMessage {
1146                role,
1147                content: msg.to_llm_content(),
1148            }
1149        })
1150        .collect()
1151}
1152
1153#[derive(Serialize)]
1154struct ChatRequest<'a> {
1155    model: &'a str,
1156    messages: &'a [ApiMessage<'a>],
1157    #[serde(flatten)]
1158    completion_tokens: CompletionTokens,
1159    #[serde(skip_serializing_if = "std::ops::Not::not")]
1160    stream: bool,
1161    #[serde(skip_serializing_if = "Option::is_none")]
1162    reasoning: Option<Reasoning<'a>>,
1163    #[serde(skip_serializing_if = "Option::is_none")]
1164    temperature: Option<f64>,
1165    #[serde(skip_serializing_if = "Option::is_none")]
1166    top_p: Option<f64>,
1167    #[serde(skip_serializing_if = "Option::is_none")]
1168    frequency_penalty: Option<f64>,
1169    #[serde(skip_serializing_if = "Option::is_none")]
1170    presence_penalty: Option<f64>,
1171}
1172
1173#[derive(Serialize)]
1174struct Reasoning<'a> {
1175    effort: &'a str,
1176}
1177
1178#[derive(Serialize)]
1179struct ApiMessage<'a> {
1180    role: &'a str,
1181    content: &'a str,
1182}
1183
1184#[derive(Deserialize)]
1185struct OpenAiChatResponse {
1186    choices: Vec<ChatChoice>,
1187    #[serde(default)]
1188    usage: Option<OpenAiUsage>,
1189}
1190
1191#[derive(Deserialize)]
1192struct OpenAiUsage {
1193    #[serde(default)]
1194    prompt_tokens: u64,
1195    #[serde(default)]
1196    completion_tokens: u64,
1197    #[serde(default)]
1198    prompt_tokens_details: Option<PromptTokensDetails>,
1199    #[serde(default)]
1200    completion_tokens_details: Option<CompletionTokensDetails>,
1201}
1202
1203#[derive(Deserialize)]
1204struct PromptTokensDetails {
1205    #[serde(default)]
1206    cached_tokens: u64,
1207}
1208
1209#[derive(Deserialize)]
1210struct CompletionTokensDetails {
1211    /// Reasoning tokens are a subset of `completion_tokens`; do not add to cost.
1212    #[serde(default)]
1213    reasoning_tokens: u64,
1214}
1215
1216#[derive(Deserialize)]
1217struct ChatChoice {
1218    message: ChatMessage,
1219}
1220
1221#[derive(Deserialize)]
1222struct ChatMessage {
1223    content: String,
1224}
1225
1226#[derive(Serialize)]
1227struct OpenAiTool<'a> {
1228    r#type: &'a str,
1229    function: OpenAiFunction<'a>,
1230}
1231
1232#[derive(Serialize)]
1233struct OpenAiFunction<'a> {
1234    name: &'a str,
1235    description: &'a str,
1236    #[serde(skip_serializing_if = "Option::is_none")]
1237    parameters: Option<serde_json::Value>,
1238}
1239
1240#[derive(Serialize)]
1241struct ToolChatRequest<'a> {
1242    model: &'a str,
1243    messages: &'a [StructuredApiMessage],
1244    #[serde(flatten)]
1245    completion_tokens: CompletionTokens,
1246    tools: &'a [OpenAiTool<'a>],
1247    #[serde(skip_serializing_if = "Option::is_none")]
1248    reasoning: Option<Reasoning<'a>>,
1249    #[serde(skip_serializing_if = "Option::is_none")]
1250    temperature: Option<f64>,
1251    #[serde(skip_serializing_if = "Option::is_none")]
1252    top_p: Option<f64>,
1253    #[serde(skip_serializing_if = "Option::is_none")]
1254    frequency_penalty: Option<f64>,
1255    #[serde(skip_serializing_if = "Option::is_none")]
1256    presence_penalty: Option<f64>,
1257}
1258
1259#[derive(Serialize)]
1260struct StructuredApiMessage {
1261    role: String,
1262    #[serde(skip_serializing_if = "Option::is_none")]
1263    content: Option<String>,
1264    #[serde(skip_serializing_if = "Option::is_none")]
1265    tool_calls: Option<Vec<OpenAiToolCallOut>>,
1266    #[serde(skip_serializing_if = "Option::is_none")]
1267    tool_call_id: Option<String>,
1268}
1269
1270#[derive(Serialize)]
1271struct OpenAiToolCallOut {
1272    id: String,
1273    r#type: String,
1274    function: OpenAiFunctionCall,
1275}
1276
1277#[derive(Serialize)]
1278struct OpenAiFunctionCall {
1279    name: String,
1280    arguments: String,
1281}
1282
1283#[derive(Deserialize)]
1284struct ToolChatResponse {
1285    choices: Vec<ToolChatChoice>,
1286    #[serde(default)]
1287    usage: Option<OpenAiUsage>,
1288}
1289
1290#[derive(Deserialize)]
1291struct ToolChatChoice {
1292    message: ToolChatMessage,
1293    #[serde(default)]
1294    finish_reason: Option<String>,
1295}
1296
1297#[derive(Deserialize)]
1298struct ToolChatMessage {
1299    #[serde(default, deserialize_with = "deserialize_null_string_as_default")]
1300    content: String,
1301    #[serde(default)]
1302    tool_calls: Option<Vec<OpenAiToolCall>>,
1303}
1304
1305#[derive(Deserialize)]
1306struct OpenAiToolCall {
1307    id: String,
1308    function: OpenAiToolCallFunction,
1309}
1310
1311#[derive(Deserialize)]
1312struct OpenAiToolCallFunction {
1313    name: String,
1314    arguments: String,
1315}
1316
1317fn deserialize_null_string_as_default<'de, D>(deserializer: D) -> Result<String, D::Error>
1318where
1319    D: serde::Deserializer<'de>,
1320{
1321    Ok(Option::<String>::deserialize(deserializer)?.unwrap_or_default())
1322}
1323
1324fn convert_messages_structured(messages: &[Message]) -> Vec<StructuredApiMessage> {
1325    let mut result = Vec::new();
1326
1327    for msg in messages {
1328        let has_tool_parts = msg.parts.iter().any(|p| {
1329            matches!(
1330                p,
1331                MessagePart::ToolUse { .. } | MessagePart::ToolResult { .. }
1332            )
1333        });
1334
1335        if has_tool_parts {
1336            // Assistant messages with ToolUse parts → tool_calls field
1337            if msg.role == Role::Assistant {
1338                let text_content: String = msg
1339                    .parts
1340                    .iter()
1341                    .filter_map(|p| p.as_plain_text())
1342                    .collect::<Vec<_>>()
1343                    .join("");
1344
1345                let tool_calls: Vec<OpenAiToolCallOut> = msg
1346                    .parts
1347                    .iter()
1348                    .filter_map(|p| match p {
1349                        MessagePart::ToolUse { id, name, input } => Some(OpenAiToolCallOut {
1350                            id: id.clone(),
1351                            r#type: "function".to_owned(),
1352                            function: OpenAiFunctionCall {
1353                                name: name.clone(),
1354                                arguments: serde_json::to_string(input)
1355                                    .unwrap_or_else(|_| "{}".to_owned()),
1356                            },
1357                        }),
1358                        _ => None,
1359                    })
1360                    .collect();
1361
1362                result.push(StructuredApiMessage {
1363                    role: "assistant".to_owned(),
1364                    content: if text_content.is_empty() {
1365                        None
1366                    } else {
1367                        Some(text_content)
1368                    },
1369                    tool_calls: if tool_calls.is_empty() {
1370                        None
1371                    } else {
1372                        Some(tool_calls)
1373                    },
1374                    tool_call_id: None,
1375                });
1376            } else {
1377                // User messages with ToolResult parts → role: "tool" messages
1378                for part in &msg.parts {
1379                    match part {
1380                        MessagePart::ToolResult {
1381                            tool_use_id,
1382                            content,
1383                            ..
1384                        } => {
1385                            result.push(StructuredApiMessage {
1386                                role: "tool".to_owned(),
1387                                content: Some(content.clone()),
1388                                tool_calls: None,
1389                                tool_call_id: Some(tool_use_id.clone()),
1390                            });
1391                        }
1392                        other => {
1393                            if let Some(text) = other.as_plain_text().filter(|t| !t.is_empty()) {
1394                                result.push(StructuredApiMessage {
1395                                    role: "user".to_owned(),
1396                                    content: Some(text.to_owned()),
1397                                    tool_calls: None,
1398                                    tool_call_id: None,
1399                                });
1400                            }
1401                        }
1402                    }
1403                }
1404            }
1405        } else {
1406            let role = match msg.role {
1407                Role::System => "system",
1408                Role::User => "user",
1409                Role::Assistant => "assistant",
1410            };
1411            result.push(StructuredApiMessage {
1412                role: role.to_owned(),
1413                content: Some(msg.to_llm_content().to_owned()),
1414                tool_calls: None,
1415                tool_call_id: None,
1416            });
1417        }
1418    }
1419
1420    result
1421}
1422
1423#[derive(Serialize)]
1424struct EmbeddingRequest<'a> {
1425    input: &'a str,
1426    model: &'a str,
1427}
1428
1429#[derive(Deserialize)]
1430struct EmbeddingResponse {
1431    data: Vec<EmbeddingData>,
1432}
1433
1434#[derive(Deserialize)]
1435struct EmbeddingData {
1436    #[serde(default)]
1437    index: usize,
1438    embedding: Vec<f32>,
1439}
1440
1441#[derive(Serialize)]
1442struct EmbeddingBatchRequest<'a> {
1443    model: &'a str,
1444    input: Vec<&'a str>,
1445}
1446
1447#[derive(Serialize)]
1448struct TypedChatRequest<'a> {
1449    model: &'a str,
1450    messages: &'a [ApiMessage<'a>],
1451    #[serde(flatten)]
1452    completion_tokens: CompletionTokens,
1453    response_format: ResponseFormat<'a>,
1454}
1455
1456#[derive(Serialize)]
1457#[serde(untagged)]
1458enum CompletionTokens {
1459    MaxTokens { max_tokens: u32 },
1460    MaxCompletionTokens { max_completion_tokens: u32 },
1461}
1462
1463impl CompletionTokens {
1464    fn for_model(model: &str, max_tokens: u32) -> Self {
1465        // o-series models (o1, o2, o3, o4-mini, …) and gpt-5 require max_completion_tokens;
1466        // all other models use the legacy max_tokens field.
1467        if model.starts_with("gpt-5") || starts_with_o_digit(model) {
1468            Self::MaxCompletionTokens {
1469                max_completion_tokens: max_tokens,
1470            }
1471        } else {
1472            Self::MaxTokens { max_tokens }
1473        }
1474    }
1475}
1476
1477fn starts_with_o_digit(model: &str) -> bool {
1478    let mut chars = model.chars();
1479    chars.next() == Some('o') && chars.next().is_some_and(|c| c.is_ascii_digit())
1480}
1481
1482#[derive(Serialize)]
1483struct ResponseFormat<'a> {
1484    r#type: &'a str,
1485    json_schema: JsonSchemaFormat<'a>,
1486}
1487
1488#[derive(Serialize)]
1489struct JsonSchemaFormat<'a> {
1490    name: &'a str,
1491    schema: serde_json::Value,
1492    strict: bool,
1493}
1494
1495/// Inline all `$ref` references from `$defs` into the schema tree.
1496fn inline_refs_openai(schema: &mut serde_json::Value, depth: u8) {
1497    if depth == 0 {
1498        return;
1499    }
1500    let defs = if let Some(obj) = schema.as_object() {
1501        obj.get("$defs")
1502            .or_else(|| obj.get("definitions"))
1503            .cloned()
1504            .unwrap_or(serde_json::Value::Object(serde_json::Map::default()))
1505    } else {
1506        serde_json::Value::Object(serde_json::Map::default())
1507    };
1508    inline_refs_openai_inner(schema, &defs, depth);
1509    if let Some(obj) = schema.as_object_mut() {
1510        obj.remove("$defs");
1511        obj.remove("definitions");
1512    }
1513}
1514
1515fn inline_refs_openai_inner(schema: &mut serde_json::Value, defs: &serde_json::Value, depth: u8) {
1516    if depth == 0 {
1517        return;
1518    }
1519    if let Some(obj) = schema.as_object()
1520        && let Some(ref_val) = obj.get("$ref").and_then(|v| v.as_str())
1521    {
1522        let name = ref_val
1523            .trim_start_matches("#/$defs/")
1524            .trim_start_matches("#/definitions/");
1525        if let Some(resolved) = defs.get(name) {
1526            let mut resolved = resolved.clone();
1527            inline_refs_openai_inner(&mut resolved, defs, depth - 1);
1528            *schema = resolved;
1529            return;
1530        }
1531        *schema = serde_json::json!({"type": "object"});
1532        return;
1533    }
1534    if let Some(obj) = schema.as_object_mut() {
1535        for v in obj.values_mut() {
1536            inline_refs_openai_inner(v, defs, depth - 1);
1537        }
1538    } else if let Some(arr) = schema.as_array_mut() {
1539        for v in arr.iter_mut() {
1540            inline_refs_openai_inner(v, defs, depth - 1);
1541        }
1542    }
1543}
1544
1545/// Returns `true` when the schema represents an object with no parameters.
1546///
1547/// Matches `{"type": "object"}` with absent or empty `properties`.
1548fn is_empty_params_schema(schema: &serde_json::Value) -> bool {
1549    schema.get("type").and_then(|t| t.as_str()) == Some("object")
1550        && schema
1551            .get("properties")
1552            .and_then(|p| p.as_object())
1553            .is_none_or(serde_json::Map::is_empty)
1554}
1555
1556/// Prepare tool parameters schema for the `OpenAI` API.
1557///
1558/// Returns `None` for empty-parameter tools so the `parameters` field is
1559/// omitted entirely, avoiding strict-mode 400 errors.  For non-empty schemas,
1560/// inlines `$ref` definitions and normalizes for strict mode.
1561fn prepare_tool_params(params: &serde_json::Value) -> Option<serde_json::Value> {
1562    if is_empty_params_schema(params) {
1563        return None;
1564    }
1565    let mut schema = params.clone();
1566    inline_refs_openai(&mut schema, 8);
1567    normalize_for_openai_strict(&mut schema, 16);
1568    Some(schema)
1569}
1570
1571struct OpenAiStrictVisitor;
1572
1573impl crate::schema::SchemaVisitor for OpenAiStrictVisitor {
1574    fn visit(&mut self, schema: &mut serde_json::Value) -> bool {
1575        let Some(obj) = schema.as_object_mut() else {
1576            return false;
1577        };
1578        let remove_keys: &[&str] = &["$schema", "title", "format", "default", "examples", "$id"];
1579        for key in remove_keys {
1580            obj.remove(*key);
1581        }
1582        let is_object = obj.get("type").and_then(|t| t.as_str()) == Some("object");
1583        if is_object {
1584            obj.insert(
1585                "additionalProperties".to_owned(),
1586                serde_json::Value::Bool(false),
1587            );
1588            let prop_keys: Vec<String> = obj
1589                .get("properties")
1590                .and_then(|p| p.as_object())
1591                .map(|p| p.keys().cloned().collect())
1592                .unwrap_or_default();
1593            if !prop_keys.is_empty() {
1594                obj.insert(
1595                    "required".to_owned(),
1596                    serde_json::Value::Array(
1597                        prop_keys
1598                            .into_iter()
1599                            .map(serde_json::Value::String)
1600                            .collect(),
1601                    ),
1602                );
1603            }
1604        }
1605        true
1606    }
1607}
1608
1609/// Normalize a JSON Schema for `OpenAI` structured output strict mode.
1610///
1611/// Requirements:
1612/// - `additionalProperties: false` on every object
1613/// - All properties listed in `required`
1614/// - No `$schema`, `title`, or other non-strict keys at top level
1615fn normalize_for_openai_strict(schema: &mut serde_json::Value, depth: u8) {
1616    crate::schema::walk_schema(schema, &mut OpenAiStrictVisitor, depth);
1617}
1618
1619#[cfg(test)]
1620mod tests;