shader_sense/symbols/
symbol_parser.rs

1//! Trait and helper to parse tree-sitter AST
2use std::path::{Path, PathBuf};
3
4use tree_sitter::{Node, QueryMatch};
5
6use crate::{
7    position::{ShaderPosition, ShaderRange},
8    shader::ShaderCompilationParams,
9    shader_error::ShaderError,
10    symbols::{
11        symbol_list::{ShaderSymbolList, ShaderSymbolListRef},
12        symbols::{ShaderSymbolData, ShaderSymbolMode},
13    },
14};
15
16use super::{
17    prepocessor::{ShaderPreprocessor, ShaderPreprocessorContext, ShaderRegion},
18    shader_module::{ShaderModule, ShaderSymbols},
19    symbol_provider::{SymbolIncludeCallback, SymbolProvider},
20    symbols::{ShaderScope, ShaderSymbol},
21};
22
23pub(super) fn get_name<'a>(shader_content: &'a str, node: Node) -> &'a str {
24    let range = node.range();
25    &shader_content[range.start_byte..range.end_byte]
26}
27
28impl From<tree_sitter::Range> for ShaderRange {
29    fn from(value: tree_sitter::Range) -> Self {
30        ShaderRange::new(
31            ShaderPosition::new(
32                value.start_point.row as u32,
33                value.start_point.column as u32,
34            ),
35            ShaderPosition::new(value.end_point.row as u32, value.end_point.column as u32),
36        )
37    }
38}
39
40impl From<tree_sitter::Point> for ShaderPosition {
41    fn from(point: tree_sitter::Point) -> Self {
42        ShaderPosition::new(point.row as u32, point.column as u32)
43    }
44}
45
46pub struct ShaderSymbolListBuilder<'a> {
47    shader_symbol_list: ShaderSymbolList,
48    filter_callback: Box<&'a dyn Fn(&ShaderSymbol) -> bool>,
49}
50impl<'a> ShaderSymbolListBuilder<'a> {
51    pub fn new(filter_callback: &'a dyn Fn(&ShaderSymbol) -> bool) -> Self {
52        Self {
53            shader_symbol_list: ShaderSymbolList::default(),
54            filter_callback: Box::new(filter_callback),
55        }
56    }
57    pub fn add_call_expression(&mut self, shader_symbol: ShaderSymbol) {
58        if (self.filter_callback)(&shader_symbol) {
59            self.shader_symbol_list.call_expression.push(shader_symbol);
60        }
61    }
62    pub fn add_variable(&mut self, shader_symbol: ShaderSymbol) {
63        if (self.filter_callback)(&shader_symbol) {
64            self.shader_symbol_list.variables.push(shader_symbol);
65        }
66    }
67    pub fn add_type(&mut self, shader_symbol: ShaderSymbol) {
68        if (self.filter_callback)(&shader_symbol) {
69            self.shader_symbol_list.types.push(shader_symbol);
70        }
71    }
72    pub fn add_function(&mut self, shader_symbol: ShaderSymbol) {
73        if (self.filter_callback)(&shader_symbol) {
74            self.shader_symbol_list.functions.push(shader_symbol);
75        }
76    }
77    pub fn get_shader_symbol_list(self) -> ShaderSymbolList {
78        self.shader_symbol_list
79    }
80}
81
82#[derive(Clone, Debug)]
83pub struct ShaderWordRange {
84    parent: Option<Box<ShaderWordRange>>, // Box to avoid recursive struct
85    word: String,
86    range: ShaderRange,
87}
88
89impl ShaderWordRange {
90    pub fn new(word: String, range: ShaderRange, parent: Option<ShaderWordRange>) -> Self {
91        Self {
92            parent: match parent {
93                Some(parent) => Some(Box::new(parent)),
94                None => None,
95            },
96            word,
97            range,
98        }
99    }
100    pub fn get_word(&self) -> &str {
101        &self.word
102    }
103    pub fn get_range(&self) -> &ShaderRange {
104        &self.range
105    }
106    pub fn get_parent(&self) -> Option<&ShaderWordRange> {
107        self.parent.as_ref().map(|p| p.as_ref())
108    }
109    fn get_parent_mut(&mut self) -> Option<&mut ShaderWordRange> {
110        self.parent.as_mut().map(|p| p.as_mut())
111    }
112    pub fn set_root_parent(&mut self, root_parent: ShaderWordRange) {
113        // Use a raw pointer to traverse without holding a mutable borrow
114        let mut parent: *mut ShaderWordRange = self;
115        unsafe {
116            while let Some(p) = (*parent).get_parent_mut() {
117                parent = p;
118            }
119            // Now parent is the deepest node, safe to assign
120            (*parent).parent = Some(Box::new(root_parent));
121        }
122    }
123    pub fn get_word_stack(&self) -> Vec<&ShaderWordRange> {
124        let mut current_word = self;
125        let mut stack = Vec::new();
126        stack.push(self);
127        while let Some(parent) = &current_word.parent {
128            stack.push(parent.as_ref());
129            current_word = parent.as_ref();
130        }
131        stack
132    }
133    pub fn set_parent(&mut self, parent: ShaderWordRange) {
134        self.parent = Some(Box::new(parent));
135    }
136    pub fn is_field(&self) -> bool {
137        self.parent.is_some()
138    }
139    // Look for matching symbol in symbol_list
140    pub fn find_symbol_from_parent(
141        &self,
142        file_path: PathBuf,
143        symbol_list: &ShaderSymbolListRef,
144    ) -> Vec<ShaderSymbol> {
145        if self.parent.is_none() {
146            // Could be either a variable, a link, or a type.
147            symbol_list
148                .find_symbols_at(&self.word, &self.range.end.clone_into_file(file_path))
149                .iter()
150                .map(|s| (*s).clone())
151                .collect()
152        } else {
153            // Will be a variable or function (root only), method, or member if chained.
154            let stack = self.get_word_stack();
155            let mut rev_stack = stack.iter().rev();
156            // TODO: SHould not require file path & filter here...
157            let symbol_list = symbol_list
158                .filter_scoped_symbol(&self.range.end.clone_into_file(file_path.clone()));
159            // Look for root symbol (either a function or variable)
160            let root_symbol = match rev_stack.next() {
161                Some(current_word) => match symbol_list.find_symbol(&current_word.word) {
162                    Some(symbol) => {
163                        match &symbol.data {
164                            ShaderSymbolData::CallExpression {
165                                label,
166                                range: _,
167                                parameters: _,
168                            } => {
169                                match symbol_list.find_function_symbol(label) {
170                                    Some(function) => {
171                                        if let ShaderSymbolData::Functions { signatures: _ } =
172                                            &function.data
173                                        {
174                                            symbol
175                                        } else {
176                                            return vec![]; // Not a valid function
177                                        }
178                                    }
179                                    None => return vec![], // No matching function found
180                                }
181                            }
182                            ShaderSymbolData::Functions { signatures: _ } => symbol,
183                            ShaderSymbolData::Variables { ty: _, count: _ } => symbol,
184                            ShaderSymbolData::Enum { values: _ } => symbol,
185                            _ => return vec![], // Symbol found is not a variable nor a function.
186                        }
187                    }
188                    None => {
189                        return vec![]; // No variable found for main parent.
190                    }
191                },
192                None => unreachable!("Should always have at least one symbol on this path."),
193            };
194            // Now loop over child for matching member elements
195            let mut current_symbols = vec![root_symbol.clone()];
196            while let Some(next_item) = &rev_stack.next() {
197                // TODO: for now, we naively pick the first signature.
198                // But we should pick instead the one closest by analyzing parameters.
199                let ty = match &current_symbols[0].data {
200                    // CallExpression & variable will only be called on first iteration
201                    ShaderSymbolData::CallExpression {
202                        label,
203                        range: _,
204                        parameters: _,
205                    } => {
206                        match symbol_list.find_function_symbol(label) {
207                            Some(function) => {
208                                if let ShaderSymbolData::Functions { signatures } = &function.data {
209                                    &signatures[0].returnType
210                                } else {
211                                    return vec![]; // Not a valid function
212                                }
213                            }
214                            None => return vec![], // No matching function found
215                        }
216                    }
217                    ShaderSymbolData::Enum { values } => match values
218                        .iter()
219                        .find(|v| v.label == next_item.get_word())
220                        .map(|v| v.as_symbol(Some(file_path.clone()), &current_symbols[0].label))
221                    {
222                        Some(symbol) => return vec![symbol],
223                        None => return vec![],
224                    },
225                    ShaderSymbolData::Functions { signatures } => &signatures[0].returnType,
226                    ShaderSymbolData::Variables { ty, count: _ } => &ty,
227                    // Method & parameter will only be called after first iteration
228                    ShaderSymbolData::Method {
229                        context: _,
230                        signatures,
231                    } => &signatures[0].returnType,
232                    ShaderSymbolData::Parameter {
233                        context: _,
234                        ty,
235                        count: _,
236                    } => &ty,
237                    _ => return vec![], // Invalid type
238                };
239                // Find the type symbol of the variable / method.
240                let symbol_ty = match symbol_list.find_type_symbol(&ty) {
241                    Some(ty_symbol) => ty_symbol,
242                    None => return vec![], // No matching type found
243                };
244                // Find the variable chained from the type.
245                let symbols: Vec<ShaderSymbol> = match &symbol_ty.data {
246                    ShaderSymbolData::Struct {
247                        constructors: _,
248                        members,
249                        methods,
250                    } => {
251                        let file_path = if let ShaderSymbolMode::Runtime(runtime) = &symbol_ty.mode
252                        {
253                            Some(runtime.file_path.clone())
254                        } else {
255                            None
256                        };
257                        let member_symbols: Vec<ShaderSymbol> = members
258                            .iter()
259                            .filter(|m| m.parameters.label == next_item.word)
260                            .map(|m| m.as_symbol(file_path.clone()))
261                            .collect();
262                        let method_symbols: Vec<ShaderSymbol> = methods
263                            .iter()
264                            .filter(|m| m.label == next_item.word)
265                            .map(|m| m.as_symbol(file_path.clone()))
266                            .collect();
267                        [member_symbols, method_symbols].concat()
268                    }
269                    ShaderSymbolData::Types { constructors: _ } => {
270                        return vec![]; // Cannot chain a default type.
271                    }
272                    _ => return vec![], // Data useless.
273                };
274                if symbols.is_empty() {
275                    return vec![]; // No matching member / methods found.
276                } else {
277                    current_symbols = symbols;
278                }
279            }
280            current_symbols
281        }
282    }
283}
284
285pub trait SymbolTreeParser {
286    // The query to match tree node
287    fn get_query(&self) -> String;
288    // Process the match & convert it to symbol
289    fn process_match(
290        &self,
291        symbol_match: &QueryMatch,
292        file_path: &Path,
293        shader_content: &str,
294        scopes: &Vec<ShaderScope>,
295        symbols: &mut ShaderSymbolListBuilder,
296    );
297    fn compute_scope_stack(
298        &self,
299        scopes: &Vec<ShaderScope>,
300        range: &ShaderRange,
301    ) -> Vec<ShaderScope> {
302        scopes
303            .iter()
304            .filter_map(|e| {
305                if e.contain_bounds(&range) {
306                    Some(e.clone())
307                } else {
308                    None
309                }
310            })
311            .collect::<Vec<ShaderScope>>()
312    }
313}
314
315pub trait SymbolRegionFinder {
316    fn query_regions_in_node<'a>(
317        &self,
318        shader_module: &ShaderModule,
319        symbol_provider: &SymbolProvider,
320        shader_params: &ShaderCompilationParams,
321        node: tree_sitter::Node,
322        preprocessor: &mut ShaderPreprocessor,
323        context: &'a mut ShaderPreprocessorContext,
324        include_callback: &'a mut SymbolIncludeCallback<'a>,
325        old_symbols: Option<ShaderSymbols>,
326    ) -> Result<Vec<ShaderRegion>, ShaderError>;
327}
328
329pub trait SymbolTreePreprocessorParser {
330    // The query to match tree node
331    fn get_query(&self) -> String;
332    // Process the match & convert it to preprocessor
333    fn process_match(
334        &self,
335        symbol_match: &QueryMatch,
336        file_path: &Path,
337        shader_content: &str,
338        preprocessor: &mut ShaderPreprocessor,
339        context: &mut ShaderPreprocessorContext,
340    );
341}
342
343pub trait SymbolWordProvider {
344    fn find_word_at_position_in_node(
345        &self,
346        shader_module: &ShaderModule,
347        node: Node,
348        position: &ShaderPosition,
349    ) -> Result<ShaderWordRange, ShaderError>;
350}