Skip to main content

synaptic_models/
openai.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use serde_json::{json, Value};
5use synaptic_core::{
6    AIMessageChunk, ChatModel, ChatRequest, ChatResponse, ChatStream, Message, SynapseError,
7    TokenUsage, ToolCall, ToolChoice, ToolDefinition,
8};
9
10use crate::backend::{ProviderBackend, ProviderRequest, ProviderResponse};
11
12#[derive(Debug, Clone)]
13pub struct OpenAiConfig {
14    pub api_key: String,
15    pub model: String,
16    pub base_url: String,
17    pub max_tokens: Option<u32>,
18    pub temperature: Option<f64>,
19    pub top_p: Option<f64>,
20    pub stop: Option<Vec<String>>,
21    pub seed: Option<u64>,
22}
23
24impl OpenAiConfig {
25    pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
26        Self {
27            api_key: api_key.into(),
28            model: model.into(),
29            base_url: "https://api.openai.com/v1".to_string(),
30            max_tokens: None,
31            temperature: None,
32            top_p: None,
33            stop: None,
34            seed: None,
35        }
36    }
37
38    pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
39        self.base_url = url.into();
40        self
41    }
42
43    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
44        self.max_tokens = Some(max_tokens);
45        self
46    }
47
48    pub fn with_temperature(mut self, temperature: f64) -> Self {
49        self.temperature = Some(temperature);
50        self
51    }
52
53    pub fn with_top_p(mut self, top_p: f64) -> Self {
54        self.top_p = Some(top_p);
55        self
56    }
57
58    pub fn with_stop(mut self, stop: Vec<String>) -> Self {
59        self.stop = Some(stop);
60        self
61    }
62
63    pub fn with_seed(mut self, seed: u64) -> Self {
64        self.seed = Some(seed);
65        self
66    }
67}
68
69pub struct OpenAiChatModel {
70    config: OpenAiConfig,
71    backend: Arc<dyn ProviderBackend>,
72}
73
74impl OpenAiChatModel {
75    pub fn new(config: OpenAiConfig, backend: Arc<dyn ProviderBackend>) -> Self {
76        Self { config, backend }
77    }
78
79    fn build_request(&self, request: &ChatRequest, stream: bool) -> ProviderRequest {
80        let messages: Vec<Value> = request.messages.iter().map(message_to_openai).collect();
81
82        let mut body = json!({
83            "model": self.config.model,
84            "messages": messages,
85            "stream": stream,
86        });
87
88        if let Some(max_tokens) = self.config.max_tokens {
89            body["max_tokens"] = json!(max_tokens);
90        }
91        if let Some(temp) = self.config.temperature {
92            body["temperature"] = json!(temp);
93        }
94        if let Some(top_p) = self.config.top_p {
95            body["top_p"] = json!(top_p);
96        }
97        if let Some(ref stop) = self.config.stop {
98            body["stop"] = json!(stop);
99        }
100        if let Some(seed) = self.config.seed {
101            body["seed"] = json!(seed);
102        }
103        if !request.tools.is_empty() {
104            body["tools"] = json!(request
105                .tools
106                .iter()
107                .map(tool_def_to_openai)
108                .collect::<Vec<_>>());
109        }
110        if let Some(ref choice) = request.tool_choice {
111            body["tool_choice"] = match choice {
112                ToolChoice::Auto => json!("auto"),
113                ToolChoice::Required => json!("required"),
114                ToolChoice::None => json!("none"),
115                ToolChoice::Specific(name) => json!({
116                    "type": "function",
117                    "function": {"name": name}
118                }),
119            };
120        }
121
122        ProviderRequest {
123            url: format!("{}/chat/completions", self.config.base_url),
124            headers: vec![
125                (
126                    "Authorization".to_string(),
127                    format!("Bearer {}", self.config.api_key),
128                ),
129                ("Content-Type".to_string(), "application/json".to_string()),
130            ],
131            body,
132        }
133    }
134}
135
136fn message_to_openai(msg: &Message) -> Value {
137    match msg {
138        Message::System { content, .. } => json!({
139            "role": "system",
140            "content": content,
141        }),
142        Message::Human { content, .. } => json!({
143            "role": "user",
144            "content": content,
145        }),
146        Message::AI {
147            content,
148            tool_calls,
149            ..
150        } => {
151            let mut obj = json!({
152                "role": "assistant",
153                "content": content,
154            });
155            if !tool_calls.is_empty() {
156                obj["tool_calls"] = json!(tool_calls
157                    .iter()
158                    .map(|tc| json!({
159                        "id": tc.id,
160                        "type": "function",
161                        "function": {
162                            "name": tc.name,
163                            "arguments": tc.arguments.to_string(),
164                        }
165                    }))
166                    .collect::<Vec<_>>());
167            }
168            obj
169        }
170        Message::Tool {
171            content,
172            tool_call_id,
173            ..
174        } => json!({
175            "role": "tool",
176            "content": content,
177            "tool_call_id": tool_call_id,
178        }),
179        Message::Chat {
180            custom_role,
181            content,
182            ..
183        } => json!({
184            "role": custom_role,
185            "content": content,
186        }),
187        Message::Remove { .. } => json!(null), // Remove messages are skipped
188    }
189}
190
191fn tool_def_to_openai(def: &ToolDefinition) -> Value {
192    json!({
193        "type": "function",
194        "function": {
195            "name": def.name,
196            "description": def.description,
197            "parameters": def.parameters,
198        }
199    })
200}
201
202fn parse_response(resp: &ProviderResponse) -> Result<ChatResponse, SynapseError> {
203    check_error_status(resp)?;
204
205    let choice = &resp.body["choices"][0]["message"];
206    let content = choice["content"].as_str().unwrap_or("").to_string();
207    let tool_calls = parse_tool_calls(choice);
208
209    let usage = parse_usage(&resp.body["usage"]);
210
211    let message = if tool_calls.is_empty() {
212        Message::ai(content)
213    } else {
214        Message::ai_with_tool_calls(content, tool_calls)
215    };
216
217    Ok(ChatResponse { message, usage })
218}
219
220fn check_error_status(resp: &ProviderResponse) -> Result<(), SynapseError> {
221    if resp.status == 429 {
222        let msg = resp.body["error"]["message"]
223            .as_str()
224            .unwrap_or("rate limited")
225            .to_string();
226        return Err(SynapseError::RateLimit(msg));
227    }
228    if resp.status >= 400 {
229        let msg = resp.body["error"]["message"]
230            .as_str()
231            .unwrap_or("unknown API error")
232            .to_string();
233        return Err(SynapseError::Model(format!(
234            "OpenAI API error ({}): {}",
235            resp.status, msg
236        )));
237    }
238    Ok(())
239}
240
241fn parse_tool_calls(message: &Value) -> Vec<ToolCall> {
242    message["tool_calls"]
243        .as_array()
244        .map(|arr| {
245            arr.iter()
246                .filter_map(|tc| {
247                    let id = tc["id"].as_str()?.to_string();
248                    let name = tc["function"]["name"].as_str()?.to_string();
249                    let args_str = tc["function"]["arguments"].as_str().unwrap_or("{}");
250                    let arguments =
251                        serde_json::from_str(args_str).unwrap_or(Value::Object(Default::default()));
252                    Some(ToolCall {
253                        id,
254                        name,
255                        arguments,
256                    })
257                })
258                .collect()
259        })
260        .unwrap_or_default()
261}
262
263fn parse_usage(usage: &Value) -> Option<TokenUsage> {
264    if usage.is_null() {
265        return None;
266    }
267    Some(TokenUsage {
268        input_tokens: usage["prompt_tokens"].as_u64().unwrap_or(0) as u32,
269        output_tokens: usage["completion_tokens"].as_u64().unwrap_or(0) as u32,
270        total_tokens: usage["total_tokens"].as_u64().unwrap_or(0) as u32,
271        input_details: None,
272        output_details: None,
273    })
274}
275
276fn parse_stream_chunk(data: &str) -> Option<AIMessageChunk> {
277    let v: Value = serde_json::from_str(data).ok()?;
278    let delta = &v["choices"][0]["delta"];
279
280    let content = delta["content"].as_str().unwrap_or("").to_string();
281    let tool_calls = parse_tool_calls(delta);
282    let usage = parse_usage(&v["usage"]);
283
284    Some(AIMessageChunk {
285        content,
286        tool_calls,
287        usage,
288        ..Default::default()
289    })
290}
291
292#[async_trait]
293impl ChatModel for OpenAiChatModel {
294    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, SynapseError> {
295        let provider_req = self.build_request(&request, false);
296        let resp = self.backend.send(provider_req).await?;
297        parse_response(&resp)
298    }
299
300    fn stream_chat(&self, request: ChatRequest) -> ChatStream<'_> {
301        Box::pin(async_stream::stream! {
302            let provider_req = self.build_request(&request, true);
303            let byte_stream = self.backend.send_stream(provider_req).await;
304
305            let byte_stream = match byte_stream {
306                Ok(s) => s,
307                Err(e) => {
308                    yield Err(e);
309                    return;
310                }
311            };
312
313            use eventsource_stream::Eventsource;
314            use futures::StreamExt;
315
316            let mut event_stream = byte_stream
317                .map(|result| result.map_err(|e| std::io::Error::other(e.to_string())))
318                .eventsource();
319
320            while let Some(event) = event_stream.next().await {
321                match event {
322                    Ok(ev) => {
323                        if ev.data == "[DONE]" {
324                            break;
325                        }
326                        if let Some(chunk) = parse_stream_chunk(&ev.data) {
327                            yield Ok(chunk);
328                        }
329                    }
330                    Err(e) => {
331                        yield Err(SynapseError::Model(format!("SSE parse error: {e}")));
332                        break;
333                    }
334                }
335            }
336        })
337    }
338}