Skip to main content

synaptic_openai/
chat_model.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use serde_json::{json, Value};
5use synaptic_core::{
6    AIMessageChunk, ChatModel, ChatRequest, ChatResponse, ChatStream, Message, SynapticError,
7    TokenUsage, ToolCall, ToolChoice, ToolDefinition,
8};
9use synaptic_models::{ProviderBackend, ProviderRequest, ProviderResponse};
10
11#[derive(Debug, Clone)]
12pub struct OpenAiConfig {
13    pub api_key: String,
14    pub model: String,
15    pub base_url: String,
16    pub max_tokens: Option<u32>,
17    pub temperature: Option<f64>,
18    pub top_p: Option<f64>,
19    pub stop: Option<Vec<String>>,
20    pub seed: Option<u64>,
21}
22
23impl OpenAiConfig {
24    pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
25        Self {
26            api_key: api_key.into(),
27            model: model.into(),
28            base_url: "https://api.openai.com/v1".to_string(),
29            max_tokens: None,
30            temperature: None,
31            top_p: None,
32            stop: None,
33            seed: None,
34        }
35    }
36
37    pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
38        self.base_url = url.into();
39        self
40    }
41
42    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
43        self.max_tokens = Some(max_tokens);
44        self
45    }
46
47    pub fn with_temperature(mut self, temperature: f64) -> Self {
48        self.temperature = Some(temperature);
49        self
50    }
51
52    pub fn with_top_p(mut self, top_p: f64) -> Self {
53        self.top_p = Some(top_p);
54        self
55    }
56
57    pub fn with_stop(mut self, stop: Vec<String>) -> Self {
58        self.stop = Some(stop);
59        self
60    }
61
62    pub fn with_seed(mut self, seed: u64) -> Self {
63        self.seed = Some(seed);
64        self
65    }
66}
67
68pub struct OpenAiChatModel {
69    config: OpenAiConfig,
70    backend: Arc<dyn ProviderBackend>,
71}
72
73impl OpenAiChatModel {
74    pub fn new(config: OpenAiConfig, backend: Arc<dyn ProviderBackend>) -> Self {
75        Self { config, backend }
76    }
77
78    fn build_request(&self, request: &ChatRequest, stream: bool) -> ProviderRequest {
79        let messages: Vec<Value> = request.messages.iter().map(message_to_openai).collect();
80
81        let mut body = json!({
82            "model": self.config.model,
83            "messages": messages,
84            "stream": stream,
85        });
86
87        if let Some(max_tokens) = self.config.max_tokens {
88            body["max_tokens"] = json!(max_tokens);
89        }
90        if let Some(temp) = self.config.temperature {
91            body["temperature"] = json!(temp);
92        }
93        if let Some(top_p) = self.config.top_p {
94            body["top_p"] = json!(top_p);
95        }
96        if let Some(ref stop) = self.config.stop {
97            body["stop"] = json!(stop);
98        }
99        if let Some(seed) = self.config.seed {
100            body["seed"] = json!(seed);
101        }
102        if !request.tools.is_empty() {
103            body["tools"] = json!(request
104                .tools
105                .iter()
106                .map(tool_def_to_openai)
107                .collect::<Vec<_>>());
108        }
109        if let Some(ref choice) = request.tool_choice {
110            body["tool_choice"] = match choice {
111                ToolChoice::Auto => json!("auto"),
112                ToolChoice::Required => json!("required"),
113                ToolChoice::None => json!("none"),
114                ToolChoice::Specific(name) => json!({
115                    "type": "function",
116                    "function": {"name": name}
117                }),
118            };
119        }
120
121        ProviderRequest {
122            url: format!("{}/chat/completions", self.config.base_url),
123            headers: vec![
124                (
125                    "Authorization".to_string(),
126                    format!("Bearer {}", self.config.api_key),
127                ),
128                ("Content-Type".to_string(), "application/json".to_string()),
129            ],
130            body,
131        }
132    }
133}
134
135pub(crate) fn message_to_openai(msg: &Message) -> Value {
136    match msg {
137        Message::System { content, .. } => json!({
138            "role": "system",
139            "content": content,
140        }),
141        Message::Human { content, .. } => json!({
142            "role": "user",
143            "content": content,
144        }),
145        Message::AI {
146            content,
147            tool_calls,
148            ..
149        } => {
150            let mut obj = json!({
151                "role": "assistant",
152                "content": content,
153            });
154            if !tool_calls.is_empty() {
155                obj["tool_calls"] = json!(tool_calls
156                    .iter()
157                    .map(|tc| json!({
158                        "id": tc.id,
159                        "type": "function",
160                        "function": {
161                            "name": tc.name,
162                            "arguments": tc.arguments.to_string(),
163                        }
164                    }))
165                    .collect::<Vec<_>>());
166            }
167            obj
168        }
169        Message::Tool {
170            content,
171            tool_call_id,
172            ..
173        } => json!({
174            "role": "tool",
175            "content": content,
176            "tool_call_id": tool_call_id,
177        }),
178        Message::Chat {
179            custom_role,
180            content,
181            ..
182        } => json!({
183            "role": custom_role,
184            "content": content,
185        }),
186        Message::Remove { .. } => json!(null),
187    }
188}
189
190pub(crate) fn tool_def_to_openai(def: &ToolDefinition) -> Value {
191    json!({
192        "type": "function",
193        "function": {
194            "name": def.name,
195            "description": def.description,
196            "parameters": def.parameters,
197        }
198    })
199}
200
201pub(crate) fn parse_response(resp: &ProviderResponse) -> Result<ChatResponse, SynapticError> {
202    check_error_status(resp)?;
203
204    let choice = &resp.body["choices"][0]["message"];
205    let content = choice["content"].as_str().unwrap_or("").to_string();
206    let tool_calls = parse_tool_calls(choice);
207
208    let usage = parse_usage(&resp.body["usage"]);
209
210    let message = if tool_calls.is_empty() {
211        Message::ai(content)
212    } else {
213        Message::ai_with_tool_calls(content, tool_calls)
214    };
215
216    Ok(ChatResponse { message, usage })
217}
218
219pub(crate) fn check_error_status(resp: &ProviderResponse) -> Result<(), SynapticError> {
220    if resp.status == 429 {
221        let msg = resp.body["error"]["message"]
222            .as_str()
223            .unwrap_or("rate limited")
224            .to_string();
225        return Err(SynapticError::RateLimit(msg));
226    }
227    if resp.status >= 400 {
228        let msg = resp.body["error"]["message"]
229            .as_str()
230            .unwrap_or("unknown API error")
231            .to_string();
232        return Err(SynapticError::Model(format!(
233            "OpenAI API error ({}): {}",
234            resp.status, msg
235        )));
236    }
237    Ok(())
238}
239
240pub(crate) fn parse_tool_calls(message: &Value) -> Vec<ToolCall> {
241    message["tool_calls"]
242        .as_array()
243        .map(|arr| {
244            arr.iter()
245                .filter_map(|tc| {
246                    let id = tc["id"].as_str()?.to_string();
247                    let name = tc["function"]["name"].as_str()?.to_string();
248                    let args_str = tc["function"]["arguments"].as_str().unwrap_or("{}");
249                    let arguments =
250                        serde_json::from_str(args_str).unwrap_or(Value::Object(Default::default()));
251                    Some(ToolCall {
252                        id,
253                        name,
254                        arguments,
255                    })
256                })
257                .collect()
258        })
259        .unwrap_or_default()
260}
261
262pub(crate) fn parse_usage(usage: &Value) -> Option<TokenUsage> {
263    if usage.is_null() {
264        return None;
265    }
266    Some(TokenUsage {
267        input_tokens: usage["prompt_tokens"].as_u64().unwrap_or(0) as u32,
268        output_tokens: usage["completion_tokens"].as_u64().unwrap_or(0) as u32,
269        total_tokens: usage["total_tokens"].as_u64().unwrap_or(0) as u32,
270        input_details: None,
271        output_details: None,
272    })
273}
274
275pub(crate) fn parse_stream_chunk(data: &str) -> Option<AIMessageChunk> {
276    let v: Value = serde_json::from_str(data).ok()?;
277    let delta = &v["choices"][0]["delta"];
278
279    let content = delta["content"].as_str().unwrap_or("").to_string();
280    let tool_calls = parse_tool_calls(delta);
281    let usage = parse_usage(&v["usage"]);
282
283    Some(AIMessageChunk {
284        content,
285        tool_calls,
286        usage,
287        ..Default::default()
288    })
289}
290
291#[async_trait]
292impl ChatModel for OpenAiChatModel {
293    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, SynapticError> {
294        let provider_req = self.build_request(&request, false);
295        let resp = self.backend.send(provider_req).await?;
296        parse_response(&resp)
297    }
298
299    fn stream_chat(&self, request: ChatRequest) -> ChatStream<'_> {
300        Box::pin(async_stream::stream! {
301            let provider_req = self.build_request(&request, true);
302            let byte_stream = self.backend.send_stream(provider_req).await;
303
304            let byte_stream = match byte_stream {
305                Ok(s) => s,
306                Err(e) => {
307                    yield Err(e);
308                    return;
309                }
310            };
311
312            use eventsource_stream::Eventsource;
313            use futures::StreamExt;
314
315            let mut event_stream = byte_stream
316                .map(|result| result.map_err(|e| std::io::Error::other(e.to_string())))
317                .eventsource();
318
319            while let Some(event) = event_stream.next().await {
320                match event {
321                    Ok(ev) => {
322                        if ev.data == "[DONE]" {
323                            break;
324                        }
325                        if let Some(chunk) = parse_stream_chunk(&ev.data) {
326                            yield Ok(chunk);
327                        }
328                    }
329                    Err(e) => {
330                        yield Err(SynapticError::Model(format!("SSE parse error: {e}")));
331                        break;
332                    }
333                }
334            }
335        })
336    }
337}