Skip to main content

serdes_ai_embeddings/
openai.rs

1//! OpenAI embedding model implementation.
2
3use crate::embedding::Embedding;
4use crate::error::{EmbeddingError, EmbeddingResult};
5use crate::model::{EmbedInput, EmbeddingModel, EmbeddingOutput, EmbeddingSettings};
6use async_trait::async_trait;
7use reqwest::Client;
8use serde::{Deserialize, Serialize};
9
10/// OpenAI embedding model.
11#[derive(Clone)]
12pub struct OpenAIEmbeddingModel {
13    model_name: String,
14    client: Client,
15    api_key: String,
16    base_url: String,
17    default_dimensions: usize,
18}
19
20impl OpenAIEmbeddingModel {
21    /// Create a new OpenAI embedding model.
22    pub fn new(model_name: impl Into<String>, api_key: impl Into<String>) -> Self {
23        let name = model_name.into();
24        let dimensions = Self::model_dimensions(&name);
25
26        Self {
27            model_name: name,
28            client: Client::new(),
29            api_key: api_key.into(),
30            base_url: "https://api.openai.com/v1".to_string(),
31            default_dimensions: dimensions,
32        }
33    }
34
35    /// Create from environment variable.
36    pub fn from_env(model_name: impl Into<String>) -> EmbeddingResult<Self> {
37        let api_key = std::env::var("OPENAI_API_KEY")
38            .map_err(|_| EmbeddingError::config("OPENAI_API_KEY not set"))?;
39        Ok(Self::new(model_name, api_key))
40    }
41
42    /// Set custom base URL.
43    pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
44        self.base_url = url.into();
45        self
46    }
47
48    /// Set custom HTTP client.
49    pub fn with_client(mut self, client: Client) -> Self {
50        self.client = client;
51        self
52    }
53
54    fn model_dimensions(name: &str) -> usize {
55        match name {
56            "text-embedding-3-small" => 1536,
57            "text-embedding-3-large" => 3072,
58            "text-embedding-ada-002" => 1536,
59            _ => 1536,
60        }
61    }
62
63    fn model_max_tokens(name: &str) -> usize {
64        match name {
65            "text-embedding-3-small" => 8191,
66            "text-embedding-3-large" => 8191,
67            "text-embedding-ada-002" => 8191,
68            _ => 8191,
69        }
70    }
71}
72
73impl std::fmt::Debug for OpenAIEmbeddingModel {
74    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75        f.debug_struct("OpenAIEmbeddingModel")
76            .field("model_name", &self.model_name)
77            .field("base_url", &self.base_url)
78            .finish()
79    }
80}
81
82#[derive(Debug, Serialize)]
83struct OpenAIEmbeddingRequest {
84    model: String,
85    input: Vec<String>,
86    #[serde(skip_serializing_if = "Option::is_none")]
87    dimensions: Option<usize>,
88    #[serde(skip_serializing_if = "Option::is_none")]
89    user: Option<String>,
90}
91
92#[derive(Debug, Deserialize)]
93struct OpenAIEmbeddingResponse {
94    data: Vec<OpenAIEmbeddingData>,
95    model: String,
96    usage: OpenAIUsage,
97}
98
99#[derive(Debug, Deserialize)]
100struct OpenAIEmbeddingData {
101    embedding: Vec<f32>,
102    index: usize,
103}
104
105#[derive(Debug, Deserialize)]
106#[allow(dead_code)]
107struct OpenAIUsage {
108    prompt_tokens: u64,
109    total_tokens: u64,
110}
111
112#[derive(Debug, Deserialize)]
113struct OpenAIErrorResponse {
114    error: OpenAIError,
115}
116
117#[derive(Debug, Deserialize)]
118#[allow(dead_code)]
119struct OpenAIError {
120    message: String,
121    #[serde(rename = "type")]
122    error_type: String,
123    code: Option<String>,
124}
125
126#[async_trait]
127impl EmbeddingModel for OpenAIEmbeddingModel {
128    fn name(&self) -> &str {
129        &self.model_name
130    }
131
132    fn dimensions(&self) -> usize {
133        self.default_dimensions
134    }
135
136    fn max_tokens(&self) -> usize {
137        Self::model_max_tokens(&self.model_name)
138    }
139
140    async fn embed(
141        &self,
142        input: EmbedInput,
143        settings: &EmbeddingSettings,
144    ) -> EmbeddingResult<EmbeddingOutput> {
145        let texts = input.into_texts();
146
147        let request = OpenAIEmbeddingRequest {
148            model: self.model_name.clone(),
149            input: texts.clone(),
150            dimensions: settings.dimensions,
151            user: settings.user.clone(),
152        };
153
154        let response = self
155            .client
156            .post(format!("{}/embeddings", self.base_url))
157            .header("Authorization", format!("Bearer {}", self.api_key))
158            .header("Content-Type", "application/json")
159            .json(&request)
160            .send()
161            .await
162            .map_err(|e| EmbeddingError::Api(e.to_string()))?;
163
164        let status = response.status();
165        if !status.is_success() {
166            let body = response.text().await.unwrap_or_default();
167
168            // Try to parse error response
169            if let Ok(error_resp) = serde_json::from_str::<OpenAIErrorResponse>(&body) {
170                if status.as_u16() == 429 {
171                    return Err(EmbeddingError::RateLimited { retry_after: None });
172                }
173                return Err(EmbeddingError::Api(error_resp.error.message));
174            }
175
176            return Err(EmbeddingError::Http {
177                status: status.as_u16(),
178                body,
179            });
180        }
181
182        let resp: OpenAIEmbeddingResponse = response
183            .json()
184            .await
185            .map_err(|e| EmbeddingError::Api(e.to_string()))?;
186
187        // Sort by index and create embeddings
188        let mut data = resp.data;
189        data.sort_by_key(|d| d.index);
190
191        let embeddings: Vec<Embedding> = data
192            .into_iter()
193            .map(|d| {
194                let text = texts.get(d.index).cloned();
195                let mut emb = Embedding::new(d.embedding)
196                    .with_model(&resp.model)
197                    .with_index(d.index);
198                if let Some(t) = text {
199                    emb = emb.with_text(t);
200                }
201                emb
202            })
203            .collect();
204
205        Ok(EmbeddingOutput::new(embeddings, &resp.model).with_tokens(resp.usage.total_tokens))
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212
213    #[test]
214    fn test_model_dimensions() {
215        assert_eq!(
216            OpenAIEmbeddingModel::model_dimensions("text-embedding-3-small"),
217            1536
218        );
219        assert_eq!(
220            OpenAIEmbeddingModel::model_dimensions("text-embedding-3-large"),
221            3072
222        );
223    }
224
225    #[test]
226    fn test_model_creation() {
227        let model = OpenAIEmbeddingModel::new("text-embedding-3-small", "test-key");
228        assert_eq!(model.name(), "text-embedding-3-small");
229        assert_eq!(model.dimensions(), 1536);
230    }
231
232    #[test]
233    fn test_with_base_url() {
234        let model =
235            OpenAIEmbeddingModel::new("test", "key").with_base_url("https://custom.api.com");
236        assert_eq!(model.base_url, "https://custom.api.com");
237    }
238
239    #[test]
240    fn test_debug() {
241        let model = OpenAIEmbeddingModel::new("test", "secret-key");
242        let debug = format!("{:?}", model);
243        assert!(debug.contains("test"));
244        assert!(!debug.contains("secret")); // Key should not be in debug
245    }
246}