swiftide_integrations/openai/
embed.rs

1use async_trait::async_trait;
2
3use swiftide_core::{EmbeddingModel, Embeddings, chat_completion::errors::LanguageModelError};
4
5use super::GenericOpenAI;
6use crate::openai::openai_error_to_language_model_error;
7
8#[async_trait]
9impl<
10    C: async_openai::config::Config + std::default::Default + Sync + Send + std::fmt::Debug + Clone,
11> EmbeddingModel for GenericOpenAI<C>
12{
13    async fn embed(&self, input: Vec<String>) -> Result<Embeddings, LanguageModelError> {
14        let model = self
15            .default_options
16            .embed_model
17            .as_ref()
18            .ok_or(LanguageModelError::PermanentError("Model not set".into()))?;
19
20        let request = self
21            .embed_request_defaults()
22            .model(model)
23            .input(&input)
24            .build()
25            .map_err(LanguageModelError::permanent)?;
26
27        tracing::debug!(
28            num_chunks = input.len(),
29            model = &model,
30            "[Embed] Request to openai"
31        );
32        let response = self
33            .client
34            .embeddings()
35            .create(request)
36            .await
37            .map_err(openai_error_to_language_model_error)?;
38
39        #[cfg(feature = "metrics")]
40        {
41            swiftide_core::metrics::emit_usage(
42                model,
43                response.usage.prompt_tokens.into(),
44                0,
45                response.usage.total_tokens.into(),
46                self.metric_metadata.as_ref(),
47            );
48        }
49
50        let num_embeddings = response.data.len();
51        tracing::debug!(num_embeddings = num_embeddings, "[Embed] Response openai");
52
53        // WARN: Naively assumes that the order is preserved. Might not always be the case.
54        Ok(response.data.into_iter().map(|d| d.embedding).collect())
55    }
56}