Skip to main content

synaptic_embeddings/
openai.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use serde_json::json;
5use synaptic_core::SynapseError;
6use synaptic_models::backend::{ProviderBackend, ProviderRequest};
7
8use crate::Embeddings;
9
10pub struct OpenAiEmbeddingsConfig {
11    pub api_key: String,
12    pub model: String,
13    pub base_url: String,
14}
15
16impl OpenAiEmbeddingsConfig {
17    pub fn new(api_key: impl Into<String>) -> Self {
18        Self {
19            api_key: api_key.into(),
20            model: "text-embedding-3-small".to_string(),
21            base_url: "https://api.openai.com/v1".to_string(),
22        }
23    }
24
25    pub fn with_model(mut self, model: impl Into<String>) -> Self {
26        self.model = model.into();
27        self
28    }
29
30    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
31        self.base_url = base_url.into();
32        self
33    }
34}
35
36pub struct OpenAiEmbeddings {
37    config: OpenAiEmbeddingsConfig,
38    backend: Arc<dyn ProviderBackend>,
39}
40
41impl OpenAiEmbeddings {
42    pub fn new(config: OpenAiEmbeddingsConfig, backend: Arc<dyn ProviderBackend>) -> Self {
43        Self { config, backend }
44    }
45
46    fn build_request(&self, input: Vec<String>) -> ProviderRequest {
47        ProviderRequest {
48            url: format!("{}/embeddings", self.config.base_url),
49            headers: vec![
50                (
51                    "Authorization".to_string(),
52                    format!("Bearer {}", self.config.api_key),
53                ),
54                ("Content-Type".to_string(), "application/json".to_string()),
55            ],
56            body: json!({
57                "model": self.config.model,
58                "input": input,
59            }),
60        }
61    }
62
63    fn parse_response(&self, body: &serde_json::Value) -> Result<Vec<Vec<f32>>, SynapseError> {
64        let data = body.get("data").and_then(|d| d.as_array()).ok_or_else(|| {
65            SynapseError::Embedding("missing 'data' field in response".to_string())
66        })?;
67
68        let mut embeddings = Vec::with_capacity(data.len());
69        for item in data {
70            let embedding = item
71                .get("embedding")
72                .and_then(|e| e.as_array())
73                .ok_or_else(|| SynapseError::Embedding("missing 'embedding' field".to_string()))?
74                .iter()
75                .map(|v| v.as_f64().unwrap_or(0.0) as f32)
76                .collect();
77            embeddings.push(embedding);
78        }
79
80        Ok(embeddings)
81    }
82}
83
84#[async_trait]
85impl Embeddings for OpenAiEmbeddings {
86    async fn embed_documents(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, SynapseError> {
87        let input: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
88        let request = self.build_request(input);
89        let response = self.backend.send(request).await?;
90
91        if response.status != 200 {
92            return Err(SynapseError::Embedding(format!(
93                "OpenAI API error ({}): {}",
94                response.status, response.body
95            )));
96        }
97
98        self.parse_response(&response.body)
99    }
100
101    async fn embed_query(&self, text: &str) -> Result<Vec<f32>, SynapseError> {
102        let mut results = self.embed_documents(&[text]).await?;
103        results
104            .pop()
105            .ok_or_else(|| SynapseError::Embedding("empty response".to_string()))
106    }
107}