swiftide_integrations/treesitter/
metadata_qa_code.rs1use anyhow::Result;
4use async_trait::async_trait;
5use swiftide_core::{Transformer, indexing::TextNode};
6
7#[swiftide_macros::indexing_transformer(
11 metadata_field_name = "Questions and Answers (code)",
12 default_prompt_file = "prompts/metadata_qa_code.prompt.md"
13)]
14pub struct MetadataQACode {
15 #[builder(default = "5")]
16 num_questions: usize,
17}
18
19#[async_trait]
20impl Transformer for MetadataQACode {
21 type Input = String;
22 type Output = String;
23 #[tracing::instrument(skip_all, name = "transformers.metadata_qa_code")]
42 async fn transform_node(&self, mut node: TextNode) -> Result<TextNode> {
43 let mut prompt = self
44 .prompt_template
45 .clone()
46 .with_node(&node)
47 .with_context_value("questions", self.num_questions);
48
49 if let Some(outline) = node.metadata.get("Outline") {
50 prompt = prompt.with_context_value("outline", outline.as_str());
51 }
52
53 let response = self.prompt(prompt).await?;
54
55 node.metadata.insert(NAME, response);
56
57 Ok(node)
58 }
59
60 fn concurrency(&self) -> Option<usize> {
61 self.concurrency
62 }
63}
64
65#[cfg(test)]
66mod test {
67 use swiftide_core::{MockSimplePrompt, assert_default_prompt_snapshot};
68
69 use super::*;
70
71 assert_default_prompt_snapshot!("test", "questions" => 5);
72
73 #[tokio::test]
74 async fn test_template_with_outline() {
75 let template = default_prompt();
76
77 let prompt = template
78 .clone()
79 .with_node(&TextNode::new("test"))
80 .with_context_value("questions", 5)
81 .with_context_value("outline", "Test outline");
82 insta::assert_snapshot!(prompt.render().unwrap());
83 }
84
85 #[tokio::test]
86 async fn test_metadata_qacode() {
87 let mut client = MockSimplePrompt::new();
88
89 client
90 .expect_prompt()
91 .returning(|_| Ok("Q1: Hello\nA1: World".to_string()));
92
93 let transformer = MetadataQACode::builder().client(client).build().unwrap();
94 let node = TextNode::new("Some text");
95
96 let result = transformer.transform_node(node).await.unwrap();
97
98 assert_eq!(
99 result.metadata.get("Questions and Answers (code)").unwrap(),
100 "Q1: Hello\nA1: World"
101 );
102 }
103}