Skip to main content

sqry_lang_haskell/
lib.rs

1//! Haskell language plugin.
2//!
3//! Provides graph-native extraction via `HaskellGraphBuilder`, AST parsing,
4//! scope extraction, and literate (`.lhs`) preprocessing.
5
6mod preprocess;
7pub mod relations;
8
9pub use relations::HaskellGraphBuilder;
10
11use preprocess::preprocess_content;
12use sqry_core::ast::{Scope, ScopeId, link_nested_scopes};
13use sqry_core::plugin::LanguageMetadata;
14use sqry_core::plugin::LanguagePlugin;
15use sqry_core::plugin::error::ScopeError;
16use std::borrow::Cow;
17use std::path::Path;
18use tree_sitter::{Language, Node, Tree};
19
20const LANGUAGE_ID: &str = "haskell";
21const LANGUAGE_NAME: &str = "Haskell";
22const TREE_SITTER_VERSION: &str = "0.23";
23
24/// Haskell language plugin implementation.
25pub struct HaskellPlugin {
26    graph_builder: HaskellGraphBuilder,
27}
28
29impl HaskellPlugin {
30    /// Creates a new Haskell plugin instance.
31    #[must_use]
32    pub fn new() -> Self {
33        Self {
34            graph_builder: HaskellGraphBuilder::default(),
35        }
36    }
37}
38
39impl Default for HaskellPlugin {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45impl LanguagePlugin for HaskellPlugin {
46    fn metadata(&self) -> LanguageMetadata {
47        LanguageMetadata {
48            id: LANGUAGE_ID,
49            name: LANGUAGE_NAME,
50            version: env!("CARGO_PKG_VERSION"),
51            author: "Verivus Pty Ltd",
52            description: "Haskell language support for sqry",
53            tree_sitter_version: TREE_SITTER_VERSION,
54        }
55    }
56
57    fn extensions(&self) -> &'static [&'static str] {
58        &["hs", "lhs", "hs-boot"]
59    }
60
61    fn language(&self) -> Language {
62        tree_sitter_haskell::LANGUAGE.into()
63    }
64
65    fn preprocess<'a>(&self, content: &'a [u8]) -> Cow<'a, [u8]> {
66        preprocess_content(content)
67    }
68
69    fn extract_scopes(
70        &self,
71        tree: &Tree,
72        content: &[u8],
73        file_path: &Path,
74    ) -> Result<Vec<Scope>, ScopeError> {
75        let processed = self.preprocess(content);
76        Ok(extract_haskell_scopes(tree, processed.as_ref(), file_path))
77    }
78
79    fn graph_builder(&self) -> Option<&dyn sqry_core::graph::GraphBuilder> {
80        Some(&self.graph_builder)
81    }
82}
83
84/// Extract scopes from Haskell source using AST traversal.
85fn extract_haskell_scopes(tree: &Tree, content: &[u8], file_path: &Path) -> Vec<Scope> {
86    let mut scopes = Vec::new();
87    let root = tree.root_node();
88
89    let mut root_cursor = root.walk();
90    for child in root.children(&mut root_cursor) {
91        if child.kind() == "header" {
92            if let Some(module_name) = extract_module_name_from_header(child, content) {
93                let start = child.start_position();
94                let end = root.end_position();
95                scopes.push(Scope {
96                    id: ScopeId::new(0),
97                    scope_type: "module".to_string(),
98                    name: module_name,
99                    file_path: file_path.to_path_buf(),
100                    start_line: start.row + 1,
101                    start_column: start.column,
102                    end_line: end.row + 1,
103                    end_column: end.column,
104                    parent_id: None,
105                });
106            }
107            break;
108        }
109    }
110
111    if let Some(decls) = root.child_by_field_name("declarations") {
112        collect_declaration_scopes(decls, content, file_path, &mut scopes);
113    }
114
115    scopes.sort_by_key(|s| (s.start_line, s.start_column));
116    link_nested_scopes(&mut scopes);
117    scopes
118}
119
120fn collect_declaration_scopes(
121    node: Node<'_>,
122    content: &[u8],
123    file_path: &Path,
124    scopes: &mut Vec<Scope>,
125) {
126    let mut cursor = node.walk();
127    for child in node.children(&mut cursor) {
128        let (scope_type, name_field) = match child.kind() {
129            "function" | "bind" => ("function", Some("name")),
130            "data_type" | "newtype" | "type_synomym" => ("type", Some("name")),
131            "class" => ("class", Some("name")),
132            "instance" => ("instance", Some("name")),
133            "pattern_synonym" => ("function", Some("synonym")),
134            _ => continue,
135        };
136
137        let name = name_field
138            .and_then(|field| child.child_by_field_name(field))
139            .and_then(|n| n.utf8_text(content).ok())
140            .map_or_else(|| format!("<{}>", child.kind()), |s| s.trim().to_string());
141
142        let start = child.start_position();
143        let end = child.end_position();
144
145        scopes.push(Scope {
146            id: ScopeId::new(0),
147            scope_type: scope_type.to_string(),
148            name,
149            file_path: file_path.to_path_buf(),
150            start_line: start.row + 1,
151            start_column: start.column,
152            end_line: end.row + 1,
153            end_column: end.column,
154            parent_id: None,
155        });
156    }
157}
158
159fn extract_module_name_from_header(header: Node<'_>, content: &[u8]) -> Option<String> {
160    let mut cursor = header.walk();
161    for child in header.children(&mut cursor) {
162        if matches!(child.kind(), "module" | "module_id")
163            && let Ok(text) = child.utf8_text(content)
164            && text != "module"
165        {
166            return Some(text.to_string());
167        }
168    }
169    header
170        .utf8_text(content)
171        .ok()
172        .and_then(parse_module_name_from_text)
173}
174
175fn parse_module_name_from_text(text: &str) -> Option<String> {
176    let mut tokens = text.split_whitespace();
177    while let Some(token) = tokens.next() {
178        if token == "module"
179            && let Some(name_token) = tokens.next()
180        {
181            let trimmed = name_token.trim_end_matches(['(', ';']);
182            if !trimmed.is_empty() {
183                return Some(trimmed.to_string());
184            }
185        }
186    }
187    None
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193    use sqry_core::plugin::LanguagePlugin;
194    use std::fs;
195    use std::path::PathBuf;
196
197    fn load_fixture(name: &str) -> (Vec<u8>, PathBuf) {
198        let path = PathBuf::from(format!("tests/fixtures/{name}"));
199        let content = fs::read(&path).expect("failed to read fixture");
200        (content, path)
201    }
202
203    fn extract_scopes_from_fixture(plugin: &HaskellPlugin, name: &str) -> Vec<Scope> {
204        let (content, path) = load_fixture(name);
205        let tree = plugin.parse_ast(&content).expect("parse fixture");
206        plugin
207            .extract_scopes(&tree, &content, &path)
208            .expect("extract scopes")
209    }
210
211    fn has_scope(scopes: &[Scope], scope_type: &str, name: &str) -> bool {
212        scopes
213            .iter()
214            .any(|scope| scope.scope_type == scope_type && scope.name == name)
215    }
216
217    #[test]
218    fn extracts_scopes_from_basic_fixture() {
219        let plugin = HaskellPlugin::default();
220        let scopes = extract_scopes_from_fixture(&plugin, "basic.hs");
221
222        assert!(has_scope(&scopes, "module", "Sample"));
223        assert!(has_scope(&scopes, "function", "foo"));
224        assert!(has_scope(&scopes, "function", "bar"));
225        assert!(has_scope(&scopes, "class", "Run"));
226    }
227
228    #[test]
229    fn parses_literate_haskell() {
230        let plugin = HaskellPlugin::default();
231        let scopes = extract_scopes_from_fixture(&plugin, "literate.lhs");
232
233        assert!(has_scope(&scopes, "module", "Literate"));
234        assert!(has_scope(&scopes, "function", "answer"));
235    }
236}