swiftide_integrations/openai/
embed.rs1use async_trait::async_trait;
2
3use swiftide_core::{
4 EmbeddingModel, Embeddings,
5 chat_completion::{Usage, errors::LanguageModelError},
6};
7
8use super::GenericOpenAI;
9use crate::openai::openai_error_to_language_model_error;
10
11#[async_trait]
12impl<
13 C: async_openai::config::Config
14 + std::default::Default
15 + Sync
16 + Send
17 + std::fmt::Debug
18 + Clone
19 + 'static,
20> EmbeddingModel for GenericOpenAI<C>
21{
22 async fn embed(&self, input: Vec<String>) -> Result<Embeddings, LanguageModelError> {
23 let model = self
24 .default_options
25 .embed_model
26 .as_ref()
27 .ok_or(LanguageModelError::PermanentError("Model not set".into()))?;
28
29 let request = self
30 .embed_request_defaults()
31 .model(model)
32 .input(&input)
33 .build()
34 .map_err(LanguageModelError::permanent)?;
35
36 tracing::debug!(
37 num_chunks = input.len(),
38 model = &model,
39 "[Embed] Request to openai"
40 );
41 let response = self
42 .client
43 .embeddings()
44 .create(request.clone())
45 .await
46 .map_err(openai_error_to_language_model_error)?;
47
48 let usage = Usage {
49 prompt_tokens: response.usage.prompt_tokens,
50 completion_tokens: 0,
51 total_tokens: response.usage.total_tokens,
52 };
53
54 self.track_completion(model, Some(&usage), Some(&request), Some(&response));
55
56 let num_embeddings = response.data.len();
57 tracing::debug!(num_embeddings = num_embeddings, "[Embed] Response openai");
58
59 Ok(response.data.into_iter().map(|d| d.embedding).collect())
61 }
62}
63
64#[cfg(test)]
65mod tests {
66 use super::*;
67 use crate::openai::OpenAI;
68 use serde_json::json;
69 use wiremock::{
70 Mock, MockServer, Request, Respond, ResponseTemplate,
71 matchers::{method, path},
72 };
73
74 #[test_log::test(tokio::test)]
75 async fn test_embed_returns_error_when_model_missing() {
76 let openai = OpenAI::builder().build().unwrap();
77 let err = openai.embed(vec!["text".into()]).await.unwrap_err();
78 assert!(matches!(err, LanguageModelError::PermanentError(_)));
79 }
80
81 #[allow(clippy::items_after_statements)]
82 #[test_log::test(tokio::test)]
83 async fn test_embed_success() {
84 let mock_server = MockServer::start().await;
85
86 let response_body = json!({
87 "data": [{
88 "embedding": [0.1, 0.2],
89 "index": 0,
90 "object": "embedding"
91 }],
92 "model": "text-embedding-3-small",
93 "object": "list",
94 "usage": {"prompt_tokens": 5, "total_tokens": 5}
95 });
96
97 struct ValidateEmbeddingRequest(serde_json::Value);
98
99 impl Respond for ValidateEmbeddingRequest {
100 fn respond(&self, request: &Request) -> ResponseTemplate {
101 let body: serde_json::Value = serde_json::from_slice(&request.body).unwrap();
102 assert_eq!(body["model"], "text-embedding-3-small");
103 assert!(body["input"].is_array());
104 ResponseTemplate::new(200).set_body_json(self.0.clone())
105 }
106 }
107
108 Mock::given(method("POST"))
109 .and(path("/embeddings"))
110 .respond_with(ValidateEmbeddingRequest(response_body))
111 .mount(&mock_server)
112 .await;
113
114 let config = async_openai::config::OpenAIConfig::new().with_api_base(mock_server.uri());
115 let client = async_openai::Client::with_config(config);
116
117 let openai = OpenAI::builder()
118 .client(client)
119 .default_embed_model("text-embedding-3-small")
120 .build()
121 .unwrap();
122
123 let embeddings = openai
124 .embed(vec!["Hello".into(), "World".into()])
125 .await
126 .unwrap();
127
128 assert_eq!(embeddings.len(), 1);
129 assert_eq!(embeddings[0], vec![0.1, 0.2]);
130 }
131}