swiftide_integrations/openai/
embed.rs

1use async_openai::types::CreateEmbeddingRequestArgs;
2use async_trait::async_trait;
3
4use swiftide_core::{chat_completion::errors::LanguageModelError, EmbeddingModel, Embeddings};
5
6use super::GenericOpenAI;
7use crate::openai::openai_error_to_language_model_error;
8
9#[async_trait]
10impl<C: async_openai::config::Config + std::default::Default + Sync + Send + std::fmt::Debug>
11    EmbeddingModel for GenericOpenAI<C>
12{
13    async fn embed(&self, input: Vec<String>) -> Result<Embeddings, LanguageModelError> {
14        let model = self
15            .default_options
16            .embed_model
17            .as_ref()
18            .ok_or(LanguageModelError::PermanentError("Model not set".into()))?;
19
20        let request = CreateEmbeddingRequestArgs::default()
21            .model(model)
22            .input(&input)
23            .build()
24            .map_err(LanguageModelError::permanent)?;
25
26        tracing::debug!(
27            num_chunks = input.len(),
28            model = &model,
29            "[Embed] Request to openai"
30        );
31        let response = self
32            .client
33            .embeddings()
34            .create(request)
35            .await
36            .map_err(openai_error_to_language_model_error)?;
37
38        let num_embeddings = response.data.len();
39        tracing::debug!(num_embeddings = num_embeddings, "[Embed] Response openai");
40
41        // WARN: Naively assumes that the order is preserved. Might not always be the case.
42        Ok(response.data.into_iter().map(|d| d.embedding).collect())
43    }
44}