swiftide_query/query_transformers/
sparse_embed.rs

1use std::sync::Arc;
2
3use swiftide_core::{
4    SparseEmbeddingModel,
5    prelude::*,
6    querying::{Query, TransformQuery, states},
7};
8
9/// Embed a query with a sparse embedding.
10#[derive(Debug, Clone)]
11pub struct SparseEmbed {
12    embed_model: Arc<dyn SparseEmbeddingModel>,
13}
14
15impl SparseEmbed {
16    pub fn from_client(client: impl SparseEmbeddingModel + 'static) -> SparseEmbed {
17        SparseEmbed {
18            embed_model: Arc::new(client),
19        }
20    }
21}
22
23#[async_trait]
24impl TransformQuery for SparseEmbed {
25    #[tracing::instrument(skip_all)]
26    async fn transform_query(
27        &self,
28        mut query: Query<states::Pending>,
29    ) -> Result<Query<states::Pending>> {
30        let Some(embedding) = self
31            .embed_model
32            .sparse_embed(vec![query.current().to_string()])
33            .await?
34            .pop()
35        else {
36            anyhow::bail!("Failed to embed query")
37        };
38
39        query.sparse_embedding = Some(embedding);
40
41        Ok(query)
42    }
43}