spec_ai_core/
embeddings.rs1use anyhow::{anyhow, Context, Result};
2use async_openai::{
3 config::OpenAIConfig, types::CreateEmbeddingRequestArgs, Client as OpenAIClient,
4};
5use async_trait::async_trait;
6use std::sync::Arc;
7
8#[async_trait]
10pub trait EmbeddingsService: Send + Sync + 'static {
11 async fn create_embeddings(&self, model: &str, inputs: Vec<String>) -> Result<Vec<Vec<f32>>>;
13}
14
15#[derive(Clone)]
17pub struct EmbeddingsClient {
18 model: String,
19 service: Arc<dyn EmbeddingsService>,
20}
21
22impl EmbeddingsClient {
23 pub fn new(model: impl Into<String>) -> Self {
25 Self::with_service(
26 model,
27 Arc::new(OpenAIEmbeddingsService::new()) as Arc<dyn EmbeddingsService>,
28 )
29 }
30
31 pub fn with_api_key(model: impl Into<String>, api_key: impl Into<String>) -> Self {
33 let service = OpenAIEmbeddingsService::with_api_key(api_key);
34 Self::with_service(model, Arc::new(service))
35 }
36
37 pub fn with_config(model: impl Into<String>, config: OpenAIConfig) -> Self {
39 let service = OpenAIEmbeddingsService::with_config(config);
40 Self::with_service(model, Arc::new(service))
41 }
42
43 pub fn with_service(model: impl Into<String>, service: Arc<dyn EmbeddingsService>) -> Self {
45 Self {
46 model: model.into(),
47 service,
48 }
49 }
50
51 pub async fn embed_batch<T>(&self, inputs: &[T]) -> Result<Vec<Vec<f32>>>
53 where
54 T: AsRef<str>,
55 {
56 if inputs.is_empty() {
57 return Ok(Vec::new());
58 }
59
60 let sanitized_inputs = inputs
61 .iter()
62 .map(|input| sanitize_embedding_input(input.as_ref()))
63 .collect::<Vec<_>>();
64
65 self.service
66 .create_embeddings(&self.model, sanitized_inputs)
67 .await
68 }
69
70 pub async fn embed(&self, input: &str) -> Result<Vec<f32>> {
72 let inputs = [input];
73 let mut embeddings = self.embed_batch(&inputs).await?;
74 Ok(embeddings.pop().unwrap_or_default())
75 }
76}
77
78fn sanitize_embedding_input(input: &str) -> String {
79 const MAX_LEN: usize = 4096;
80 let mut processed = input
81 .replace('\\', "\\\\")
82 .replace('\r', "\\r")
83 .replace('\n', "\\n");
84
85 if processed.len() > MAX_LEN {
86 processed.truncate(MAX_LEN);
87 processed.push_str("\\n[truncated]");
88 }
89
90 processed
91}
92
93#[cfg(test)]
94mod embedding_sanitizer_tests {
95 use super::sanitize_embedding_input;
96
97 #[test]
98 fn sanitizes_newlines_and_backslashes() {
99 let raw = "line1\nline2\r\npath\\to\\file";
100 let sanitized = sanitize_embedding_input(raw);
101 assert_eq!(sanitized, "line1\\nline2\\r\\npath\\\\to\\\\file");
102 }
103
104 #[test]
105 fn truncates_long_payloads() {
106 let raw = "a".repeat(5000);
107 let sanitized = sanitize_embedding_input(&raw);
108 assert!(sanitized.ends_with("\\n[truncated]"));
109 assert!(sanitized.len() <= 4096 + "\\n[truncated]".len());
110 }
111}
112
113#[derive(Clone)]
115pub struct OpenAIEmbeddingsService {
116 client: OpenAIClient<OpenAIConfig>,
117}
118
119impl Default for OpenAIEmbeddingsService {
120 fn default() -> Self {
121 Self::new()
122 }
123}
124
125impl OpenAIEmbeddingsService {
126 pub fn new() -> Self {
128 Self {
129 client: OpenAIClient::new(),
130 }
131 }
132
133 pub fn with_api_key(api_key: impl Into<String>) -> Self {
135 let config = OpenAIConfig::new().with_api_key(api_key);
136 Self::with_config(config)
137 }
138
139 pub fn with_config(config: OpenAIConfig) -> Self {
141 Self {
142 client: OpenAIClient::with_config(config),
143 }
144 }
145}
146
147#[async_trait]
148impl EmbeddingsService for OpenAIEmbeddingsService {
149 async fn create_embeddings(&self, model: &str, inputs: Vec<String>) -> Result<Vec<Vec<f32>>> {
150 if inputs.is_empty() {
151 return Ok(Vec::new());
152 }
153
154 let request = CreateEmbeddingRequestArgs::default()
155 .model(model)
156 .input(inputs)
157 .build()
158 .context("Failed to build embedding request")?;
159
160 let response = self
161 .client
162 .embeddings()
163 .create(request)
164 .await
165 .context("OpenAI embeddings request failed")?;
166
167 let embeddings = response
168 .data
169 .into_iter()
170 .map(|item| item.embedding)
171 .collect::<Vec<_>>();
172
173 if embeddings.is_empty() {
174 Err(anyhow!("OpenAI embeddings response was empty"))
175 } else {
176 Ok(embeddings)
177 }
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184 use anyhow::anyhow;
185 use async_trait::async_trait;
186 use std::sync::Arc;
187
188 #[derive(Clone)]
189 struct DummyService {
190 embeddings: Vec<Vec<f32>>,
191 fail: bool,
192 }
193
194 impl DummyService {
195 fn ok_single(embedding: Vec<f32>) -> Self {
196 Self {
197 embeddings: vec![embedding],
198 fail: false,
199 }
200 }
201
202 fn ok_batch(embeddings: Vec<Vec<f32>>) -> Self {
203 Self {
204 embeddings,
205 fail: false,
206 }
207 }
208
209 fn err() -> Self {
210 Self {
211 embeddings: Vec::new(),
212 fail: true,
213 }
214 }
215 }
216
217 #[async_trait]
218 impl EmbeddingsService for DummyService {
219 async fn create_embeddings(
220 &self,
221 _model: &str,
222 _inputs: Vec<String>,
223 ) -> Result<Vec<Vec<f32>>> {
224 if self.fail {
225 return Err(anyhow!("boom"));
226 }
227
228 if self.embeddings.is_empty() {
229 return Ok(Vec::new());
230 }
231
232 Ok(self.embeddings.clone())
233 }
234 }
235
236 #[tokio::test]
237 async fn embed_returns_the_service_embedding() {
238 let embedding = vec![0.1, 0.2];
239 let service = Arc::new(DummyService::ok_single(embedding.clone()));
240 let client = EmbeddingsClient::with_service("model", service);
241
242 let result = client.embed("input").await.unwrap();
243
244 assert_eq!(result, embedding);
245 }
246
247 #[tokio::test]
248 async fn embed_propagates_errors() {
249 let service = Arc::new(DummyService::err());
250 let client = EmbeddingsClient::with_service("model", service);
251
252 let result = client.embed("input").await;
253
254 assert!(result.is_err());
255 }
256
257 #[tokio::test]
258 async fn embed_batch_returns_all_embeddings() {
259 let service = Arc::new(DummyService::ok_batch(vec![vec![0.1, 0.2], vec![0.3, 0.4]]));
260 let client = EmbeddingsClient::with_service("model", service);
261
262 let inputs = ["first", "second"];
263 let result = client.embed_batch(&inputs).await.unwrap();
264
265 assert_eq!(result.len(), 2);
266 assert_eq!(result[0], vec![0.1, 0.2]);
267 assert_eq!(result[1], vec![0.3, 0.4]);
268 }
269}