synaptic_embeddings/
openai.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4use serde_json::json;
5use synaptic_core::SynapseError;
6use synaptic_models::backend::{ProviderBackend, ProviderRequest};
7
8use crate::Embeddings;
9
10pub struct OpenAiEmbeddingsConfig {
11 pub api_key: String,
12 pub model: String,
13 pub base_url: String,
14}
15
16impl OpenAiEmbeddingsConfig {
17 pub fn new(api_key: impl Into<String>) -> Self {
18 Self {
19 api_key: api_key.into(),
20 model: "text-embedding-3-small".to_string(),
21 base_url: "https://api.openai.com/v1".to_string(),
22 }
23 }
24
25 pub fn with_model(mut self, model: impl Into<String>) -> Self {
26 self.model = model.into();
27 self
28 }
29
30 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
31 self.base_url = base_url.into();
32 self
33 }
34}
35
36pub struct OpenAiEmbeddings {
37 config: OpenAiEmbeddingsConfig,
38 backend: Arc<dyn ProviderBackend>,
39}
40
41impl OpenAiEmbeddings {
42 pub fn new(config: OpenAiEmbeddingsConfig, backend: Arc<dyn ProviderBackend>) -> Self {
43 Self { config, backend }
44 }
45
46 fn build_request(&self, input: Vec<String>) -> ProviderRequest {
47 ProviderRequest {
48 url: format!("{}/embeddings", self.config.base_url),
49 headers: vec![
50 (
51 "Authorization".to_string(),
52 format!("Bearer {}", self.config.api_key),
53 ),
54 ("Content-Type".to_string(), "application/json".to_string()),
55 ],
56 body: json!({
57 "model": self.config.model,
58 "input": input,
59 }),
60 }
61 }
62
63 fn parse_response(&self, body: &serde_json::Value) -> Result<Vec<Vec<f32>>, SynapseError> {
64 let data = body.get("data").and_then(|d| d.as_array()).ok_or_else(|| {
65 SynapseError::Embedding("missing 'data' field in response".to_string())
66 })?;
67
68 let mut embeddings = Vec::with_capacity(data.len());
69 for item in data {
70 let embedding = item
71 .get("embedding")
72 .and_then(|e| e.as_array())
73 .ok_or_else(|| SynapseError::Embedding("missing 'embedding' field".to_string()))?
74 .iter()
75 .map(|v| v.as_f64().unwrap_or(0.0) as f32)
76 .collect();
77 embeddings.push(embedding);
78 }
79
80 Ok(embeddings)
81 }
82}
83
84#[async_trait]
85impl Embeddings for OpenAiEmbeddings {
86 async fn embed_documents(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, SynapseError> {
87 let input: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
88 let request = self.build_request(input);
89 let response = self.backend.send(request).await?;
90
91 if response.status != 200 {
92 return Err(SynapseError::Embedding(format!(
93 "OpenAI API error ({}): {}",
94 response.status, response.body
95 )));
96 }
97
98 self.parse_response(&response.body)
99 }
100
101 async fn embed_query(&self, text: &str) -> Result<Vec<f32>, SynapseError> {
102 let mut results = self.embed_documents(&[text]).await?;
103 results
104 .pop()
105 .ok_or_else(|| SynapseError::Embedding("empty response".to_string()))
106 }
107}