1use std::sync::Arc;
10
11use async_trait::async_trait;
12use serde_json::{json, Value};
13use synaptic_core::{
14 ChatModel, ChatRequest, ChatResponse, ChatStream, Embeddings, SynapticError, ToolChoice,
15};
16use synaptic_models::{ProviderBackend, ProviderRequest};
17
18use crate::chat_model::{
19 message_to_openai, parse_response, parse_stream_chunk, tool_def_to_openai,
20};
21use crate::embeddings::parse_embeddings_response;
22
23#[derive(Debug, Clone)]
29pub struct AzureOpenAiConfig {
30 pub api_key: String,
31 pub resource_name: String,
32 pub deployment_name: String,
33 pub api_version: String,
34 pub max_tokens: Option<u32>,
35 pub temperature: Option<f64>,
36 pub top_p: Option<f64>,
37 pub stop: Option<Vec<String>>,
38}
39
40impl AzureOpenAiConfig {
41 pub fn new(
46 api_key: impl Into<String>,
47 resource_name: impl Into<String>,
48 deployment_name: impl Into<String>,
49 ) -> Self {
50 Self {
51 api_key: api_key.into(),
52 resource_name: resource_name.into(),
53 deployment_name: deployment_name.into(),
54 api_version: "2024-10-21".to_string(),
55 max_tokens: None,
56 temperature: None,
57 top_p: None,
58 stop: None,
59 }
60 }
61
62 pub fn with_api_version(mut self, version: impl Into<String>) -> Self {
63 self.api_version = version.into();
64 self
65 }
66
67 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
68 self.max_tokens = Some(max_tokens);
69 self
70 }
71
72 pub fn with_temperature(mut self, temperature: f64) -> Self {
73 self.temperature = Some(temperature);
74 self
75 }
76
77 pub fn with_top_p(mut self, top_p: f64) -> Self {
78 self.top_p = Some(top_p);
79 self
80 }
81
82 pub fn with_stop(mut self, stop: Vec<String>) -> Self {
83 self.stop = Some(stop);
84 self
85 }
86}
87
88pub struct AzureOpenAiChatModel {
94 config: AzureOpenAiConfig,
95 backend: Arc<dyn ProviderBackend>,
96}
97
98impl AzureOpenAiChatModel {
99 pub fn new(config: AzureOpenAiConfig, backend: Arc<dyn ProviderBackend>) -> Self {
100 Self { config, backend }
101 }
102
103 pub fn build_request(&self, request: &ChatRequest, stream: bool) -> ProviderRequest {
105 let messages: Vec<Value> = request.messages.iter().map(message_to_openai).collect();
106
107 let mut body = json!({
108 "messages": messages,
109 "stream": stream,
110 });
111
112 if let Some(max_tokens) = self.config.max_tokens {
113 body["max_tokens"] = json!(max_tokens);
114 }
115 if let Some(temp) = self.config.temperature {
116 body["temperature"] = json!(temp);
117 }
118 if let Some(top_p) = self.config.top_p {
119 body["top_p"] = json!(top_p);
120 }
121 if let Some(ref stop) = self.config.stop {
122 body["stop"] = json!(stop);
123 }
124 if !request.tools.is_empty() {
125 body["tools"] = json!(request
126 .tools
127 .iter()
128 .map(tool_def_to_openai)
129 .collect::<Vec<_>>());
130 }
131 if let Some(ref choice) = request.tool_choice {
132 body["tool_choice"] = match choice {
133 ToolChoice::Auto => json!("auto"),
134 ToolChoice::Required => json!("required"),
135 ToolChoice::None => json!("none"),
136 ToolChoice::Specific(name) => json!({
137 "type": "function",
138 "function": {"name": name}
139 }),
140 };
141 }
142
143 let url = format!(
144 "https://{}.openai.azure.com/openai/deployments/{}/chat/completions?api-version={}",
145 self.config.resource_name, self.config.deployment_name, self.config.api_version,
146 );
147
148 ProviderRequest {
149 url,
150 headers: vec![
151 ("api-key".to_string(), self.config.api_key.clone()),
152 ("Content-Type".to_string(), "application/json".to_string()),
153 ],
154 body,
155 }
156 }
157}
158
159#[async_trait]
160impl ChatModel for AzureOpenAiChatModel {
161 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, SynapticError> {
162 let provider_req = self.build_request(&request, false);
163 let resp = self.backend.send(provider_req).await?;
164 parse_response(&resp)
165 }
166
167 fn stream_chat(&self, request: ChatRequest) -> ChatStream<'_> {
168 Box::pin(async_stream::stream! {
169 let provider_req = self.build_request(&request, true);
170 let byte_stream = self.backend.send_stream(provider_req).await;
171
172 let byte_stream = match byte_stream {
173 Ok(s) => s,
174 Err(e) => {
175 yield Err(e);
176 return;
177 }
178 };
179
180 use eventsource_stream::Eventsource;
181 use futures::StreamExt;
182
183 let mut event_stream = byte_stream
184 .map(|result| result.map_err(|e| std::io::Error::other(e.to_string())))
185 .eventsource();
186
187 while let Some(event) = event_stream.next().await {
188 match event {
189 Ok(ev) => {
190 if ev.data == "[DONE]" {
191 break;
192 }
193 if let Some(chunk) = parse_stream_chunk(&ev.data) {
194 yield Ok(chunk);
195 }
196 }
197 Err(e) => {
198 yield Err(SynapticError::Model(format!("SSE parse error: {e}")));
199 break;
200 }
201 }
202 }
203 })
204 }
205}
206
207#[derive(Debug, Clone)]
213pub struct AzureOpenAiEmbeddingsConfig {
214 pub api_key: String,
215 pub resource_name: String,
216 pub deployment_name: String,
217 pub api_version: String,
218 pub model: String,
219}
220
221impl AzureOpenAiEmbeddingsConfig {
222 pub fn new(
224 api_key: impl Into<String>,
225 resource_name: impl Into<String>,
226 deployment_name: impl Into<String>,
227 ) -> Self {
228 Self {
229 api_key: api_key.into(),
230 resource_name: resource_name.into(),
231 deployment_name: deployment_name.into(),
232 api_version: "2024-10-21".to_string(),
233 model: "text-embedding-3-small".to_string(),
234 }
235 }
236
237 pub fn with_api_version(mut self, version: impl Into<String>) -> Self {
238 self.api_version = version.into();
239 self
240 }
241
242 pub fn with_model(mut self, model: impl Into<String>) -> Self {
243 self.model = model.into();
244 self
245 }
246}
247
248pub struct AzureOpenAiEmbeddings {
254 config: AzureOpenAiEmbeddingsConfig,
255 backend: Arc<dyn ProviderBackend>,
256}
257
258impl AzureOpenAiEmbeddings {
259 pub fn new(config: AzureOpenAiEmbeddingsConfig, backend: Arc<dyn ProviderBackend>) -> Self {
260 Self { config, backend }
261 }
262
263 fn build_request(&self, input: Vec<String>) -> ProviderRequest {
264 let url = format!(
265 "https://{}.openai.azure.com/openai/deployments/{}/embeddings?api-version={}",
266 self.config.resource_name, self.config.deployment_name, self.config.api_version,
267 );
268
269 ProviderRequest {
270 url,
271 headers: vec![
272 ("api-key".to_string(), self.config.api_key.clone()),
273 ("Content-Type".to_string(), "application/json".to_string()),
274 ],
275 body: json!({
276 "model": self.config.model,
277 "input": input,
278 }),
279 }
280 }
281}
282
283#[async_trait]
284impl Embeddings for AzureOpenAiEmbeddings {
285 async fn embed_documents(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, SynapticError> {
286 let input: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
287 let request = self.build_request(input);
288 let response = self.backend.send(request).await?;
289
290 if response.status != 200 {
291 return Err(SynapticError::Embedding(format!(
292 "Azure OpenAI API error ({}): {}",
293 response.status, response.body
294 )));
295 }
296
297 parse_embeddings_response(&response.body)
298 }
299
300 async fn embed_query(&self, text: &str) -> Result<Vec<f32>, SynapticError> {
301 let mut results = self.embed_documents(&[text]).await?;
302 results
303 .pop()
304 .ok_or_else(|| SynapticError::Embedding("empty response".to_string()))
305 }
306}