sgr_agent/
oxide_chat_client.rs1use crate::client::LlmClient;
7use crate::tool::ToolDef;
8use crate::types::{LlmConfig, Message, Role, SgrError, ToolCall};
9use openai_oxide::OpenAI;
10use openai_oxide::config::ClientConfig;
11use openai_oxide::types::chat::*;
12use serde_json::Value;
13
14pub struct OxideChatClient {
16 client: OpenAI,
17 pub(crate) model: String,
18 pub(crate) temperature: Option<f64>,
19 pub(crate) max_tokens: Option<u32>,
20}
21
22impl OxideChatClient {
23 pub fn from_config(config: &LlmConfig) -> Result<Self, SgrError> {
25 let api_key = config
26 .api_key
27 .clone()
28 .or_else(|| std::env::var("OPENAI_API_KEY").ok())
29 .unwrap_or_else(|| {
30 if config.base_url.is_some() {
31 "dummy_key".into()
32 } else {
33 "".into()
34 }
35 });
36
37 if api_key.is_empty() {
38 return Err(SgrError::Schema("No API key for oxide chat client".into()));
39 }
40
41 let mut client_config = ClientConfig::new(&api_key);
42 if let Some(ref url) = config.base_url {
43 client_config = client_config.base_url(url.clone());
44 }
45
46 Ok(Self {
47 client: OpenAI::with_config(client_config),
48 model: config.model.clone(),
49 temperature: Some(config.temp),
50 max_tokens: config.max_tokens,
51 })
52 }
53
54 fn build_messages(&self, messages: &[Message]) -> Vec<ChatCompletionMessageParam> {
55 messages
56 .iter()
57 .map(|m| match m.role {
58 Role::System => ChatCompletionMessageParam::System {
59 content: m.content.clone(),
60 name: None,
61 },
62 Role::User => ChatCompletionMessageParam::User {
63 content: UserContent::Text(m.content.clone()),
64 name: None,
65 },
66 Role::Assistant => ChatCompletionMessageParam::Assistant {
67 content: Some(m.content.clone()),
68 name: None,
69 tool_calls: None,
70 refusal: None,
71 },
72 Role::Tool => ChatCompletionMessageParam::Tool {
73 content: m.content.clone(),
74 tool_call_id: m.tool_call_id.clone().unwrap_or_default(),
75 },
76 })
77 .collect()
78 }
79
80 fn build_request(&self, messages: &[Message]) -> ChatCompletionRequest {
81 let mut req = ChatCompletionRequest::new(&self.model, self.build_messages(messages));
82 if let Some(temp) = self.temperature {
83 req.temperature = Some(temp);
84 }
85 if let Some(max) = self.max_tokens {
86 req.max_tokens = Some(max as i64);
87 }
88 req
89 }
90
91 fn extract_tool_calls(response: &ChatCompletionResponse) -> Vec<ToolCall> {
92 let Some(choice) = response.choices.first() else {
93 return Vec::new();
94 };
95 let Some(ref calls) = choice.message.tool_calls else {
96 return Vec::new();
97 };
98 calls
99 .iter()
100 .map(|tc| ToolCall {
101 id: tc.id.clone(),
102 name: tc.function.name.clone(),
103 arguments: serde_json::from_str(&tc.function.arguments).unwrap_or(Value::Null),
104 })
105 .collect()
106 }
107}
108
109#[async_trait::async_trait]
110impl LlmClient for OxideChatClient {
111 async fn structured_call(
112 &self,
113 messages: &[Message],
114 schema: &Value,
115 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
116 let mut strict_schema = schema.clone();
117 openai_oxide::parsing::ensure_strict(&mut strict_schema);
118
119 let mut req = self.build_request(messages);
120 req.response_format = Some(ResponseFormat::JsonSchema {
121 json_schema: JsonSchema {
122 name: "response".into(),
123 description: None,
124 schema: Some(strict_schema),
125 strict: Some(true),
126 },
127 });
128
129 let response = self
130 .client
131 .chat()
132 .completions()
133 .create(req)
134 .await
135 .map_err(|e| SgrError::Api {
136 status: 0,
137 body: e.to_string(),
138 })?;
139
140 let raw_text = response
141 .choices
142 .first()
143 .and_then(|c| c.message.content.clone())
144 .unwrap_or_default();
145 let tool_calls = Self::extract_tool_calls(&response);
146 let parsed = serde_json::from_str::<Value>(&raw_text).ok();
147
148 tracing::info!(
149 model = %response.model,
150 "oxide_chat.structured_call"
151 );
152
153 Ok((parsed, tool_calls, raw_text))
154 }
155
156 async fn tools_call(
157 &self,
158 messages: &[Message],
159 tools: &[ToolDef],
160 ) -> Result<Vec<ToolCall>, SgrError> {
161 let mut req = self.build_request(messages);
162
163 let chat_tools: Vec<Tool> = tools
164 .iter()
165 .map(|t| {
166 Tool::function(
167 &t.name,
168 if t.description.is_empty() {
169 "No description"
170 } else {
171 &t.description
172 },
173 t.parameters.clone(),
174 )
175 })
176 .collect();
177 req.tools = Some(chat_tools);
178 req.tool_choice = Some(openai_oxide::types::chat::ToolChoice::Mode(
179 "required".into(),
180 ));
181
182 let response = self
183 .client
184 .chat()
185 .completions()
186 .create(req)
187 .await
188 .map_err(|e| SgrError::Api {
189 status: 0,
190 body: e.to_string(),
191 })?;
192
193 tracing::info!(model = %response.model, "oxide_chat.tools_call");
194
195 let calls = Self::extract_tool_calls(&response);
196 Ok(calls)
198 }
199
200 async fn complete(&self, messages: &[Message]) -> Result<String, SgrError> {
201 let req = self.build_request(messages);
202
203 let response = self
204 .client
205 .chat()
206 .completions()
207 .create(req)
208 .await
209 .map_err(|e| SgrError::Api {
210 status: 0,
211 body: e.to_string(),
212 })?;
213
214 tracing::info!(model = %response.model, "oxide_chat.complete");
215
216 Ok(response
217 .choices
218 .first()
219 .and_then(|c| c.message.content.clone())
220 .unwrap_or_default())
221 }
222}