1use std::sync::Arc;
2
3use async_trait::async_trait;
4use serde_json::{json, Value};
5use synaptic_core::{
6 AIMessageChunk, ChatModel, ChatRequest, ChatResponse, ChatStream, Message, SynapticError,
7 TokenUsage, ToolCall, ToolChoice, ToolDefinition,
8};
9use synaptic_models::{ProviderBackend, ProviderRequest, ProviderResponse};
10
11#[derive(Debug, Clone)]
12pub struct OpenAiConfig {
13 pub api_key: String,
14 pub model: String,
15 pub base_url: String,
16 pub max_tokens: Option<u32>,
17 pub temperature: Option<f64>,
18 pub top_p: Option<f64>,
19 pub stop: Option<Vec<String>>,
20 pub seed: Option<u64>,
21}
22
23impl OpenAiConfig {
24 pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
25 Self {
26 api_key: api_key.into(),
27 model: model.into(),
28 base_url: "https://api.openai.com/v1".to_string(),
29 max_tokens: None,
30 temperature: None,
31 top_p: None,
32 stop: None,
33 seed: None,
34 }
35 }
36
37 pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
38 self.base_url = url.into();
39 self
40 }
41
42 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
43 self.max_tokens = Some(max_tokens);
44 self
45 }
46
47 pub fn with_temperature(mut self, temperature: f64) -> Self {
48 self.temperature = Some(temperature);
49 self
50 }
51
52 pub fn with_top_p(mut self, top_p: f64) -> Self {
53 self.top_p = Some(top_p);
54 self
55 }
56
57 pub fn with_stop(mut self, stop: Vec<String>) -> Self {
58 self.stop = Some(stop);
59 self
60 }
61
62 pub fn with_seed(mut self, seed: u64) -> Self {
63 self.seed = Some(seed);
64 self
65 }
66}
67
68pub struct OpenAiChatModel {
69 config: OpenAiConfig,
70 backend: Arc<dyn ProviderBackend>,
71}
72
73impl OpenAiChatModel {
74 pub fn new(config: OpenAiConfig, backend: Arc<dyn ProviderBackend>) -> Self {
75 Self { config, backend }
76 }
77
78 fn build_request(&self, request: &ChatRequest, stream: bool) -> ProviderRequest {
79 let messages: Vec<Value> = request.messages.iter().map(message_to_openai).collect();
80
81 let mut body = json!({
82 "model": self.config.model,
83 "messages": messages,
84 "stream": stream,
85 });
86
87 if let Some(max_tokens) = self.config.max_tokens {
88 body["max_tokens"] = json!(max_tokens);
89 }
90 if let Some(temp) = self.config.temperature {
91 body["temperature"] = json!(temp);
92 }
93 if let Some(top_p) = self.config.top_p {
94 body["top_p"] = json!(top_p);
95 }
96 if let Some(ref stop) = self.config.stop {
97 body["stop"] = json!(stop);
98 }
99 if let Some(seed) = self.config.seed {
100 body["seed"] = json!(seed);
101 }
102 if !request.tools.is_empty() {
103 body["tools"] = json!(request
104 .tools
105 .iter()
106 .map(tool_def_to_openai)
107 .collect::<Vec<_>>());
108 }
109 if let Some(ref choice) = request.tool_choice {
110 body["tool_choice"] = match choice {
111 ToolChoice::Auto => json!("auto"),
112 ToolChoice::Required => json!("required"),
113 ToolChoice::None => json!("none"),
114 ToolChoice::Specific(name) => json!({
115 "type": "function",
116 "function": {"name": name}
117 }),
118 };
119 }
120
121 ProviderRequest {
122 url: format!("{}/chat/completions", self.config.base_url),
123 headers: vec![
124 (
125 "Authorization".to_string(),
126 format!("Bearer {}", self.config.api_key),
127 ),
128 ("Content-Type".to_string(), "application/json".to_string()),
129 ],
130 body,
131 }
132 }
133}
134
135pub(crate) fn message_to_openai(msg: &Message) -> Value {
136 match msg {
137 Message::System { content, .. } => json!({
138 "role": "system",
139 "content": content,
140 }),
141 Message::Human { content, .. } => json!({
142 "role": "user",
143 "content": content,
144 }),
145 Message::AI {
146 content,
147 tool_calls,
148 ..
149 } => {
150 let mut obj = json!({
151 "role": "assistant",
152 "content": content,
153 });
154 if !tool_calls.is_empty() {
155 obj["tool_calls"] = json!(tool_calls
156 .iter()
157 .map(|tc| json!({
158 "id": tc.id,
159 "type": "function",
160 "function": {
161 "name": tc.name,
162 "arguments": tc.arguments.to_string(),
163 }
164 }))
165 .collect::<Vec<_>>());
166 }
167 obj
168 }
169 Message::Tool {
170 content,
171 tool_call_id,
172 ..
173 } => json!({
174 "role": "tool",
175 "content": content,
176 "tool_call_id": tool_call_id,
177 }),
178 Message::Chat {
179 custom_role,
180 content,
181 ..
182 } => json!({
183 "role": custom_role,
184 "content": content,
185 }),
186 Message::Remove { .. } => json!(null),
187 }
188}
189
190pub(crate) fn tool_def_to_openai(def: &ToolDefinition) -> Value {
191 json!({
192 "type": "function",
193 "function": {
194 "name": def.name,
195 "description": def.description,
196 "parameters": def.parameters,
197 }
198 })
199}
200
201pub(crate) fn parse_response(resp: &ProviderResponse) -> Result<ChatResponse, SynapticError> {
202 check_error_status(resp)?;
203
204 let choice = &resp.body["choices"][0]["message"];
205 let content = choice["content"].as_str().unwrap_or("").to_string();
206 let tool_calls = parse_tool_calls(choice);
207
208 let usage = parse_usage(&resp.body["usage"]);
209
210 let message = if tool_calls.is_empty() {
211 Message::ai(content)
212 } else {
213 Message::ai_with_tool_calls(content, tool_calls)
214 };
215
216 Ok(ChatResponse { message, usage })
217}
218
219pub(crate) fn check_error_status(resp: &ProviderResponse) -> Result<(), SynapticError> {
220 if resp.status == 429 {
221 let msg = resp.body["error"]["message"]
222 .as_str()
223 .unwrap_or("rate limited")
224 .to_string();
225 return Err(SynapticError::RateLimit(msg));
226 }
227 if resp.status >= 400 {
228 let msg = resp.body["error"]["message"]
229 .as_str()
230 .unwrap_or("unknown API error")
231 .to_string();
232 return Err(SynapticError::Model(format!(
233 "OpenAI API error ({}): {}",
234 resp.status, msg
235 )));
236 }
237 Ok(())
238}
239
240pub(crate) fn parse_tool_calls(message: &Value) -> Vec<ToolCall> {
241 message["tool_calls"]
242 .as_array()
243 .map(|arr| {
244 arr.iter()
245 .filter_map(|tc| {
246 let id = tc["id"].as_str()?.to_string();
247 let name = tc["function"]["name"].as_str()?.to_string();
248 let args_str = tc["function"]["arguments"].as_str().unwrap_or("{}");
249 let arguments =
250 serde_json::from_str(args_str).unwrap_or(Value::Object(Default::default()));
251 Some(ToolCall {
252 id,
253 name,
254 arguments,
255 })
256 })
257 .collect()
258 })
259 .unwrap_or_default()
260}
261
262pub(crate) fn parse_usage(usage: &Value) -> Option<TokenUsage> {
263 if usage.is_null() {
264 return None;
265 }
266 Some(TokenUsage {
267 input_tokens: usage["prompt_tokens"].as_u64().unwrap_or(0) as u32,
268 output_tokens: usage["completion_tokens"].as_u64().unwrap_or(0) as u32,
269 total_tokens: usage["total_tokens"].as_u64().unwrap_or(0) as u32,
270 input_details: None,
271 output_details: None,
272 })
273}
274
275pub(crate) fn parse_stream_chunk(data: &str) -> Option<AIMessageChunk> {
276 let v: Value = serde_json::from_str(data).ok()?;
277 let delta = &v["choices"][0]["delta"];
278
279 let content = delta["content"].as_str().unwrap_or("").to_string();
280 let tool_calls = parse_tool_calls(delta);
281 let usage = parse_usage(&v["usage"]);
282
283 Some(AIMessageChunk {
284 content,
285 tool_calls,
286 usage,
287 ..Default::default()
288 })
289}
290
291#[async_trait]
292impl ChatModel for OpenAiChatModel {
293 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, SynapticError> {
294 let provider_req = self.build_request(&request, false);
295 let resp = self.backend.send(provider_req).await?;
296 parse_response(&resp)
297 }
298
299 fn stream_chat(&self, request: ChatRequest) -> ChatStream<'_> {
300 Box::pin(async_stream::stream! {
301 let provider_req = self.build_request(&request, true);
302 let byte_stream = self.backend.send_stream(provider_req).await;
303
304 let byte_stream = match byte_stream {
305 Ok(s) => s,
306 Err(e) => {
307 yield Err(e);
308 return;
309 }
310 };
311
312 use eventsource_stream::Eventsource;
313 use futures::StreamExt;
314
315 let mut event_stream = byte_stream
316 .map(|result| result.map_err(|e| std::io::Error::other(e.to_string())))
317 .eventsource();
318
319 while let Some(event) = event_stream.next().await {
320 match event {
321 Ok(ev) => {
322 if ev.data == "[DONE]" {
323 break;
324 }
325 if let Some(chunk) = parse_stream_chunk(&ev.data) {
326 yield Ok(chunk);
327 }
328 }
329 Err(e) => {
330 yield Err(SynapticError::Model(format!("SSE parse error: {e}")));
331 break;
332 }
333 }
334 }
335 })
336 }
337}