sgr_agent/
oxide_chat_client.rs1use crate::client::LlmClient;
7use crate::tool::ToolDef;
8use crate::types::{LlmConfig, Message, Role, SgrError, ToolCall};
9use openai_oxide::OpenAI;
10use openai_oxide::config::ClientConfig;
11use openai_oxide::types::chat::*;
12use serde_json::Value;
13
14pub struct OxideChatClient {
16 client: OpenAI,
17 pub(crate) model: String,
18 pub(crate) temperature: Option<f64>,
19 pub(crate) max_tokens: Option<u32>,
20}
21
22impl OxideChatClient {
23 pub fn from_config(config: &LlmConfig) -> Result<Self, SgrError> {
25 let api_key = config
26 .api_key
27 .clone()
28 .or_else(|| std::env::var("OPENAI_API_KEY").ok())
29 .unwrap_or_else(|| {
30 if config.base_url.is_some() {
31 "dummy_key".into()
32 } else {
33 "".into()
34 }
35 });
36
37 if api_key.is_empty() {
38 return Err(SgrError::Schema("No API key for oxide chat client".into()));
39 }
40
41 let mut client_config = ClientConfig::new(&api_key);
42 if let Some(ref url) = config.base_url {
43 client_config = client_config.base_url(url.clone());
44 }
45
46 Ok(Self {
47 client: OpenAI::with_config(client_config),
48 model: config.model.clone(),
49 temperature: Some(config.temp),
50 max_tokens: config.max_tokens,
51 })
52 }
53
54 fn build_messages(&self, messages: &[Message]) -> Vec<ChatCompletionMessageParam> {
55 messages
56 .iter()
57 .map(|m| match m.role {
58 Role::System => ChatCompletionMessageParam::System {
59 content: m.content.clone(),
60 name: None,
61 },
62 Role::User => ChatCompletionMessageParam::User {
63 content: UserContent::Text(m.content.clone()),
64 name: None,
65 },
66 Role::Assistant => {
67 let tc = if m.tool_calls.is_empty() {
68 None
69 } else {
70 Some(
71 m.tool_calls
72 .iter()
73 .map(|tc| openai_oxide::types::chat::ToolCall {
74 id: tc.id.clone(),
75 type_: "function".into(),
76 function: openai_oxide::types::chat::FunctionCall {
77 name: tc.name.clone(),
78 arguments: tc.arguments.to_string(),
79 },
80 })
81 .collect(),
82 )
83 };
84 ChatCompletionMessageParam::Assistant {
85 content: if m.content.is_empty() {
86 None
87 } else {
88 Some(m.content.clone())
89 },
90 name: None,
91 tool_calls: tc,
92 refusal: None,
93 }
94 }
95 Role::Tool => ChatCompletionMessageParam::Tool {
96 content: m.content.clone(),
97 tool_call_id: m.tool_call_id.clone().unwrap_or_default(),
98 },
99 })
100 .collect()
101 }
102
103 fn build_request(&self, messages: &[Message]) -> ChatCompletionRequest {
104 let mut req = ChatCompletionRequest::new(&self.model, self.build_messages(messages));
105 if let Some(temp) = self.temperature {
106 req.temperature = Some(temp);
107 }
108 if let Some(max) = self.max_tokens {
109 if self.model.starts_with("gpt-5") || self.model.starts_with("o") {
111 req = req.max_completion_tokens(max as i64);
112 } else {
113 req.max_tokens = Some(max as i64);
114 }
115 }
116 req
117 }
118
119 fn extract_tool_calls(response: &ChatCompletionResponse) -> Vec<ToolCall> {
120 let Some(choice) = response.choices.first() else {
121 return Vec::new();
122 };
123 let Some(ref calls) = choice.message.tool_calls else {
124 return Vec::new();
125 };
126 calls
127 .iter()
128 .map(|tc| ToolCall {
129 id: tc.id.clone(),
130 name: tc.function.name.clone(),
131 arguments: serde_json::from_str(&tc.function.arguments).unwrap_or(Value::Null),
132 })
133 .collect()
134 }
135}
136
137#[async_trait::async_trait]
138impl LlmClient for OxideChatClient {
139 async fn structured_call(
140 &self,
141 messages: &[Message],
142 schema: &Value,
143 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
144 let strict_schema =
146 if schema.get("additionalProperties").and_then(|v| v.as_bool()) == Some(false) {
147 schema.clone()
148 } else {
149 let mut s = schema.clone();
150 openai_oxide::parsing::ensure_strict(&mut s);
151 s
152 };
153
154 let mut req = self.build_request(messages);
155 req.response_format = Some(ResponseFormat::JsonSchema {
156 json_schema: JsonSchema {
157 name: "response".into(),
158 description: None,
159 schema: Some(strict_schema),
160 strict: Some(true),
161 },
162 });
163
164 let response = self
165 .client
166 .chat()
167 .completions()
168 .create(req)
169 .await
170 .map_err(|e| SgrError::Api {
171 status: 0,
172 body: e.to_string(),
173 })?;
174
175 let raw_text = response
176 .choices
177 .first()
178 .and_then(|c| c.message.content.clone())
179 .unwrap_or_default();
180 let tool_calls = Self::extract_tool_calls(&response);
181 let parsed = serde_json::from_str::<Value>(&raw_text).ok();
182
183 tracing::info!(
184 model = %response.model,
185 "oxide_chat.structured_call"
186 );
187
188 Ok((parsed, tool_calls, raw_text))
189 }
190
191 async fn tools_call(
192 &self,
193 messages: &[Message],
194 tools: &[ToolDef],
195 ) -> Result<Vec<ToolCall>, SgrError> {
196 let mut req = self.build_request(messages);
197
198 let chat_tools: Vec<Tool> = tools
199 .iter()
200 .map(|t| {
201 Tool::function(
202 &t.name,
203 if t.description.is_empty() {
204 "No description"
205 } else {
206 &t.description
207 },
208 t.parameters.clone(),
209 )
210 })
211 .collect();
212 req.tools = Some(chat_tools);
213 req.tool_choice = Some(openai_oxide::types::chat::ToolChoice::Mode(
214 "required".into(),
215 ));
216
217 let response = self
218 .client
219 .chat()
220 .completions()
221 .create(req)
222 .await
223 .map_err(|e| SgrError::Api {
224 status: 0,
225 body: e.to_string(),
226 })?;
227
228 tracing::info!(model = %response.model, "oxide_chat.tools_call");
229
230 let calls = Self::extract_tool_calls(&response);
231 Ok(calls)
233 }
234
235 async fn complete(&self, messages: &[Message]) -> Result<String, SgrError> {
236 let req = self.build_request(messages);
237
238 let response = self
239 .client
240 .chat()
241 .completions()
242 .create(req)
243 .await
244 .map_err(|e| SgrError::Api {
245 status: 0,
246 body: e.to_string(),
247 })?;
248
249 tracing::info!(model = %response.model, "oxide_chat.complete");
250
251 Ok(response
252 .choices
253 .first()
254 .and_then(|c| c.message.content.clone())
255 .unwrap_or_default())
256 }
257}