shader_sense/symbols/
symbol_parser.rs

1use std::path::{Path, PathBuf};
2
3use tree_sitter::{Node, QueryMatch};
4
5use crate::{
6    position::{ShaderFilePosition, ShaderPosition, ShaderRange},
7    shader::ShaderCompilationParams,
8    shader_error::ShaderError,
9    symbols::{
10        symbol_list::{ShaderSymbolList, ShaderSymbolListRef},
11        symbols::{ShaderSymbolData, ShaderSymbolMode},
12    },
13};
14
15use super::{
16    prepocessor::{ShaderPreprocessor, ShaderPreprocessorContext, ShaderRegion},
17    shader_module::{ShaderModule, ShaderSymbols},
18    symbol_provider::{SymbolIncludeCallback, SymbolProvider},
19    symbols::{ShaderScope, ShaderSymbol},
20};
21
22pub(super) fn get_name<'a>(shader_content: &'a str, node: Node) -> &'a str {
23    let range = node.range();
24    &shader_content[range.start_byte..range.end_byte]
25}
26
27impl From<tree_sitter::Range> for ShaderRange {
28    fn from(value: tree_sitter::Range) -> Self {
29        ShaderRange::new(
30            ShaderPosition::new(
31                value.start_point.row as u32,
32                value.start_point.column as u32,
33            ),
34            ShaderPosition::new(value.end_point.row as u32, value.end_point.column as u32),
35        )
36    }
37}
38
39impl From<tree_sitter::Point> for ShaderPosition {
40    fn from(point: tree_sitter::Point) -> Self {
41        ShaderPosition::new(point.row as u32, point.column as u32)
42    }
43}
44
45pub struct ShaderSymbolListBuilder<'a> {
46    shader_symbol_list: ShaderSymbolList,
47    filter_callback: Box<&'a dyn Fn(&ShaderSymbol) -> bool>,
48}
49impl<'a> ShaderSymbolListBuilder<'a> {
50    pub fn new(filter_callback: &'a dyn Fn(&ShaderSymbol) -> bool) -> Self {
51        Self {
52            shader_symbol_list: ShaderSymbolList::default(),
53            filter_callback: Box::new(filter_callback),
54        }
55    }
56    pub fn add_call_expression(&mut self, shader_symbol: ShaderSymbol) {
57        if (self.filter_callback)(&shader_symbol) {
58            self.shader_symbol_list.call_expression.push(shader_symbol);
59        }
60    }
61    pub fn add_variable(&mut self, shader_symbol: ShaderSymbol) {
62        if (self.filter_callback)(&shader_symbol) {
63            self.shader_symbol_list.variables.push(shader_symbol);
64        }
65    }
66    pub fn add_type(&mut self, shader_symbol: ShaderSymbol) {
67        if (self.filter_callback)(&shader_symbol) {
68            self.shader_symbol_list.types.push(shader_symbol);
69        }
70    }
71    pub fn add_function(&mut self, shader_symbol: ShaderSymbol) {
72        if (self.filter_callback)(&shader_symbol) {
73            self.shader_symbol_list.functions.push(shader_symbol);
74        }
75    }
76    pub fn get_shader_symbol_list(self) -> ShaderSymbolList {
77        self.shader_symbol_list
78    }
79}
80
81#[derive(Clone, Debug)]
82pub struct ShaderWordRange {
83    parent: Option<Box<ShaderWordRange>>, // Box to avoid recursive struct
84    word: String,
85    range: ShaderRange,
86}
87
88impl ShaderWordRange {
89    pub fn new(word: String, range: ShaderRange, parent: Option<ShaderWordRange>) -> Self {
90        Self {
91            parent: match parent {
92                Some(parent) => Some(Box::new(parent)),
93                None => None,
94            },
95            word,
96            range,
97        }
98    }
99    pub fn get_word(&self) -> &str {
100        &self.word
101    }
102    pub fn get_range(&self) -> &ShaderRange {
103        &self.range
104    }
105    pub fn get_parent(&self) -> Option<&ShaderWordRange> {
106        self.parent.as_ref().map(|p| p.as_ref())
107    }
108    fn get_parent_mut(&mut self) -> Option<&mut ShaderWordRange> {
109        self.parent.as_mut().map(|p| p.as_mut())
110    }
111    pub fn set_root_parent(&mut self, root_parent: ShaderWordRange) {
112        // Use a raw pointer to traverse without holding a mutable borrow
113        let mut parent: *mut ShaderWordRange = self;
114        unsafe {
115            while let Some(p) = (*parent).get_parent_mut() {
116                parent = p;
117            }
118            // Now parent is the deepest node, safe to assign
119            (*parent).parent = Some(Box::new(root_parent));
120        }
121    }
122    pub fn get_word_stack(&self) -> Vec<&ShaderWordRange> {
123        let mut current_word = self;
124        let mut stack = Vec::new();
125        stack.push(self);
126        while let Some(parent) = &current_word.parent {
127            stack.push(parent.as_ref());
128            current_word = parent.as_ref();
129        }
130        stack
131    }
132    pub fn set_parent(&mut self, parent: ShaderWordRange) {
133        self.parent = Some(Box::new(parent));
134    }
135    pub fn is_field(&self) -> bool {
136        self.parent.is_some()
137    }
138    // Look for matching symbol in symbol_list
139    pub fn find_symbol_from_parent(
140        &self,
141        file_path: PathBuf,
142        symbol_list: &ShaderSymbolListRef,
143    ) -> Vec<ShaderSymbol> {
144        if self.parent.is_none() {
145            // Could be either a variable, a link, or a type.
146            symbol_list
147                .find_symbols_at(&self.word, &self.range.end.clone_into_file(file_path))
148                .iter()
149                .map(|s| (*s).clone())
150                .collect()
151        } else {
152            // Will be a variable or function (root only), method, or member if chained.
153            let stack = self.get_word_stack();
154            let mut rev_stack = stack.iter().rev();
155            // TODO: SHould not require file path & filter here...
156            let symbol_list =
157                symbol_list.filter_scoped_symbol(&self.range.end.clone_into_file(file_path));
158            // Look for root symbol (either a function or variable)
159            let root_symbol = match rev_stack.next() {
160                Some(current_word) => match symbol_list.find_symbol(&current_word.word) {
161                    Some(symbol) => {
162                        match &symbol.data {
163                            ShaderSymbolData::CallExpression {
164                                label,
165                                range: _,
166                                parameters: _,
167                            } => {
168                                match symbol_list.find_function_symbol(label) {
169                                    Some(function) => {
170                                        if let ShaderSymbolData::Functions { signatures: _ } =
171                                            &function.data
172                                        {
173                                            symbol
174                                        } else {
175                                            return vec![]; // Not a valid function
176                                        }
177                                    }
178                                    None => return vec![], // No matching function found
179                                }
180                            }
181                            ShaderSymbolData::Functions { signatures: _ } => symbol,
182                            ShaderSymbolData::Variables { ty: _, count: _ } => symbol,
183                            _ => return vec![], // Symbol found is not a variable nor a function.
184                        }
185                    }
186                    None => {
187                        return vec![]; // No variable found for main parent.
188                    }
189                },
190                None => unreachable!("Should always have at least one symbol on this path."),
191            };
192            // Now loop over child for matching member elements
193            let mut current_symbols = vec![root_symbol.clone()];
194            while let Some(next_item) = &rev_stack.next() {
195                // TODO: for now, we naively pick the first signature.
196                // But we should pick instead the one closest by analyzing parameters.
197                let ty = match &current_symbols[0].data {
198                    // CallExpression & variable will only be called on first iteration
199                    ShaderSymbolData::CallExpression {
200                        label,
201                        range: _,
202                        parameters: _,
203                    } => {
204                        match symbol_list.find_function_symbol(label) {
205                            Some(function) => {
206                                if let ShaderSymbolData::Functions { signatures } = &function.data {
207                                    &signatures[0].returnType
208                                } else {
209                                    return vec![]; // Not a valid function
210                                }
211                            }
212                            None => return vec![], // No matching function found
213                        }
214                    }
215                    ShaderSymbolData::Functions { signatures } => &signatures[0].returnType,
216                    ShaderSymbolData::Variables { ty, count: _ } => &ty,
217                    // Method & parameter will only be called after first iteration
218                    ShaderSymbolData::Method {
219                        context: _,
220                        signatures,
221                    } => &signatures[0].returnType,
222                    ShaderSymbolData::Parameter {
223                        context: _,
224                        ty,
225                        count: _,
226                    } => &ty,
227                    _ => return vec![], // Invalid type
228                };
229                // Find the type symbol of the variable / method.
230                let symbol_ty = match symbol_list.find_type_symbol(&ty) {
231                    Some(ty_symbol) => ty_symbol,
232                    None => return vec![], // No matching type found
233                };
234                // Find the variable chained from the type.
235                let symbols: Vec<ShaderSymbol> = match &symbol_ty.data {
236                    ShaderSymbolData::Struct {
237                        constructors: _,
238                        members,
239                        methods,
240                    } => {
241                        let file_path = if let ShaderSymbolMode::Runtime(runtime) = &symbol_ty.mode
242                        {
243                            Some(runtime.file_path.clone())
244                        } else {
245                            None
246                        };
247                        let member_symbols: Vec<ShaderSymbol> = members
248                            .iter()
249                            .filter(|m| m.parameters.label == next_item.word)
250                            .map(|m| m.as_symbol(file_path.clone()))
251                            .collect();
252                        let method_symbols: Vec<ShaderSymbol> = methods
253                            .iter()
254                            .filter(|m| m.label == next_item.word)
255                            .map(|m| m.as_symbol(file_path.clone()))
256                            .collect();
257                        [member_symbols, method_symbols].concat()
258                    }
259                    ShaderSymbolData::Types { constructors: _ } => {
260                        return vec![]; // Cannot chain a default type.
261                    }
262                    _ => return vec![], // Data useless.
263                };
264                if symbols.is_empty() {
265                    return vec![]; // No matching member / methods found.
266                } else {
267                    current_symbols = symbols;
268                }
269            }
270            current_symbols
271        }
272    }
273}
274
275pub trait SymbolTreeParser {
276    // The query to match tree node
277    fn get_query(&self) -> String;
278    // Process the match & convert it to symbol
279    fn process_match(
280        &self,
281        matches: QueryMatch,
282        file_path: &Path,
283        shader_content: &str,
284        scopes: &Vec<ShaderScope>,
285        symbols: &mut ShaderSymbolListBuilder,
286    );
287    fn compute_scope_stack(
288        &self,
289        scopes: &Vec<ShaderScope>,
290        range: &ShaderRange,
291    ) -> Vec<ShaderScope> {
292        scopes
293            .iter()
294            .filter_map(|e| {
295                if e.contain_bounds(&range) {
296                    Some(e.clone())
297                } else {
298                    None
299                }
300            })
301            .collect::<Vec<ShaderScope>>()
302    }
303}
304
305pub trait SymbolRegionFinder {
306    fn query_regions_in_node<'a>(
307        &self,
308        shader_module: &ShaderModule,
309        symbol_provider: &SymbolProvider,
310        shader_params: &ShaderCompilationParams,
311        node: tree_sitter::Node,
312        preprocessor: &mut ShaderPreprocessor,
313        context: &'a mut ShaderPreprocessorContext,
314        include_callback: &'a mut SymbolIncludeCallback<'a>,
315        old_symbols: Option<ShaderSymbols>,
316    ) -> Result<Vec<ShaderRegion>, ShaderError>;
317}
318
319pub trait SymbolTreePreprocessorParser {
320    // The query to match tree node
321    fn get_query(&self) -> String;
322    // Process the match & convert it to preprocessor
323    fn process_match(
324        &self,
325        matches: QueryMatch,
326        file_path: &Path,
327        shader_content: &str,
328        preprocessor: &mut ShaderPreprocessor,
329        context: &mut ShaderPreprocessorContext,
330    );
331}
332
333pub trait SymbolWordProvider {
334    fn find_word_at_position_in_node(
335        &self,
336        shader_module: &ShaderModule,
337        node: Node,
338        position: &ShaderFilePosition,
339    ) -> Result<ShaderWordRange, ShaderError>;
340}