Skip to main content

sgr_agent/
oxide_chat_client.rs

1//! OxideChatClient — LlmClient via Chat Completions API (not Responses).
2//!
3//! For OpenAI-compatible endpoints that don't support /responses:
4//! Cloudflare AI Gateway compat, OpenRouter, local models, Workers AI.
5
6use crate::client::LlmClient;
7use crate::tool::ToolDef;
8use crate::types::{LlmConfig, Message, Role, SgrError, ToolCall};
9use openai_oxide::OpenAI;
10use openai_oxide::config::ClientConfig;
11use openai_oxide::types::chat::*;
12use serde_json::Value;
13
14/// LlmClient backed by openai-oxide Chat Completions API.
15pub struct OxideChatClient {
16    client: OpenAI,
17    pub(crate) model: String,
18    pub(crate) temperature: Option<f64>,
19    pub(crate) max_tokens: Option<u32>,
20}
21
22impl OxideChatClient {
23    /// Create from LlmConfig.
24    pub fn from_config(config: &LlmConfig) -> Result<Self, SgrError> {
25        let api_key = config
26            .api_key
27            .clone()
28            .or_else(|| std::env::var("OPENAI_API_KEY").ok())
29            .unwrap_or_else(|| {
30                if config.base_url.is_some() {
31                    "dummy_key".into()
32                } else {
33                    "".into()
34                }
35            });
36
37        if api_key.is_empty() {
38            return Err(SgrError::Schema("No API key for oxide chat client".into()));
39        }
40
41        let mut client_config = ClientConfig::new(&api_key);
42        if let Some(ref url) = config.base_url {
43            client_config = client_config.base_url(url.clone());
44        }
45
46        Ok(Self {
47            client: OpenAI::with_config(client_config),
48            model: config.model.clone(),
49            temperature: Some(config.temp),
50            max_tokens: config.max_tokens,
51        })
52    }
53
54    fn build_messages(&self, messages: &[Message]) -> Vec<ChatCompletionMessageParam> {
55        messages
56            .iter()
57            .map(|m| match m.role {
58                Role::System => ChatCompletionMessageParam::System {
59                    content: m.content.clone(),
60                    name: None,
61                },
62                Role::User => ChatCompletionMessageParam::User {
63                    content: UserContent::Text(m.content.clone()),
64                    name: None,
65                },
66                Role::Assistant => ChatCompletionMessageParam::Assistant {
67                    content: Some(m.content.clone()),
68                    name: None,
69                    tool_calls: None,
70                    refusal: None,
71                },
72                Role::Tool => ChatCompletionMessageParam::Tool {
73                    content: m.content.clone(),
74                    tool_call_id: m.tool_call_id.clone().unwrap_or_default(),
75                },
76            })
77            .collect()
78    }
79
80    fn build_request(&self, messages: &[Message]) -> ChatCompletionRequest {
81        let mut req = ChatCompletionRequest::new(&self.model, self.build_messages(messages));
82        if let Some(temp) = self.temperature {
83            req.temperature = Some(temp);
84        }
85        if let Some(max) = self.max_tokens {
86            req.max_tokens = Some(max as i64);
87        }
88        req
89    }
90
91    fn extract_tool_calls(response: &ChatCompletionResponse) -> Vec<ToolCall> {
92        let Some(choice) = response.choices.first() else {
93            return Vec::new();
94        };
95        let Some(ref calls) = choice.message.tool_calls else {
96            return Vec::new();
97        };
98        calls
99            .iter()
100            .map(|tc| ToolCall {
101                id: tc.id.clone(),
102                name: tc.function.name.clone(),
103                arguments: serde_json::from_str(&tc.function.arguments).unwrap_or(Value::Null),
104            })
105            .collect()
106    }
107}
108
109#[async_trait::async_trait]
110impl LlmClient for OxideChatClient {
111    async fn structured_call(
112        &self,
113        messages: &[Message],
114        schema: &Value,
115    ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
116        let mut strict_schema = schema.clone();
117        openai_oxide::parsing::ensure_strict(&mut strict_schema);
118
119        let mut req = self.build_request(messages);
120        req.response_format = Some(ResponseFormat::JsonSchema {
121            json_schema: JsonSchema {
122                name: "response".into(),
123                description: None,
124                schema: Some(strict_schema),
125                strict: Some(true),
126            },
127        });
128
129        let response = self
130            .client
131            .chat()
132            .completions()
133            .create(req)
134            .await
135            .map_err(|e| SgrError::Api {
136                status: 0,
137                body: e.to_string(),
138            })?;
139
140        let raw_text = response
141            .choices
142            .first()
143            .and_then(|c| c.message.content.clone())
144            .unwrap_or_default();
145        let tool_calls = Self::extract_tool_calls(&response);
146        let parsed = serde_json::from_str::<Value>(&raw_text).ok();
147
148        tracing::info!(
149            model = %response.model,
150            "oxide_chat.structured_call"
151        );
152
153        Ok((parsed, tool_calls, raw_text))
154    }
155
156    async fn tools_call(
157        &self,
158        messages: &[Message],
159        tools: &[ToolDef],
160    ) -> Result<Vec<ToolCall>, SgrError> {
161        let mut req = self.build_request(messages);
162
163        let chat_tools: Vec<Tool> = tools
164            .iter()
165            .map(|t| {
166                Tool::function(
167                    &t.name,
168                    if t.description.is_empty() {
169                        "No description"
170                    } else {
171                        &t.description
172                    },
173                    t.parameters.clone(),
174                )
175            })
176            .collect();
177        req.tools = Some(chat_tools);
178        req.tool_choice = Some(openai_oxide::types::chat::ToolChoice::Mode(
179            "required".into(),
180        ));
181
182        let response = self
183            .client
184            .chat()
185            .completions()
186            .create(req)
187            .await
188            .map_err(|e| SgrError::Api {
189                status: 0,
190                body: e.to_string(),
191            })?;
192
193        tracing::info!(model = %response.model, "oxide_chat.tools_call");
194
195        let calls = Self::extract_tool_calls(&response);
196        // Don't synthesize finish — empty tool_calls signals completion to ToolCallingAgent.
197        Ok(calls)
198    }
199
200    async fn complete(&self, messages: &[Message]) -> Result<String, SgrError> {
201        let req = self.build_request(messages);
202
203        let response = self
204            .client
205            .chat()
206            .completions()
207            .create(req)
208            .await
209            .map_err(|e| SgrError::Api {
210                status: 0,
211                body: e.to_string(),
212            })?;
213
214        tracing::info!(model = %response.model, "oxide_chat.complete");
215
216        Ok(response
217            .choices
218            .first()
219            .and_then(|c| c.message.content.clone())
220            .unwrap_or_default())
221    }
222}