Skip to main content

sem_core/parser/plugins/
erb.rs

1use std::cell::RefCell;
2use std::collections::HashMap;
3
4use crate::model::entity::{build_entity_id, SemanticEntity};
5use crate::parser::plugin::SemanticParserPlugin;
6use crate::utils::hash::content_hash;
7
8thread_local! {
9    static ERB_PARSER: RefCell<tree_sitter::Parser> = RefCell::new({
10        let mut p = tree_sitter::Parser::new();
11        let lang: tree_sitter::Language = tree_sitter_embedded_template::LANGUAGE.into();
12        let _ = p.set_language(&lang);
13        p
14    });
15}
16
17pub struct ErbParserPlugin;
18
19impl SemanticParserPlugin for ErbParserPlugin {
20    fn id(&self) -> &str {
21        "erb"
22    }
23
24    fn extensions(&self) -> &[&str] {
25        &[".erb"]
26    }
27
28    fn extract_entities(&self, content: &str, file_path: &str) -> Vec<SemanticEntity> {
29        let lines: Vec<&str> = content.lines().collect();
30        if lines.is_empty() {
31            return Vec::new();
32        }
33
34        let mut entities = Vec::new();
35
36        // Top-level template entity
37        let template_name = extract_template_name(file_path);
38        let template_id = build_entity_id(file_path, "template", &template_name, None);
39        entities.push(SemanticEntity {
40            id: template_id.clone(),
41            file_path: file_path.to_string(),
42            entity_type: "template".to_string(),
43            name: template_name,
44            parent_id: None,
45            content: content.to_string(),
46            content_hash: content_hash(content),
47            structural_hash: None,
48            start_line: 1,
49            end_line: lines.len(),
50            metadata: None,
51        });
52
53        // Parse with tree-sitter and extract tags
54        let tags = ERB_PARSER.with(|parser| {
55            let mut parser = parser.borrow_mut();
56            match parser.parse(content.as_bytes(), None) {
57                Some(tree) => extract_tags_from_tree(&tree, content),
58                None => Vec::new(),
59            }
60        });
61
62        let mut block_stack: Vec<ErbTag> = Vec::new();
63        let mut name_counts: HashMap<String, usize> = HashMap::new();
64
65        for tag in tags {
66            match tag.kind {
67                TagKind::BlockOpen => {
68                    block_stack.push(tag);
69                }
70                TagKind::BlockClose => {
71                    if let Some(opener) = block_stack.pop() {
72                        let block_content =
73                            lines[opener.start_line - 1..tag.end_line].join("\n");
74                        let name = unique_name(&opener.name, &mut name_counts);
75                        entities.push(SemanticEntity {
76                            id: build_entity_id(
77                                file_path,
78                                "erb_block",
79                                &name,
80                                Some(&template_id),
81                            ),
82                            file_path: file_path.to_string(),
83                            entity_type: "erb_block".to_string(),
84                            name,
85                            parent_id: Some(template_id.clone()),
86                            content: block_content.clone(),
87                            content_hash: content_hash(&block_content),
88                            structural_hash: None,
89                            start_line: opener.start_line,
90                            end_line: tag.end_line,
91                            metadata: None,
92                        });
93                    }
94                }
95                TagKind::Expression => {
96                    let expr_content =
97                        lines[tag.start_line - 1..tag.end_line].join("\n");
98                    let name = unique_name(&tag.name, &mut name_counts);
99                    entities.push(SemanticEntity {
100                        id: build_entity_id(
101                            file_path,
102                            "erb_expression",
103                            &name,
104                            Some(&template_id),
105                        ),
106                        file_path: file_path.to_string(),
107                        entity_type: "erb_expression".to_string(),
108                        name,
109                        parent_id: Some(template_id.clone()),
110                        content: expr_content.clone(),
111                        content_hash: content_hash(&expr_content),
112                        structural_hash: None,
113                        start_line: tag.start_line,
114                        end_line: tag.end_line,
115                        metadata: None,
116                    });
117                }
118                // No separate Code variant needed; expressions cover all non-block tags
119            }
120        }
121
122        entities
123    }
124}
125
126// --- Internal types ---
127
128#[derive(Debug)]
129enum TagKind {
130    BlockOpen,
131    BlockClose,
132    Expression,
133}
134
135#[derive(Debug)]
136struct ErbTag {
137    kind: TagKind,
138    name: String,
139    start_line: usize,
140    end_line: usize,
141}
142
143// --- Helpers ---
144
145fn extract_template_name(file_path: &str) -> String {
146    let filename = file_path.rsplit('/').next().unwrap_or(file_path);
147    filename.strip_suffix(".erb").unwrap_or(filename).to_string()
148}
149
150/// Walk the tree-sitter AST and classify each directive node.
151fn extract_tags_from_tree(tree: &tree_sitter::Tree, source: &str) -> Vec<ErbTag> {
152    let mut tags = Vec::new();
153    let root = tree.root_node();
154    let mut cursor = root.walk();
155
156    for node in root.children(&mut cursor) {
157        let start_line = node.start_position().row + 1; // 1-indexed
158        let end_line = node.end_position().row + 1;
159
160        match node.kind() {
161            "directive" | "output_directive" => {
162                if let Some(code_text) = code_child_text(&node, source) {
163                    let trimmed = code_text.trim();
164                    if trimmed.is_empty() {
165                        continue;
166                    }
167
168                    if let Some(tag) = classify_code(trimmed, start_line, end_line) {
169                        tags.push(tag);
170                    }
171                }
172            }
173            // comment_directive, content -> skip
174            _ => {}
175        }
176    }
177
178    tags
179}
180
181/// Classify a code snippet from inside an ERB tag.
182/// Returns None for mid-block keywords (else, elsif, etc.) that should be skipped.
183fn classify_code(trimmed: &str, start_line: usize, end_line: usize) -> Option<ErbTag> {
184    let first_word = trimmed.split_whitespace().next().unwrap_or("");
185
186    if first_word == "end" {
187        Some(ErbTag {
188            kind: TagKind::BlockClose,
189            name: "end".to_string(),
190            start_line,
191            end_line,
192        })
193    } else if is_block_opener(trimmed) {
194        Some(ErbTag {
195            kind: TagKind::BlockOpen,
196            name: truncate_name(trimmed),
197            start_line,
198            end_line,
199        })
200    } else if is_mid_block_keyword(first_word) {
201        None
202    } else {
203        // Expression or standalone code
204        Some(ErbTag {
205            kind: TagKind::Expression,
206            name: truncate_name(trimmed),
207            start_line,
208            end_line,
209        })
210    }
211}
212
213fn code_child_text<'a>(node: &tree_sitter::Node, source: &'a str) -> Option<&'a str> {
214    let mut cursor = node.walk();
215    for child in node.children(&mut cursor) {
216        if child.kind() == "code" {
217            return child.utf8_text(source.as_bytes()).ok();
218        }
219    }
220    None
221}
222
223fn is_block_opener(content: &str) -> bool {
224    let first_word = content.split_whitespace().next().unwrap_or("");
225    if matches!(
226        first_word,
227        "if" | "unless" | "for" | "while" | "until" | "case" | "begin"
228    ) {
229        return true;
230    }
231    // Catch `.each do |item|`, `.times do`, etc.
232    content.split_whitespace().any(|w| w == "do")
233}
234
235fn is_mid_block_keyword(word: &str) -> bool {
236    matches!(word, "else" | "elsif" | "when" | "rescue" | "ensure")
237}
238
239fn truncate_name(s: &str) -> String {
240    let s = s.trim();
241    if s.len() <= 60 {
242        s.to_string()
243    } else {
244        format!("{}...", &s[..57])
245    }
246}
247
248fn unique_name(base: &str, counts: &mut HashMap<String, usize>) -> String {
249    let count = counts.entry(base.to_string()).or_insert(0);
250    *count += 1;
251    if *count > 1 {
252        format!("{}#{}", base, count)
253    } else {
254        base.to_string()
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    #[test]
263    fn test_erb_extraction() {
264        let erb = r#"<div class="container">
265  <% if @user.admin? %>
266    <h1>Admin Panel</h1>
267    <%= @user.name %>
268  <% else %>
269    <p>Access denied</p>
270  <% end %>
271
272  <% @items.each do |item| %>
273    <li><%= item.title %></li>
274  <% end %>
275
276  <%# This is a comment, should be skipped %>
277  <% @count = @items.length %>
278</div>
279"#;
280        let plugin = ErbParserPlugin;
281        let entities = plugin.extract_entities(erb, "views/dashboard.html.erb");
282
283        let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
284        let types: Vec<&str> = entities.iter().map(|e| e.entity_type.as_str()).collect();
285        eprintln!(
286            "ERB entities: {:?}",
287            names.iter().zip(types.iter()).collect::<Vec<_>>()
288        );
289
290        // Template entity
291        assert_eq!(entities[0].entity_type, "template");
292        assert_eq!(entities[0].name, "dashboard.html");
293        assert_eq!(entities[0].start_line, 1);
294
295        // if block (lines 2-7)
296        let if_block = entities.iter().find(|e| e.name == "if @user.admin?").unwrap();
297        assert_eq!(if_block.entity_type, "erb_block");
298        assert_eq!(if_block.start_line, 2);
299        assert_eq!(if_block.end_line, 7);
300        assert!(if_block.parent_id.is_some());
301
302        // each block (lines 9-11)
303        let each_block = entities
304            .iter()
305            .find(|e| e.name == "@items.each do |item|")
306            .unwrap();
307        assert_eq!(each_block.entity_type, "erb_block");
308        assert_eq!(each_block.start_line, 9);
309        assert_eq!(each_block.end_line, 11);
310
311        // Expressions
312        assert!(names.contains(&"@user.name"));
313        assert!(names.contains(&"item.title"));
314        let user_name = entities.iter().find(|e| e.name == "@user.name").unwrap();
315        assert_eq!(user_name.entity_type, "erb_expression");
316        assert_eq!(user_name.start_line, 4);
317
318        // Standalone code shows as expression
319        let code = entities
320            .iter()
321            .find(|e| e.name == "@count = @items.length")
322            .unwrap();
323        assert_eq!(code.entity_type, "erb_expression");
324        assert_eq!(code.start_line, 14);
325
326        // Comment should be skipped
327        assert!(!names.iter().any(|n| n.contains("comment")));
328
329        // else should be skipped (mid-block keyword)
330        assert!(!names.iter().any(|n| *n == "else"));
331    }
332
333    #[test]
334    fn test_erb_nested_blocks() {
335        let erb = r#"<% if @show %>
336  <% @items.each do |item| %>
337    <%= item %>
338  <% end %>
339<% end %>
340"#;
341        let plugin = ErbParserPlugin;
342        let entities = plugin.extract_entities(erb, "nested.html.erb");
343
344        let blocks: Vec<&SemanticEntity> = entities
345            .iter()
346            .filter(|e| e.entity_type == "erb_block")
347            .collect();
348        assert_eq!(blocks.len(), 2, "Should have 2 blocks: {:?}",
349            blocks.iter().map(|b| &b.name).collect::<Vec<_>>());
350
351        // Inner block (each) closes first
352        let each = blocks.iter().find(|b| b.name.contains("each")).unwrap();
353        assert_eq!(each.start_line, 2);
354        assert_eq!(each.end_line, 4);
355
356        // Outer block (if) closes second
357        let if_block = blocks.iter().find(|b| b.name.contains("if")).unwrap();
358        assert_eq!(if_block.start_line, 1);
359        assert_eq!(if_block.end_line, 5);
360    }
361
362    #[test]
363    fn test_erb_template_name() {
364        assert_eq!(extract_template_name("views/best.html.erb"), "best.html");
365        assert_eq!(extract_template_name("loading.erb"), "loading");
366        assert_eq!(extract_template_name("app/views/_partial.html.erb"), "_partial.html");
367    }
368
369    #[test]
370    fn test_erb_dash_variant() {
371        // <%- is the whitespace-stripping variant, should produce blocks like <%
372        let erb = r#"<header>
373  <%- if @show %>
374    <%= @title %>
375  <%- else %>
376    <p>nope</p>
377  <%- end if %>
378</header>
379"#;
380        let plugin = ErbParserPlugin;
381        let entities = plugin.extract_entities(erb, "test.html.erb");
382
383        let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
384        let types: Vec<&str> = entities.iter().map(|e| e.entity_type.as_str()).collect();
385        eprintln!("Dash variant: {:?}",
386            names.iter().zip(types.iter()).collect::<Vec<_>>());
387
388        // <%- if %> ... <%- end if %> should create a block
389        let if_block = entities.iter().find(|e| e.name == "if @show").unwrap();
390        assert_eq!(if_block.entity_type, "erb_block");
391        assert_eq!(if_block.start_line, 2);
392        assert_eq!(if_block.end_line, 6);
393
394        // else should be skipped
395        assert!(!names.iter().any(|n| *n == "else"));
396    }
397
398    #[test]
399    fn test_erb_duplicate_expressions() {
400        let erb = r#"<%= @title %>
401<%= @title %>
402"#;
403        let plugin = ErbParserPlugin;
404        let entities = plugin.extract_entities(erb, "test.erb");
405
406        let exprs: Vec<&SemanticEntity> = entities
407            .iter()
408            .filter(|e| e.entity_type == "erb_expression")
409            .collect();
410        assert_eq!(exprs.len(), 2);
411        assert_eq!(exprs[0].name, "@title");
412        assert_eq!(exprs[1].name, "@title#2");
413    }
414}