spec_ai_core/
embeddings.rs

1use anyhow::{anyhow, Context, Result};
2use async_openai::{
3    config::OpenAIConfig, types::CreateEmbeddingRequestArgs, Client as OpenAIClient,
4};
5use async_trait::async_trait;
6use std::sync::Arc;
7
8/// Trait that describes an embeddings-capable service.
9#[async_trait]
10pub trait EmbeddingsService: Send + Sync + 'static {
11    /// Generate embeddings for the provided inputs using the given model name.
12    async fn create_embeddings(&self, model: &str, inputs: Vec<String>) -> Result<Vec<Vec<f32>>>;
13}
14
15/// Client that wraps an embeddings service and keeps track of the model name.
16#[derive(Clone)]
17pub struct EmbeddingsClient {
18    model: String,
19    service: Arc<dyn EmbeddingsService>,
20}
21
22impl EmbeddingsClient {
23    /// Create a client that uses the default OpenAI configuration (OPENAI_API_KEY).
24    pub fn new(model: impl Into<String>) -> Self {
25        Self::with_service(
26            model,
27            Arc::new(OpenAIEmbeddingsService::new()) as Arc<dyn EmbeddingsService>,
28        )
29    }
30
31    /// Create a client that uses the provided API key.
32    pub fn with_api_key(model: impl Into<String>, api_key: impl Into<String>) -> Self {
33        let service = OpenAIEmbeddingsService::with_api_key(api_key);
34        Self::with_service(model, Arc::new(service))
35    }
36
37    /// Create a client that uses the provided OpenAI configuration.
38    pub fn with_config(model: impl Into<String>, config: OpenAIConfig) -> Self {
39        let service = OpenAIEmbeddingsService::with_config(config);
40        Self::with_service(model, Arc::new(service))
41    }
42
43    /// Create a client around a custom embeddings service implementation.
44    pub fn with_service(model: impl Into<String>, service: Arc<dyn EmbeddingsService>) -> Self {
45        Self {
46            model: model.into(),
47            service,
48        }
49    }
50
51    /// Ask the underlying service for embeddings for a batch of inputs.
52    pub async fn embed_batch<T>(&self, inputs: &[T]) -> Result<Vec<Vec<f32>>>
53    where
54        T: AsRef<str>,
55    {
56        if inputs.is_empty() {
57            return Ok(Vec::new());
58        }
59
60        let sanitized_inputs = inputs
61            .iter()
62            .map(|input| sanitize_embedding_input(input.as_ref()))
63            .collect::<Vec<_>>();
64
65        self.service
66            .create_embeddings(&self.model, sanitized_inputs)
67            .await
68    }
69
70    /// Ask the underlying service for an embedding for a single input.
71    pub async fn embed(&self, input: &str) -> Result<Vec<f32>> {
72        let inputs = [input];
73        let mut embeddings = self.embed_batch(&inputs).await?;
74        Ok(embeddings.pop().unwrap_or_default())
75    }
76}
77
78fn sanitize_embedding_input(input: &str) -> String {
79    const MAX_LEN: usize = 4096;
80    let mut processed = input
81        .replace('\\', "\\\\")
82        .replace('\r', "\\r")
83        .replace('\n', "\\n");
84
85    if processed.len() > MAX_LEN {
86        processed.truncate(MAX_LEN);
87        processed.push_str("\\n[truncated]");
88    }
89
90    processed
91}
92
93#[cfg(test)]
94mod embedding_sanitizer_tests {
95    use super::sanitize_embedding_input;
96
97    #[test]
98    fn sanitizes_newlines_and_backslashes() {
99        let raw = "line1\nline2\r\npath\\to\\file";
100        let sanitized = sanitize_embedding_input(raw);
101        assert_eq!(sanitized, "line1\\nline2\\r\\npath\\\\to\\\\file");
102    }
103
104    #[test]
105    fn truncates_long_payloads() {
106        let raw = "a".repeat(5000);
107        let sanitized = sanitize_embedding_input(&raw);
108        assert!(sanitized.ends_with("\\n[truncated]"));
109        assert!(sanitized.len() <= 4096 + "\\n[truncated]".len());
110    }
111}
112
113/// Default service implementation that uses the async-openai client.
114#[derive(Clone)]
115pub struct OpenAIEmbeddingsService {
116    client: OpenAIClient<OpenAIConfig>,
117}
118
119impl Default for OpenAIEmbeddingsService {
120    fn default() -> Self {
121        Self::new()
122    }
123}
124
125impl OpenAIEmbeddingsService {
126    /// Create a service with the default OpenAI configuration.
127    pub fn new() -> Self {
128        Self {
129            client: OpenAIClient::new(),
130        }
131    }
132
133    /// Create a service backed by a specific API key.
134    pub fn with_api_key(api_key: impl Into<String>) -> Self {
135        let config = OpenAIConfig::new().with_api_key(api_key);
136        Self::with_config(config)
137    }
138
139    /// Create a service with a custom OpenAI configuration.
140    pub fn with_config(config: OpenAIConfig) -> Self {
141        Self {
142            client: OpenAIClient::with_config(config),
143        }
144    }
145}
146
147#[async_trait]
148impl EmbeddingsService for OpenAIEmbeddingsService {
149    async fn create_embeddings(&self, model: &str, inputs: Vec<String>) -> Result<Vec<Vec<f32>>> {
150        if inputs.is_empty() {
151            return Ok(Vec::new());
152        }
153
154        let request = CreateEmbeddingRequestArgs::default()
155            .model(model)
156            .input(inputs)
157            .build()
158            .context("Failed to build embedding request")?;
159
160        let response = self
161            .client
162            .embeddings()
163            .create(request)
164            .await
165            .context("OpenAI embeddings request failed")?;
166
167        let embeddings = response
168            .data
169            .into_iter()
170            .map(|item| item.embedding)
171            .collect::<Vec<_>>();
172
173        if embeddings.is_empty() {
174            Err(anyhow!("OpenAI embeddings response was empty"))
175        } else {
176            Ok(embeddings)
177        }
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184    use anyhow::anyhow;
185    use async_trait::async_trait;
186    use std::sync::Arc;
187
188    #[derive(Clone)]
189    struct DummyService {
190        embeddings: Vec<Vec<f32>>,
191        fail: bool,
192    }
193
194    impl DummyService {
195        fn ok_single(embedding: Vec<f32>) -> Self {
196            Self {
197                embeddings: vec![embedding],
198                fail: false,
199            }
200        }
201
202        fn ok_batch(embeddings: Vec<Vec<f32>>) -> Self {
203            Self {
204                embeddings,
205                fail: false,
206            }
207        }
208
209        fn err() -> Self {
210            Self {
211                embeddings: Vec::new(),
212                fail: true,
213            }
214        }
215    }
216
217    #[async_trait]
218    impl EmbeddingsService for DummyService {
219        async fn create_embeddings(
220            &self,
221            _model: &str,
222            _inputs: Vec<String>,
223        ) -> Result<Vec<Vec<f32>>> {
224            if self.fail {
225                return Err(anyhow!("boom"));
226            }
227
228            if self.embeddings.is_empty() {
229                return Ok(Vec::new());
230            }
231
232            Ok(self.embeddings.clone())
233        }
234    }
235
236    #[tokio::test]
237    async fn embed_returns_the_service_embedding() {
238        let embedding = vec![0.1, 0.2];
239        let service = Arc::new(DummyService::ok_single(embedding.clone()));
240        let client = EmbeddingsClient::with_service("model", service);
241
242        let result = client.embed("input").await.unwrap();
243
244        assert_eq!(result, embedding);
245    }
246
247    #[tokio::test]
248    async fn embed_propagates_errors() {
249        let service = Arc::new(DummyService::err());
250        let client = EmbeddingsClient::with_service("model", service);
251
252        let result = client.embed("input").await;
253
254        assert!(result.is_err());
255    }
256
257    #[tokio::test]
258    async fn embed_batch_returns_all_embeddings() {
259        let service = Arc::new(DummyService::ok_batch(vec![vec![0.1, 0.2], vec![0.3, 0.4]]));
260        let client = EmbeddingsClient::with_service("model", service);
261
262        let inputs = ["first", "second"];
263        let result = client.embed_batch(&inputs).await.unwrap();
264
265        assert_eq!(result.len(), 2);
266        assert_eq!(result[0], vec![0.1, 0.2]);
267        assert_eq!(result[1], vec![0.3, 0.4]);
268    }
269}