swiftide_integrations/openai/
embed.rs

1use async_trait::async_trait;
2
3use swiftide_core::{
4    EmbeddingModel, Embeddings,
5    chat_completion::{Usage, errors::LanguageModelError},
6};
7
8use super::GenericOpenAI;
9use crate::openai::openai_error_to_language_model_error;
10
11#[async_trait]
12impl<
13    C: async_openai::config::Config + std::default::Default + Sync + Send + std::fmt::Debug + Clone,
14> EmbeddingModel for GenericOpenAI<C>
15{
16    async fn embed(&self, input: Vec<String>) -> Result<Embeddings, LanguageModelError> {
17        let model = self
18            .default_options
19            .embed_model
20            .as_ref()
21            .ok_or(LanguageModelError::PermanentError("Model not set".into()))?;
22
23        let request = self
24            .embed_request_defaults()
25            .model(model)
26            .input(&input)
27            .build()
28            .map_err(LanguageModelError::permanent)?;
29
30        tracing::debug!(
31            num_chunks = input.len(),
32            model = &model,
33            "[Embed] Request to openai"
34        );
35        let response = self
36            .client
37            .embeddings()
38            .create(request)
39            .await
40            .map_err(openai_error_to_language_model_error)?;
41
42        #[cfg(feature = "metrics")]
43        {
44            swiftide_core::metrics::emit_usage(
45                model,
46                response.usage.prompt_tokens.into(),
47                0,
48                response.usage.total_tokens.into(),
49                self.metric_metadata.as_ref(),
50            );
51        }
52
53        if let Some(callback) = self.on_usage.as_ref() {
54            let usage = Usage {
55                prompt_tokens: response.usage.prompt_tokens,
56                completion_tokens: 0,
57                total_tokens: response.usage.total_tokens,
58            };
59
60            callback(&usage).await?;
61        }
62
63        let num_embeddings = response.data.len();
64        tracing::debug!(num_embeddings = num_embeddings, "[Embed] Response openai");
65
66        // WARN: Naively assumes that the order is preserved. Might not always be the case.
67        Ok(response.data.into_iter().map(|d| d.embedding).collect())
68    }
69}