swiftide_integrations/treesitter/
metadata_qa_code.rs

1//! Generate questions and answers based on code chunks and add them as metadata
2
3use anyhow::Result;
4use async_trait::async_trait;
5use swiftide_core::{Transformer, indexing::TextNode};
6
7/// `MetadataQACode` is responsible for generating questions and answers based on code chunks.
8/// This struct integrates with the indexing pipeline to enhance the metadata of each code chunk
9/// by adding relevant questions and answers.
10#[swiftide_macros::indexing_transformer(
11    metadata_field_name = "Questions and Answers (code)",
12    default_prompt_file = "prompts/metadata_qa_code.prompt.md"
13)]
14pub struct MetadataQACode {
15    #[builder(default = "5")]
16    num_questions: usize,
17}
18
19#[async_trait]
20impl Transformer for MetadataQACode {
21    type Input = String;
22    type Output = String;
23    /// Asynchronously transforms a `TextNode` by generating questions and answers for its code
24    /// chunk.
25    ///
26    /// This method uses the `SimplePrompt` client to generate questions and answers based on the
27    /// code chunk and adds this information to the node's metadata.
28    ///
29    /// # Arguments
30    ///
31    /// * `node` - The `TextNode` to be transformed.
32    ///
33    /// # Returns
34    ///
35    /// A result containing the transformed `TextNode` or an error if the transformation fails.
36    ///
37    /// # Errors
38    ///
39    /// This function will return an error if the `SimplePrompt` client fails to generate a
40    /// response.
41    #[tracing::instrument(skip_all, name = "transformers.metadata_qa_code")]
42    async fn transform_node(&self, mut node: TextNode) -> Result<TextNode> {
43        let mut prompt = self
44            .prompt_template
45            .clone()
46            .with_node(&node)
47            .with_context_value("questions", self.num_questions);
48
49        if let Some(outline) = node.metadata.get("Outline") {
50            prompt = prompt.with_context_value("outline", outline.as_str());
51        }
52
53        let response = self.prompt(prompt).await?;
54
55        node.metadata.insert(NAME, response);
56
57        Ok(node)
58    }
59
60    fn concurrency(&self) -> Option<usize> {
61        self.concurrency
62    }
63}
64
65#[cfg(test)]
66mod test {
67    use swiftide_core::{MockSimplePrompt, assert_default_prompt_snapshot};
68
69    use super::*;
70
71    assert_default_prompt_snapshot!("test", "questions" => 5);
72
73    #[tokio::test]
74    async fn test_template_with_outline() {
75        let template = default_prompt();
76
77        let prompt = template
78            .clone()
79            .with_node(&TextNode::new("test"))
80            .with_context_value("questions", 5)
81            .with_context_value("outline", "Test outline");
82        insta::assert_snapshot!(prompt.render().unwrap());
83    }
84
85    #[tokio::test]
86    async fn test_metadata_qacode() {
87        let mut client = MockSimplePrompt::new();
88
89        client
90            .expect_prompt()
91            .returning(|_| Ok("Q1: Hello\nA1: World".to_string()));
92
93        let transformer = MetadataQACode::builder().client(client).build().unwrap();
94        let node = TextNode::new("Some text");
95
96        let result = transformer.transform_node(node).await.unwrap();
97
98        assert_eq!(
99            result.metadata.get("Questions and Answers (code)").unwrap(),
100            "Q1: Hello\nA1: World"
101        );
102    }
103}