swiftide_indexing/transformers/
embed.rsuse std::{collections::VecDeque, sync::Arc};
use anyhow::bail;
use async_trait::async_trait;
use swiftide_core::{
    indexing::{IndexingStream, Node},
    BatchableTransformer, EmbeddingModel, WithBatchIndexingDefaults, WithIndexingDefaults,
};
#[derive(Clone)]
pub struct Embed {
    embed_model: Arc<dyn EmbeddingModel>,
    concurrency: Option<usize>,
    batch_size: Option<usize>,
}
impl std::fmt::Debug for Embed {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("Embed")
            .field("concurrency", &self.concurrency)
            .field("batch_size", &self.batch_size)
            .finish()
    }
}
impl Embed {
    pub fn new(model: impl EmbeddingModel + 'static) -> Self {
        Self {
            embed_model: Arc::new(model),
            concurrency: None,
            batch_size: None,
        }
    }
    #[must_use]
    pub fn with_concurrency(mut self, concurrency: usize) -> Self {
        self.concurrency = Some(concurrency);
        self
    }
    #[must_use]
    pub fn with_batch_size(mut self, batch_size: usize) -> Self {
        self.batch_size = Some(batch_size);
        self
    }
}
impl WithBatchIndexingDefaults for Embed {}
impl WithIndexingDefaults for Embed {}
#[async_trait]
impl BatchableTransformer for Embed {
    #[tracing::instrument(skip_all, name = "transformers.embed")]
    async fn batch_transform(&self, mut nodes: Vec<Node>) -> IndexingStream {
        let mut embeddings_keys_groups = VecDeque::with_capacity(nodes.len());
        let embeddables_data = nodes
            .iter_mut()
            .fold(Vec::new(), |mut embeddables_data, node| {
                let embeddables = node.as_embeddables();
                let mut embeddables_keys = Vec::with_capacity(embeddables.len());
                for (embeddable_key, embeddable_data) in embeddables {
                    embeddables_keys.push(embeddable_key);
                    embeddables_data.push(embeddable_data);
                }
                embeddings_keys_groups.push_back(embeddables_keys);
                embeddables_data
            });
        let mut embeddings = match self.embed_model.embed(embeddables_data).await {
            Ok(embeddngs) => VecDeque::from(embeddngs),
            Err(err) => return err.into(),
        };
        let nodes_iter = nodes.into_iter().map(move |mut node| {
            let Some(embedding_keys) = embeddings_keys_groups.pop_front() else {
                bail!("Missing embedding data");
            };
            node.vectors = embedding_keys
                .into_iter()
                .map(|embedded_field| {
                    embeddings
                        .pop_front()
                        .map(|embedding| (embedded_field, embedding))
                })
                .collect();
            Ok(node)
        });
        IndexingStream::iter(nodes_iter)
    }
    fn concurrency(&self) -> Option<usize> {
        self.concurrency
    }
    fn batch_size(&self) -> Option<usize> {
        self.batch_size
    }
}
#[cfg(test)]
mod tests {
    use swiftide_core::indexing::{EmbedMode, EmbeddedField, Metadata, Node};
    use swiftide_core::{BatchableTransformer, MockEmbeddingModel};
    use super::Embed;
    use futures_util::StreamExt;
    use mockall::predicate::*;
    use test_case::test_case;
    #[derive(Clone)]
    struct TestData<'a> {
        pub embed_mode: EmbedMode,
        pub chunk: &'a str,
        pub metadata: Metadata,
        pub expected_embedables: Vec<&'a str>,
        pub expected_vectors: Vec<(EmbeddedField, Vec<f32>)>,
    }
    #[test_case(vec![
        TestData {
            embed_mode: EmbedMode::SingleWithMetadata,
            chunk: "chunk_1",
            metadata: Metadata::from([("meta_1", "prompt_1")]),
            expected_embedables: vec!["meta_1: prompt_1\nchunk_1"],
            expected_vectors: vec![(EmbeddedField::Combined, vec![1f32])]
        },
        TestData {
            embed_mode: EmbedMode::SingleWithMetadata,
            chunk: "chunk_2",
            metadata: Metadata::from([("meta_2", "prompt_2")]),
            expected_embedables: vec!["meta_2: prompt_2\nchunk_2"],
            expected_vectors: vec![(EmbeddedField::Combined, vec![2f32])]
        }
    ]; "Multiple nodes EmbedMode::SingleWithMetadata with metadata.")]
    #[test_case(vec![
        TestData {
            embed_mode: EmbedMode::PerField,
            chunk: "chunk_1",
            metadata: Metadata::from([("meta_1", "prompt 1")]),
            expected_embedables: vec!["chunk_1", "prompt 1"],
            expected_vectors: vec![
                (EmbeddedField::Chunk, vec![10f32]),
                (EmbeddedField::Metadata("meta_1".into()), vec![11f32])
            ]
        },
        TestData {
            embed_mode: EmbedMode::PerField,
            chunk: "chunk_2",
            metadata: Metadata::from([("meta_2", "prompt 2")]),
            expected_embedables: vec!["chunk_2", "prompt 2"],
            expected_vectors: vec![
                (EmbeddedField::Chunk, vec![20f32]),
                (EmbeddedField::Metadata("meta_2".into()), vec![21f32])
            ]
        }
    ]; "Multiple nodes EmbedMode::PerField with metadata.")]
    #[test_case(vec![
        TestData {
            embed_mode: EmbedMode::Both,
            chunk: "chunk_1",
            metadata: Metadata::from([("meta_1", "prompt 1")]),
            expected_embedables: vec!["meta_1: prompt 1\nchunk_1", "chunk_1", "prompt 1"],
            expected_vectors: vec![
                (EmbeddedField::Combined, vec![10f32]),
                (EmbeddedField::Chunk, vec![11f32]),
                (EmbeddedField::Metadata("meta_1".into()), vec![12f32])
            ]
        },
        TestData {
            embed_mode: EmbedMode::Both,
            chunk: "chunk_2",
            metadata: Metadata::from([("meta_2", "prompt 2")]),
            expected_embedables: vec!["meta_2: prompt 2\nchunk_2", "chunk_2", "prompt 2"],
            expected_vectors: vec![
                (EmbeddedField::Combined, vec![20f32]),
                (EmbeddedField::Chunk, vec![21f32]),
                (EmbeddedField::Metadata("meta_2".into()), vec![22f32])
            ]
        }
    ]; "Multiple nodes EmbedMode::Both with metadata.")]
    #[test_case(vec![
        TestData {
            embed_mode: EmbedMode::Both,
            chunk: "chunk_1",
            metadata: Metadata::from([("meta_10", "prompt 10"), ("meta_11", "prompt 11"), ("meta_12", "prompt 12")]),
            expected_embedables: vec!["meta_10: prompt 10\nmeta_11: prompt 11\nmeta_12: prompt 12\nchunk_1", "chunk_1", "prompt 10", "prompt 11", "prompt 12"],
            expected_vectors: vec![
                (EmbeddedField::Combined, vec![10f32]),
                (EmbeddedField::Chunk, vec![11f32]),
                (EmbeddedField::Metadata("meta_10".into()), vec![12f32]),
                (EmbeddedField::Metadata("meta_11".into()), vec![13f32]),
                (EmbeddedField::Metadata("meta_12".into()), vec![14f32]),
            ]
        },
        TestData {
            embed_mode: EmbedMode::Both,
            chunk: "chunk_2",
            metadata: Metadata::from([("meta_20", "prompt 20"), ("meta_21", "prompt 21"), ("meta_22", "prompt 22")]),
            expected_embedables: vec!["meta_20: prompt 20\nmeta_21: prompt 21\nmeta_22: prompt 22\nchunk_2", "chunk_2", "prompt 20", "prompt 21", "prompt 22"],
            expected_vectors: vec![
                (EmbeddedField::Combined, vec![20f32]),
                (EmbeddedField::Chunk, vec![21f32]),
                (EmbeddedField::Metadata("meta_20".into()), vec![22f32]),
                (EmbeddedField::Metadata("meta_21".into()), vec![23f32]),
                (EmbeddedField::Metadata("meta_22".into()), vec![24f32])
            ]
        }
    ]; "Multiple nodes EmbedMode::Both with multiple metadata.")]
    #[test_case(vec![]; "No ingestion nodes")]
    #[tokio::test]
    async fn batch_transform(test_data: Vec<TestData<'_>>) {
        let test_nodes: Vec<Node> = test_data
            .iter()
            .map(|data| Node {
                chunk: data.chunk.into(),
                metadata: data.metadata.clone(),
                embed_mode: data.embed_mode,
                ..Default::default()
            })
            .collect();
        let expected_nodes: Vec<Node> = test_nodes
            .clone()
            .into_iter()
            .zip(test_data.iter())
            .map(|(mut expected_node, test_data)| {
                expected_node.vectors = Some(test_data.expected_vectors.iter().cloned().collect());
                expected_node
            })
            .collect();
        let expected_embeddables_batch = test_data
            .clone()
            .iter()
            .flat_map(|d| &d.expected_embedables)
            .map(ToString::to_string)
            .collect::<Vec<String>>();
        let expected_vectors_batch: Vec<Vec<f32>> = test_data
            .clone()
            .iter()
            .flat_map(|d| d.expected_vectors.iter().map(|(_, v)| v).cloned())
            .collect();
        let mut model_mock = MockEmbeddingModel::new();
        model_mock
            .expect_embed()
            .withf(move |embeddables| expected_embeddables_batch.eq(embeddables))
            .times(1)
            .returning_st(move |_| Ok(expected_vectors_batch.clone()));
        let embed = Embed::new(model_mock);
        let mut stream = embed.batch_transform(test_nodes).await;
        for expected_node in expected_nodes {
            let ingested_node = stream
                .next()
                .await
                .expect("IngestionStream has same length as expected_nodes")
                .expect("Is OK");
            debug_assert_eq!(ingested_node, expected_node);
        }
    }
    #[tokio::test]
    async fn test_returns_error_properly_if_embed_fails() {
        let test_nodes = vec![Node::new("chunk")];
        let mut model_mock = MockEmbeddingModel::new();
        model_mock
            .expect_embed()
            .times(1)
            .returning(|_| Err(anyhow::anyhow!("error")));
        let embed = Embed::new(model_mock);
        let mut stream = embed.batch_transform(test_nodes).await;
        let error = stream
            .next()
            .await
            .expect("IngestionStream has same length as expected_nodes")
            .expect_err("Is Err");
        assert_eq!(error.to_string(), "error");
    }
}