1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
use std::sync::Arc;

use swiftide_core::{
    indexing::EmbeddingModel,
    prelude::*,
    querying::{states, Query, TransformQuery},
};

#[derive(Debug, Clone)]
pub struct Embed {
    embed_model: Arc<dyn EmbeddingModel>,
}

impl Embed {
    pub fn from_client(client: impl EmbeddingModel + 'static) -> Embed {
        Embed {
            embed_model: Arc::new(client),
        }
    }
}

#[async_trait]
impl TransformQuery for Embed {
    #[tracing::instrument]
    async fn transform_query(
        &self,
        mut query: Query<states::Pending>,
    ) -> Result<Query<states::Pending>> {
        let Some(embedding) = self
            .embed_model
            .embed(vec![query.current().to_string()])
            .await?
            .pop()
        else {
            anyhow::bail!("Failed to embed query")
        };

        query.embedding = Some(embedding);

        Ok(query)
    }
}