swiftide_integrations/treesitter/
compress_code_outline.rs1use std::sync::OnceLock;
4
5use anyhow::Result;
6use async_trait::async_trait;
7use swiftide_core::{Transformer, indexing::TextNode};
8
9#[swiftide_macros::indexing_transformer(
13 metadata_field_name = "Outline",
14 default_prompt_file = "prompts/compress_code_outline.prompt.md"
15)]
16pub struct CompressCodeOutline {}
17
18fn extract_markdown_codeblock(text: String) -> String {
19 static REGEX: OnceLock<regex::Regex> = OnceLock::new();
20
21 let re = REGEX.get_or_init(|| regex::Regex::new(r"(?sm)```\w*\n(.*?)```").unwrap());
22 let captures = re.captures(text.as_str());
23 captures
24 .map(|c| c.get(1).unwrap().as_str().to_string())
25 .unwrap_or(text)
26}
27
28#[async_trait]
29impl Transformer for CompressCodeOutline {
30 type Input = String;
31 type Output = String;
32 #[tracing::instrument(skip_all, name = "transformers.compress_code_outline")]
51 async fn transform_node(&self, mut node: TextNode) -> Result<TextNode> {
52 if node.metadata.get(NAME).is_none() {
53 return Ok(node);
54 }
55
56 let prompt = self.prompt_template.clone().with_node(&node);
57
58 let response = extract_markdown_codeblock(self.prompt(prompt).await?);
59
60 node.metadata.insert(NAME, response);
61
62 Ok(node)
63 }
64
65 fn concurrency(&self) -> Option<usize> {
66 self.concurrency
67 }
68}
69
70#[cfg(test)]
71mod test {
72 use swiftide_core::MockSimplePrompt;
73
74 use super::*;
75
76 #[test_log::test(tokio::test)]
77 async fn test_compress_code_template() {
78 let template = default_prompt();
79
80 let outline = "Relevant Outline";
81 let code = "Code using outline";
82 let mut node = TextNode::new(code);
83 node.metadata.insert("Outline", outline);
84
85 let prompt = template.clone().with_node(&node);
86
87 insta::assert_snapshot!(prompt.render().unwrap());
88 }
89
90 #[tokio::test]
91 async fn test_compress_code_outline() {
92 let mut client = MockSimplePrompt::new();
93
94 client
95 .expect_prompt()
96 .returning(|_| Ok("RelevantOutline".to_string()));
97
98 let transformer = CompressCodeOutline::builder()
99 .client(client)
100 .build()
101 .unwrap();
102 let mut node = TextNode::new("Some text");
103 node.offset = 0;
104 node.original_size = 100;
105
106 node.metadata
107 .insert("Outline".to_string(), "Some outline".to_string());
108
109 let result = transformer.transform_node(node).await.unwrap();
110
111 assert_eq!(result.chunk, "Some text");
112 assert_eq!(result.metadata.get("Outline").unwrap(), "RelevantOutline");
113 }
114}