swiftide_integrations/openai/
embed.rs

1use async_trait::async_trait;
2
3use swiftide_core::{
4    EmbeddingModel, Embeddings,
5    chat_completion::{Usage, errors::LanguageModelError},
6};
7
8use super::GenericOpenAI;
9use crate::openai::openai_error_to_language_model_error;
10
11#[async_trait]
12impl<
13    C: async_openai::config::Config
14        + std::default::Default
15        + Sync
16        + Send
17        + std::fmt::Debug
18        + Clone
19        + 'static,
20> EmbeddingModel for GenericOpenAI<C>
21{
22    async fn embed(&self, input: Vec<String>) -> Result<Embeddings, LanguageModelError> {
23        let model = self
24            .default_options
25            .embed_model
26            .as_ref()
27            .ok_or(LanguageModelError::PermanentError("Model not set".into()))?;
28
29        let request = self
30            .embed_request_defaults()
31            .model(model)
32            .input(&input)
33            .build()
34            .map_err(LanguageModelError::permanent)?;
35
36        tracing::debug!(
37            num_chunks = input.len(),
38            model = &model,
39            "[Embed] Request to openai"
40        );
41        let response = self
42            .client
43            .embeddings()
44            .create(request.clone())
45            .await
46            .map_err(openai_error_to_language_model_error)?;
47
48        let usage = Usage {
49            prompt_tokens: response.usage.prompt_tokens,
50            completion_tokens: 0,
51            total_tokens: response.usage.total_tokens,
52        };
53
54        self.track_completion(model, Some(&usage), Some(&request), Some(&response));
55
56        let num_embeddings = response.data.len();
57        tracing::debug!(num_embeddings = num_embeddings, "[Embed] Response openai");
58
59        // WARN: Naively assumes that the order is preserved. Might not always be the case.
60        Ok(response.data.into_iter().map(|d| d.embedding).collect())
61    }
62}
63
64#[cfg(test)]
65mod tests {
66    use super::*;
67    use crate::openai::OpenAI;
68    use serde_json::json;
69    use wiremock::{
70        Mock, MockServer, Request, Respond, ResponseTemplate,
71        matchers::{method, path},
72    };
73
74    #[test_log::test(tokio::test)]
75    async fn test_embed_returns_error_when_model_missing() {
76        let openai = OpenAI::builder().build().unwrap();
77        let err = openai.embed(vec!["text".into()]).await.unwrap_err();
78        assert!(matches!(err, LanguageModelError::PermanentError(_)));
79    }
80
81    #[allow(clippy::items_after_statements)]
82    #[test_log::test(tokio::test)]
83    async fn test_embed_success() {
84        let mock_server = MockServer::start().await;
85
86        let response_body = json!({
87            "data": [{
88                "embedding": [0.1, 0.2],
89                "index": 0,
90                "object": "embedding"
91            }],
92            "model": "text-embedding-3-small",
93            "object": "list",
94            "usage": {"prompt_tokens": 5, "total_tokens": 5}
95        });
96
97        struct ValidateEmbeddingRequest(serde_json::Value);
98
99        impl Respond for ValidateEmbeddingRequest {
100            fn respond(&self, request: &Request) -> ResponseTemplate {
101                let body: serde_json::Value = serde_json::from_slice(&request.body).unwrap();
102                assert_eq!(body["model"], "text-embedding-3-small");
103                assert!(body["input"].is_array());
104                ResponseTemplate::new(200).set_body_json(self.0.clone())
105            }
106        }
107
108        Mock::given(method("POST"))
109            .and(path("/embeddings"))
110            .respond_with(ValidateEmbeddingRequest(response_body))
111            .mount(&mock_server)
112            .await;
113
114        let config = async_openai::config::OpenAIConfig::new().with_api_base(mock_server.uri());
115        let client = async_openai::Client::with_config(config);
116
117        let openai = OpenAI::builder()
118            .client(client)
119            .default_embed_model("text-embedding-3-small")
120            .build()
121            .unwrap();
122
123        let embeddings = openai
124            .embed(vec!["Hello".into(), "World".into()])
125            .await
126            .unwrap();
127
128        assert_eq!(embeddings.len(), 1);
129        assert_eq!(embeddings[0], vec![0.1, 0.2]);
130    }
131}