synaptic_openai/
embeddings.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4use serde_json::json;
5use synaptic_core::{Embeddings, SynapticError};
6use synaptic_models::{ProviderBackend, ProviderRequest};
7
8pub struct OpenAiEmbeddingsConfig {
9 pub api_key: String,
10 pub model: String,
11 pub base_url: String,
12}
13
14impl OpenAiEmbeddingsConfig {
15 pub fn new(api_key: impl Into<String>) -> Self {
16 Self {
17 api_key: api_key.into(),
18 model: "text-embedding-3-small".to_string(),
19 base_url: "https://api.openai.com/v1".to_string(),
20 }
21 }
22
23 pub fn with_model(mut self, model: impl Into<String>) -> Self {
24 self.model = model.into();
25 self
26 }
27
28 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
29 self.base_url = base_url.into();
30 self
31 }
32}
33
34pub struct OpenAiEmbeddings {
35 config: OpenAiEmbeddingsConfig,
36 backend: Arc<dyn ProviderBackend>,
37}
38
39impl OpenAiEmbeddings {
40 pub fn new(config: OpenAiEmbeddingsConfig, backend: Arc<dyn ProviderBackend>) -> Self {
41 Self { config, backend }
42 }
43
44 fn build_request(&self, input: Vec<String>) -> ProviderRequest {
45 ProviderRequest {
46 url: format!("{}/embeddings", self.config.base_url),
47 headers: vec![
48 (
49 "Authorization".to_string(),
50 format!("Bearer {}", self.config.api_key),
51 ),
52 ("Content-Type".to_string(), "application/json".to_string()),
53 ],
54 body: json!({
55 "model": self.config.model,
56 "input": input,
57 }),
58 }
59 }
60
61 fn parse_response(&self, body: &serde_json::Value) -> Result<Vec<Vec<f32>>, SynapticError> {
62 parse_embeddings_response(body)
63 }
64}
65
66pub(crate) fn parse_embeddings_response(
67 body: &serde_json::Value,
68) -> Result<Vec<Vec<f32>>, SynapticError> {
69 let data = body
70 .get("data")
71 .and_then(|d| d.as_array())
72 .ok_or_else(|| SynapticError::Embedding("missing 'data' field in response".to_string()))?;
73
74 let mut embeddings = Vec::with_capacity(data.len());
75 for item in data {
76 let embedding = item
77 .get("embedding")
78 .and_then(|e| e.as_array())
79 .ok_or_else(|| SynapticError::Embedding("missing 'embedding' field".to_string()))?
80 .iter()
81 .map(|v| v.as_f64().unwrap_or(0.0) as f32)
82 .collect();
83 embeddings.push(embedding);
84 }
85
86 Ok(embeddings)
87}
88
89#[async_trait]
90impl Embeddings for OpenAiEmbeddings {
91 async fn embed_documents(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, SynapticError> {
92 let input: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
93 let request = self.build_request(input);
94 let response = self.backend.send(request).await?;
95
96 if response.status != 200 {
97 return Err(SynapticError::Embedding(format!(
98 "OpenAI API error ({}): {}",
99 response.status, response.body
100 )));
101 }
102
103 self.parse_response(&response.body)
104 }
105
106 async fn embed_query(&self, text: &str) -> Result<Vec<f32>, SynapticError> {
107 let mut results = self.embed_documents(&[text]).await?;
108 results
109 .pop()
110 .ok_or_else(|| SynapticError::Embedding("empty response".to_string()))
111 }
112}