Skip to main content

sgr_agent/
oxide_chat_client.rs

1//! OxideChatClient — LlmClient via Chat Completions API (not Responses).
2//!
3//! For OpenAI-compatible endpoints that don't support /responses:
4//! Cloudflare AI Gateway compat, OpenRouter, local models, Workers AI.
5
6use crate::client::LlmClient;
7use crate::tool::ToolDef;
8use crate::types::{LlmConfig, Message, Role, SgrError, ToolCall};
9use openai_oxide::OpenAI;
10use openai_oxide::config::ClientConfig;
11use openai_oxide::types::chat::*;
12use serde_json::Value;
13
14/// LlmClient backed by openai-oxide Chat Completions API.
15pub struct OxideChatClient {
16    client: OpenAI,
17    pub(crate) model: String,
18    pub(crate) temperature: Option<f64>,
19    pub(crate) max_tokens: Option<u32>,
20}
21
22impl OxideChatClient {
23    /// Create from LlmConfig.
24    pub fn from_config(config: &LlmConfig) -> Result<Self, SgrError> {
25        let api_key = config
26            .api_key
27            .clone()
28            .or_else(|| std::env::var("OPENAI_API_KEY").ok())
29            .unwrap_or_else(|| {
30                if config.base_url.is_some() {
31                    "dummy_key".into()
32                } else {
33                    "".into()
34                }
35            });
36
37        if api_key.is_empty() {
38            return Err(SgrError::Schema("No API key for oxide chat client".into()));
39        }
40
41        let mut client_config = ClientConfig::new(&api_key);
42        if let Some(ref url) = config.base_url {
43            client_config = client_config.base_url(url.clone());
44        }
45
46        Ok(Self {
47            client: OpenAI::with_config(client_config),
48            model: config.model.clone(),
49            temperature: Some(config.temp),
50            max_tokens: config.max_tokens,
51        })
52    }
53
54    fn build_messages(&self, messages: &[Message]) -> Vec<ChatCompletionMessageParam> {
55        messages
56            .iter()
57            .map(|m| match m.role {
58                Role::System => ChatCompletionMessageParam::System {
59                    content: m.content.clone(),
60                    name: None,
61                },
62                Role::User => ChatCompletionMessageParam::User {
63                    content: UserContent::Text(m.content.clone()),
64                    name: None,
65                },
66                Role::Assistant => {
67                    let tc = if m.tool_calls.is_empty() {
68                        None
69                    } else {
70                        Some(
71                            m.tool_calls
72                                .iter()
73                                .map(|tc| openai_oxide::types::chat::ToolCall {
74                                    id: tc.id.clone(),
75                                    type_: "function".into(),
76                                    function: openai_oxide::types::chat::FunctionCall {
77                                        name: tc.name.clone(),
78                                        arguments: tc.arguments.to_string(),
79                                    },
80                                })
81                                .collect(),
82                        )
83                    };
84                    ChatCompletionMessageParam::Assistant {
85                        content: if m.content.is_empty() {
86                            None
87                        } else {
88                            Some(m.content.clone())
89                        },
90                        name: None,
91                        tool_calls: tc,
92                        refusal: None,
93                    }
94                }
95                Role::Tool => ChatCompletionMessageParam::Tool {
96                    content: m.content.clone(),
97                    tool_call_id: m.tool_call_id.clone().unwrap_or_default(),
98                },
99            })
100            .collect()
101    }
102
103    fn build_request(&self, messages: &[Message]) -> ChatCompletionRequest {
104        let mut req = ChatCompletionRequest::new(&self.model, self.build_messages(messages));
105        if let Some(temp) = self.temperature {
106            req.temperature = Some(temp);
107        }
108        if let Some(max) = self.max_tokens {
109            // Use max_completion_tokens for newer models (gpt-5.x+), max_tokens for legacy
110            if self.model.starts_with("gpt-5") || self.model.starts_with("o") {
111                req = req.max_completion_tokens(max as i64);
112            } else {
113                req.max_tokens = Some(max as i64);
114            }
115        }
116        req
117    }
118
119    fn extract_tool_calls(response: &ChatCompletionResponse) -> Vec<ToolCall> {
120        let Some(choice) = response.choices.first() else {
121            return Vec::new();
122        };
123        let Some(ref calls) = choice.message.tool_calls else {
124            return Vec::new();
125        };
126        calls
127            .iter()
128            .map(|tc| ToolCall {
129                id: tc.id.clone(),
130                name: tc.function.name.clone(),
131                arguments: serde_json::from_str(&tc.function.arguments).unwrap_or(Value::Null),
132            })
133            .collect()
134    }
135}
136
137#[async_trait::async_trait]
138impl LlmClient for OxideChatClient {
139    async fn structured_call(
140        &self,
141        messages: &[Message],
142        schema: &Value,
143    ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
144        // Skip ensure_strict for pre-strict schemas (e.g., from build_action_schema)
145        let strict_schema =
146            if schema.get("additionalProperties").and_then(|v| v.as_bool()) == Some(false) {
147                schema.clone()
148            } else {
149                let mut s = schema.clone();
150                openai_oxide::parsing::ensure_strict(&mut s);
151                s
152            };
153
154        let mut req = self.build_request(messages);
155        req.response_format = Some(ResponseFormat::JsonSchema {
156            json_schema: JsonSchema {
157                name: "response".into(),
158                description: None,
159                schema: Some(strict_schema),
160                strict: Some(true),
161            },
162        });
163
164        let response = self
165            .client
166            .chat()
167            .completions()
168            .create(req)
169            .await
170            .map_err(|e| SgrError::Api {
171                status: 0,
172                body: e.to_string(),
173            })?;
174
175        let raw_text = response
176            .choices
177            .first()
178            .and_then(|c| c.message.content.clone())
179            .unwrap_or_default();
180        let tool_calls = Self::extract_tool_calls(&response);
181        let parsed = serde_json::from_str::<Value>(&raw_text).ok();
182
183        tracing::info!(
184            model = %response.model,
185            "oxide_chat.structured_call"
186        );
187
188        Ok((parsed, tool_calls, raw_text))
189    }
190
191    async fn tools_call(
192        &self,
193        messages: &[Message],
194        tools: &[ToolDef],
195    ) -> Result<Vec<ToolCall>, SgrError> {
196        let mut req = self.build_request(messages);
197
198        let chat_tools: Vec<Tool> = tools
199            .iter()
200            .map(|t| {
201                Tool::function(
202                    &t.name,
203                    if t.description.is_empty() {
204                        "No description"
205                    } else {
206                        &t.description
207                    },
208                    t.parameters.clone(),
209                )
210            })
211            .collect();
212        req.tools = Some(chat_tools);
213        req.tool_choice = Some(openai_oxide::types::chat::ToolChoice::Mode(
214            "required".into(),
215        ));
216
217        let response = self
218            .client
219            .chat()
220            .completions()
221            .create(req)
222            .await
223            .map_err(|e| SgrError::Api {
224                status: 0,
225                body: e.to_string(),
226            })?;
227
228        tracing::info!(model = %response.model, "oxide_chat.tools_call");
229
230        let calls = Self::extract_tool_calls(&response);
231        // Don't synthesize finish — empty tool_calls signals completion to ToolCallingAgent.
232        Ok(calls)
233    }
234
235    async fn complete(&self, messages: &[Message]) -> Result<String, SgrError> {
236        let req = self.build_request(messages);
237
238        let response = self
239            .client
240            .chat()
241            .completions()
242            .create(req)
243            .await
244            .map_err(|e| SgrError::Api {
245                status: 0,
246                body: e.to_string(),
247            })?;
248
249        tracing::info!(model = %response.model, "oxide_chat.complete");
250
251        Ok(response
252            .choices
253            .first()
254            .and_then(|c| c.message.content.clone())
255            .unwrap_or_default())
256    }
257}