swiftide_query/query_transformers/
sparse_embed.rs1use std::sync::Arc;
2
3use swiftide_core::{
4 SparseEmbeddingModel,
5 prelude::*,
6 querying::{Query, TransformQuery, states},
7};
8
9#[derive(Debug, Clone)]
11pub struct SparseEmbed {
12 embed_model: Arc<dyn SparseEmbeddingModel>,
13}
14
15impl SparseEmbed {
16 pub fn from_client(client: impl SparseEmbeddingModel + 'static) -> SparseEmbed {
17 SparseEmbed {
18 embed_model: Arc::new(client),
19 }
20 }
21}
22
23#[async_trait]
24impl TransformQuery for SparseEmbed {
25 #[tracing::instrument(skip_all)]
26 async fn transform_query(
27 &self,
28 mut query: Query<states::Pending>,
29 ) -> Result<Query<states::Pending>> {
30 let Some(embedding) = self
31 .embed_model
32 .sparse_embed(vec![query.current().to_string()])
33 .await?
34 .pop()
35 else {
36 anyhow::bail!("Failed to embed query")
37 };
38
39 query.sparse_embedding = Some(embedding);
40
41 Ok(query)
42 }
43}