Skip to main content

synaptic_openai/
azure.rs

1//! Azure OpenAI integration.
2//!
3//! Azure OpenAI uses a different URL pattern and authentication scheme
4//! compared to the standard OpenAI API:
5//!
6//! - URL: `https://{resource}.openai.azure.com/openai/deployments/{deployment}/chat/completions?api-version={version}`
7//! - Auth: `api-key: {key}` header (not `Authorization: Bearer`)
8
9use std::sync::Arc;
10
11use async_trait::async_trait;
12use serde_json::{json, Value};
13use synaptic_core::{
14    ChatModel, ChatRequest, ChatResponse, ChatStream, Embeddings, SynapticError, ToolChoice,
15};
16use synaptic_models::{ProviderBackend, ProviderRequest};
17
18use crate::chat_model::{
19    message_to_openai, parse_response, parse_stream_chunk, tool_def_to_openai,
20};
21use crate::embeddings::parse_embeddings_response;
22
23// ---------------------------------------------------------------------------
24// Config
25// ---------------------------------------------------------------------------
26
27/// Configuration for Azure OpenAI chat completions.
28#[derive(Debug, Clone)]
29pub struct AzureOpenAiConfig {
30    pub api_key: String,
31    pub resource_name: String,
32    pub deployment_name: String,
33    pub api_version: String,
34    pub max_tokens: Option<u32>,
35    pub temperature: Option<f64>,
36    pub top_p: Option<f64>,
37    pub stop: Option<Vec<String>>,
38}
39
40impl AzureOpenAiConfig {
41    /// Create a new Azure OpenAI config.
42    ///
43    /// The `deployment_name` typically corresponds to the model you deployed
44    /// (e.g. `"gpt-4"`, `"gpt-4o"`).
45    pub fn new(
46        api_key: impl Into<String>,
47        resource_name: impl Into<String>,
48        deployment_name: impl Into<String>,
49    ) -> Self {
50        Self {
51            api_key: api_key.into(),
52            resource_name: resource_name.into(),
53            deployment_name: deployment_name.into(),
54            api_version: "2024-10-21".to_string(),
55            max_tokens: None,
56            temperature: None,
57            top_p: None,
58            stop: None,
59        }
60    }
61
62    pub fn with_api_version(mut self, version: impl Into<String>) -> Self {
63        self.api_version = version.into();
64        self
65    }
66
67    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
68        self.max_tokens = Some(max_tokens);
69        self
70    }
71
72    pub fn with_temperature(mut self, temperature: f64) -> Self {
73        self.temperature = Some(temperature);
74        self
75    }
76
77    pub fn with_top_p(mut self, top_p: f64) -> Self {
78        self.top_p = Some(top_p);
79        self
80    }
81
82    pub fn with_stop(mut self, stop: Vec<String>) -> Self {
83        self.stop = Some(stop);
84        self
85    }
86}
87
88// ---------------------------------------------------------------------------
89// Chat model
90// ---------------------------------------------------------------------------
91
92/// Azure OpenAI chat model.
93pub struct AzureOpenAiChatModel {
94    config: AzureOpenAiConfig,
95    backend: Arc<dyn ProviderBackend>,
96}
97
98impl AzureOpenAiChatModel {
99    pub fn new(config: AzureOpenAiConfig, backend: Arc<dyn ProviderBackend>) -> Self {
100        Self { config, backend }
101    }
102
103    /// Build a `ProviderRequest` targeting the Azure chat completions endpoint.
104    pub fn build_request(&self, request: &ChatRequest, stream: bool) -> ProviderRequest {
105        let messages: Vec<Value> = request.messages.iter().map(message_to_openai).collect();
106
107        let mut body = json!({
108            "messages": messages,
109            "stream": stream,
110        });
111
112        if let Some(max_tokens) = self.config.max_tokens {
113            body["max_tokens"] = json!(max_tokens);
114        }
115        if let Some(temp) = self.config.temperature {
116            body["temperature"] = json!(temp);
117        }
118        if let Some(top_p) = self.config.top_p {
119            body["top_p"] = json!(top_p);
120        }
121        if let Some(ref stop) = self.config.stop {
122            body["stop"] = json!(stop);
123        }
124        if !request.tools.is_empty() {
125            body["tools"] = json!(request
126                .tools
127                .iter()
128                .map(tool_def_to_openai)
129                .collect::<Vec<_>>());
130        }
131        if let Some(ref choice) = request.tool_choice {
132            body["tool_choice"] = match choice {
133                ToolChoice::Auto => json!("auto"),
134                ToolChoice::Required => json!("required"),
135                ToolChoice::None => json!("none"),
136                ToolChoice::Specific(name) => json!({
137                    "type": "function",
138                    "function": {"name": name}
139                }),
140            };
141        }
142
143        let url = format!(
144            "https://{}.openai.azure.com/openai/deployments/{}/chat/completions?api-version={}",
145            self.config.resource_name, self.config.deployment_name, self.config.api_version,
146        );
147
148        ProviderRequest {
149            url,
150            headers: vec![
151                ("api-key".to_string(), self.config.api_key.clone()),
152                ("Content-Type".to_string(), "application/json".to_string()),
153            ],
154            body,
155        }
156    }
157}
158
159#[async_trait]
160impl ChatModel for AzureOpenAiChatModel {
161    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, SynapticError> {
162        let provider_req = self.build_request(&request, false);
163        let resp = self.backend.send(provider_req).await?;
164        parse_response(&resp)
165    }
166
167    fn stream_chat(&self, request: ChatRequest) -> ChatStream<'_> {
168        Box::pin(async_stream::stream! {
169            let provider_req = self.build_request(&request, true);
170            let byte_stream = self.backend.send_stream(provider_req).await;
171
172            let byte_stream = match byte_stream {
173                Ok(s) => s,
174                Err(e) => {
175                    yield Err(e);
176                    return;
177                }
178            };
179
180            use eventsource_stream::Eventsource;
181            use futures::StreamExt;
182
183            let mut event_stream = byte_stream
184                .map(|result| result.map_err(|e| std::io::Error::other(e.to_string())))
185                .eventsource();
186
187            while let Some(event) = event_stream.next().await {
188                match event {
189                    Ok(ev) => {
190                        if ev.data == "[DONE]" {
191                            break;
192                        }
193                        if let Some(chunk) = parse_stream_chunk(&ev.data) {
194                            yield Ok(chunk);
195                        }
196                    }
197                    Err(e) => {
198                        yield Err(SynapticError::Model(format!("SSE parse error: {e}")));
199                        break;
200                    }
201                }
202            }
203        })
204    }
205}
206
207// ---------------------------------------------------------------------------
208// Embeddings config
209// ---------------------------------------------------------------------------
210
211/// Configuration for Azure OpenAI embeddings.
212#[derive(Debug, Clone)]
213pub struct AzureOpenAiEmbeddingsConfig {
214    pub api_key: String,
215    pub resource_name: String,
216    pub deployment_name: String,
217    pub api_version: String,
218    pub model: String,
219}
220
221impl AzureOpenAiEmbeddingsConfig {
222    /// Create a new Azure OpenAI embeddings config.
223    pub fn new(
224        api_key: impl Into<String>,
225        resource_name: impl Into<String>,
226        deployment_name: impl Into<String>,
227    ) -> Self {
228        Self {
229            api_key: api_key.into(),
230            resource_name: resource_name.into(),
231            deployment_name: deployment_name.into(),
232            api_version: "2024-10-21".to_string(),
233            model: "text-embedding-3-small".to_string(),
234        }
235    }
236
237    pub fn with_api_version(mut self, version: impl Into<String>) -> Self {
238        self.api_version = version.into();
239        self
240    }
241
242    pub fn with_model(mut self, model: impl Into<String>) -> Self {
243        self.model = model.into();
244        self
245    }
246}
247
248// ---------------------------------------------------------------------------
249// Embeddings
250// ---------------------------------------------------------------------------
251
252/// Azure OpenAI embeddings.
253pub struct AzureOpenAiEmbeddings {
254    config: AzureOpenAiEmbeddingsConfig,
255    backend: Arc<dyn ProviderBackend>,
256}
257
258impl AzureOpenAiEmbeddings {
259    pub fn new(config: AzureOpenAiEmbeddingsConfig, backend: Arc<dyn ProviderBackend>) -> Self {
260        Self { config, backend }
261    }
262
263    fn build_request(&self, input: Vec<String>) -> ProviderRequest {
264        let url = format!(
265            "https://{}.openai.azure.com/openai/deployments/{}/embeddings?api-version={}",
266            self.config.resource_name, self.config.deployment_name, self.config.api_version,
267        );
268
269        ProviderRequest {
270            url,
271            headers: vec![
272                ("api-key".to_string(), self.config.api_key.clone()),
273                ("Content-Type".to_string(), "application/json".to_string()),
274            ],
275            body: json!({
276                "model": self.config.model,
277                "input": input,
278            }),
279        }
280    }
281}
282
283#[async_trait]
284impl Embeddings for AzureOpenAiEmbeddings {
285    async fn embed_documents(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, SynapticError> {
286        let input: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
287        let request = self.build_request(input);
288        let response = self.backend.send(request).await?;
289
290        if response.status != 200 {
291            return Err(SynapticError::Embedding(format!(
292                "Azure OpenAI API error ({}): {}",
293                response.status, response.body
294            )));
295        }
296
297        parse_embeddings_response(&response.body)
298    }
299
300    async fn embed_query(&self, text: &str) -> Result<Vec<f32>, SynapticError> {
301        let mut results = self.embed_documents(&[text]).await?;
302        results
303            .pop()
304            .ok_or_else(|| SynapticError::Embedding("empty response".to_string()))
305    }
306}