swiftide_integrations/openai/
embed.rs1use async_trait::async_trait;
2
3use swiftide_core::{
4 EmbeddingModel, Embeddings,
5 chat_completion::{Usage, errors::LanguageModelError},
6};
7
8use super::GenericOpenAI;
9use crate::openai::openai_error_to_language_model_error;
10
11#[async_trait]
12impl<
13 C: async_openai::config::Config + std::default::Default + Sync + Send + std::fmt::Debug + Clone,
14> EmbeddingModel for GenericOpenAI<C>
15{
16 async fn embed(&self, input: Vec<String>) -> Result<Embeddings, LanguageModelError> {
17 let model = self
18 .default_options
19 .embed_model
20 .as_ref()
21 .ok_or(LanguageModelError::PermanentError("Model not set".into()))?;
22
23 let request = self
24 .embed_request_defaults()
25 .model(model)
26 .input(&input)
27 .build()
28 .map_err(LanguageModelError::permanent)?;
29
30 tracing::debug!(
31 num_chunks = input.len(),
32 model = &model,
33 "[Embed] Request to openai"
34 );
35 let response = self
36 .client
37 .embeddings()
38 .create(request)
39 .await
40 .map_err(openai_error_to_language_model_error)?;
41
42 #[cfg(feature = "metrics")]
43 {
44 swiftide_core::metrics::emit_usage(
45 model,
46 response.usage.prompt_tokens.into(),
47 0,
48 response.usage.total_tokens.into(),
49 self.metric_metadata.as_ref(),
50 );
51 }
52
53 if let Some(callback) = self.on_usage.as_ref() {
54 let usage = Usage {
55 prompt_tokens: response.usage.prompt_tokens,
56 completion_tokens: 0,
57 total_tokens: response.usage.total_tokens,
58 };
59
60 callback(&usage).await?;
61 }
62
63 let num_embeddings = response.data.len();
64 tracing::debug!(num_embeddings = num_embeddings, "[Embed] Response openai");
65
66 Ok(response.data.into_iter().map(|d| d.embedding).collect())
68 }
69}