swiftide_query/query_transformers/
embed.rs

1use std::sync::Arc;
2
3use swiftide_core::{
4    indexing::EmbeddingModel,
5    prelude::*,
6    querying::{Query, TransformQuery, states},
7};
8
9#[derive(Debug, Clone)]
10pub struct Embed {
11    embed_model: Arc<dyn EmbeddingModel>,
12}
13
14impl Embed {
15    pub fn from_client(client: impl EmbeddingModel + 'static) -> Embed {
16        Embed {
17            embed_model: Arc::new(client),
18        }
19    }
20}
21
22#[async_trait]
23impl TransformQuery for Embed {
24    #[tracing::instrument(skip_all)]
25    async fn transform_query(
26        &self,
27        mut query: Query<states::Pending>,
28    ) -> Result<Query<states::Pending>> {
29        let Some(embedding) = self
30            .embed_model
31            .embed(vec![query.current().to_string()])
32            .await?
33            .pop()
34        else {
35            anyhow::bail!("Failed to embed query")
36        };
37
38        query.embedding = Some(embedding);
39
40        Ok(query)
41    }
42}