swiftide_query/query_transformers/
embed.rs1use std::sync::Arc;
2
3use swiftide_core::{
4 indexing::EmbeddingModel,
5 prelude::*,
6 querying::{states, Query, TransformQuery},
7};
8
9#[derive(Debug, Clone)]
10pub struct Embed {
11 embed_model: Arc<dyn EmbeddingModel>,
12}
13
14impl Embed {
15 pub fn from_client(client: impl EmbeddingModel + 'static) -> Embed {
16 Embed {
17 embed_model: Arc::new(client),
18 }
19 }
20}
21
22#[async_trait]
23impl TransformQuery for Embed {
24 #[tracing::instrument(skip_all)]
25 async fn transform_query(
26 &self,
27 mut query: Query<states::Pending>,
28 ) -> Result<Query<states::Pending>> {
29 let Some(embedding) = self
30 .embed_model
31 .embed(vec![query.current().to_string()])
32 .await?
33 .pop()
34 else {
35 anyhow::bail!("Failed to embed query")
36 };
37
38 query.embedding = Some(embedding);
39
40 Ok(query)
41 }
42}