serdes_ai_embeddings/
openai.rs1use crate::embedding::Embedding;
4use crate::error::{EmbeddingError, EmbeddingResult};
5use crate::model::{EmbedInput, EmbeddingModel, EmbeddingOutput, EmbeddingSettings};
6use async_trait::async_trait;
7use reqwest::Client;
8use serde::{Deserialize, Serialize};
9
10#[derive(Clone)]
12pub struct OpenAIEmbeddingModel {
13 model_name: String,
14 client: Client,
15 api_key: String,
16 base_url: String,
17 default_dimensions: usize,
18}
19
20impl OpenAIEmbeddingModel {
21 pub fn new(model_name: impl Into<String>, api_key: impl Into<String>) -> Self {
23 let name = model_name.into();
24 let dimensions = Self::model_dimensions(&name);
25
26 Self {
27 model_name: name,
28 client: Client::new(),
29 api_key: api_key.into(),
30 base_url: "https://api.openai.com/v1".to_string(),
31 default_dimensions: dimensions,
32 }
33 }
34
35 pub fn from_env(model_name: impl Into<String>) -> EmbeddingResult<Self> {
37 let api_key = std::env::var("OPENAI_API_KEY")
38 .map_err(|_| EmbeddingError::config("OPENAI_API_KEY not set"))?;
39 Ok(Self::new(model_name, api_key))
40 }
41
42 pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
44 self.base_url = url.into();
45 self
46 }
47
48 pub fn with_client(mut self, client: Client) -> Self {
50 self.client = client;
51 self
52 }
53
54 fn model_dimensions(name: &str) -> usize {
55 match name {
56 "text-embedding-3-small" => 1536,
57 "text-embedding-3-large" => 3072,
58 "text-embedding-ada-002" => 1536,
59 _ => 1536,
60 }
61 }
62
63 fn model_max_tokens(name: &str) -> usize {
64 match name {
65 "text-embedding-3-small" => 8191,
66 "text-embedding-3-large" => 8191,
67 "text-embedding-ada-002" => 8191,
68 _ => 8191,
69 }
70 }
71}
72
73impl std::fmt::Debug for OpenAIEmbeddingModel {
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 f.debug_struct("OpenAIEmbeddingModel")
76 .field("model_name", &self.model_name)
77 .field("base_url", &self.base_url)
78 .finish()
79 }
80}
81
82#[derive(Debug, Serialize)]
83struct OpenAIEmbeddingRequest {
84 model: String,
85 input: Vec<String>,
86 #[serde(skip_serializing_if = "Option::is_none")]
87 dimensions: Option<usize>,
88 #[serde(skip_serializing_if = "Option::is_none")]
89 user: Option<String>,
90}
91
92#[derive(Debug, Deserialize)]
93struct OpenAIEmbeddingResponse {
94 data: Vec<OpenAIEmbeddingData>,
95 model: String,
96 usage: OpenAIUsage,
97}
98
99#[derive(Debug, Deserialize)]
100struct OpenAIEmbeddingData {
101 embedding: Vec<f32>,
102 index: usize,
103}
104
105#[derive(Debug, Deserialize)]
106#[allow(dead_code)]
107struct OpenAIUsage {
108 prompt_tokens: u64,
109 total_tokens: u64,
110}
111
112#[derive(Debug, Deserialize)]
113struct OpenAIErrorResponse {
114 error: OpenAIError,
115}
116
117#[derive(Debug, Deserialize)]
118#[allow(dead_code)]
119struct OpenAIError {
120 message: String,
121 #[serde(rename = "type")]
122 error_type: String,
123 code: Option<String>,
124}
125
126#[async_trait]
127impl EmbeddingModel for OpenAIEmbeddingModel {
128 fn name(&self) -> &str {
129 &self.model_name
130 }
131
132 fn dimensions(&self) -> usize {
133 self.default_dimensions
134 }
135
136 fn max_tokens(&self) -> usize {
137 Self::model_max_tokens(&self.model_name)
138 }
139
140 async fn embed(
141 &self,
142 input: EmbedInput,
143 settings: &EmbeddingSettings,
144 ) -> EmbeddingResult<EmbeddingOutput> {
145 let texts = input.into_texts();
146
147 let request = OpenAIEmbeddingRequest {
148 model: self.model_name.clone(),
149 input: texts.clone(),
150 dimensions: settings.dimensions,
151 user: settings.user.clone(),
152 };
153
154 let response = self
155 .client
156 .post(format!("{}/embeddings", self.base_url))
157 .header("Authorization", format!("Bearer {}", self.api_key))
158 .header("Content-Type", "application/json")
159 .json(&request)
160 .send()
161 .await
162 .map_err(|e| EmbeddingError::Api(e.to_string()))?;
163
164 let status = response.status();
165 if !status.is_success() {
166 let body = response.text().await.unwrap_or_default();
167
168 if let Ok(error_resp) = serde_json::from_str::<OpenAIErrorResponse>(&body) {
170 if status.as_u16() == 429 {
171 return Err(EmbeddingError::RateLimited { retry_after: None });
172 }
173 return Err(EmbeddingError::Api(error_resp.error.message));
174 }
175
176 return Err(EmbeddingError::Http {
177 status: status.as_u16(),
178 body,
179 });
180 }
181
182 let resp: OpenAIEmbeddingResponse = response
183 .json()
184 .await
185 .map_err(|e| EmbeddingError::Api(e.to_string()))?;
186
187 let mut data = resp.data;
189 data.sort_by_key(|d| d.index);
190
191 let embeddings: Vec<Embedding> = data
192 .into_iter()
193 .map(|d| {
194 let text = texts.get(d.index).cloned();
195 let mut emb = Embedding::new(d.embedding)
196 .with_model(&resp.model)
197 .with_index(d.index);
198 if let Some(t) = text {
199 emb = emb.with_text(t);
200 }
201 emb
202 })
203 .collect();
204
205 Ok(EmbeddingOutput::new(embeddings, &resp.model).with_tokens(resp.usage.total_tokens))
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212
213 #[test]
214 fn test_model_dimensions() {
215 assert_eq!(
216 OpenAIEmbeddingModel::model_dimensions("text-embedding-3-small"),
217 1536
218 );
219 assert_eq!(
220 OpenAIEmbeddingModel::model_dimensions("text-embedding-3-large"),
221 3072
222 );
223 }
224
225 #[test]
226 fn test_model_creation() {
227 let model = OpenAIEmbeddingModel::new("text-embedding-3-small", "test-key");
228 assert_eq!(model.name(), "text-embedding-3-small");
229 assert_eq!(model.dimensions(), 1536);
230 }
231
232 #[test]
233 fn test_with_base_url() {
234 let model =
235 OpenAIEmbeddingModel::new("test", "key").with_base_url("https://custom.api.com");
236 assert_eq!(model.base_url, "https://custom.api.com");
237 }
238
239 #[test]
240 fn test_debug() {
241 let model = OpenAIEmbeddingModel::new("test", "secret-key");
242 let debug = format!("{:?}", model);
243 assert!(debug.contains("test"));
244 assert!(!debug.contains("secret")); }
246}