shader_sense/symbols/
mod.rs

1//! Handle symbol inspection with tree-sitter
2
3mod glsl;
4mod hlsl;
5pub mod intrinsics;
6pub mod prepocessor;
7pub mod shader_module;
8pub mod shader_module_parser;
9pub mod symbol_list;
10mod symbol_parser;
11pub mod symbol_provider;
12pub mod symbols;
13mod wgsl;
14
15#[cfg(test)]
16mod tests {
17    use std::{
18        collections::HashSet,
19        path::{Path, PathBuf},
20    };
21
22    use regex::Regex;
23
24    use crate::{
25        include::IncludeHandler,
26        position::{ShaderFilePosition, ShaderFileRange, ShaderPosition},
27        shader::{
28            GlslShadingLanguageTag, HlslShadingLanguageTag, ShaderCompilationParams, ShaderParams,
29            ShaderStage, ShadingLanguage, ShadingLanguageTag, WgslShadingLanguageTag,
30        },
31        shader_error::ShaderError,
32        symbols::{
33            intrinsics::ShaderIntrinsics, shader_module_parser::ShaderModuleParser,
34            symbol_list::ShaderSymbolList, symbols::ShaderSymbolData,
35        },
36    };
37
38    use super::symbol_provider::{default_include_callback, SymbolProvider};
39
40    pub fn find_file_dependencies(
41        include_handler: &mut IncludeHandler,
42        shader_content: &String,
43    ) -> Vec<PathBuf> {
44        let include_regex = Regex::new("\\#include\\s+\"([\\w\\s\\\\/\\.\\-]+)\"").unwrap();
45        let dependencies_paths: Vec<&str> = include_regex
46            .captures_iter(&shader_content)
47            .map(|c| c.get(1).unwrap().as_str())
48            .collect();
49        dependencies_paths
50            .iter()
51            .filter_map(|dependency| include_handler.search_path_in_includes(Path::new(dependency)))
52            .collect::<Vec<PathBuf>>()
53    }
54    pub fn find_dependencies(
55        include_handler: &mut IncludeHandler,
56        shader_content: &String,
57    ) -> HashSet<(String, PathBuf)> {
58        let dependencies_path = find_file_dependencies(include_handler, shader_content);
59        let dependencies = dependencies_path
60            .into_iter()
61            .map(|e| (std::fs::read_to_string(&e).unwrap(), e))
62            .collect::<Vec<(String, PathBuf)>>();
63
64        // Use hashset to avoid computing dependencies twice.
65        let mut recursed_dependencies = HashSet::new();
66        for dependency in &dependencies {
67            recursed_dependencies.extend(find_dependencies(include_handler, &dependency.0));
68        }
69        recursed_dependencies.extend(dependencies);
70
71        recursed_dependencies
72    }
73
74    fn get_all_preprocessed_symbols<T: ShadingLanguageTag>(
75        shader_module_parser: &mut ShaderModuleParser,
76        symbol_provider: &SymbolProvider,
77        file_path: &Path,
78        shader_content: &String,
79    ) -> Result<ShaderSymbolList, ShaderError> {
80        let mut include_handler = IncludeHandler::main_without_config(&file_path);
81        let deps = find_dependencies(&mut include_handler, &shader_content);
82        let mut all_symbols = ShaderIntrinsics::get(T::get_language())
83            .get_intrinsics_symbol(&ShaderCompilationParams::default())
84            .to_owned();
85        let shader_module = shader_module_parser
86            .create_module(file_path, shader_content)
87            .unwrap();
88        let symbols = symbol_provider
89            .query_symbols(
90                &shader_module,
91                ShaderParams::default(),
92                &mut default_include_callback::<T>,
93                None,
94            )
95            .unwrap();
96        let symbols = symbols.get_all_symbols();
97        all_symbols.append(symbols.into());
98        for dep in deps {
99            let shader_module = shader_module_parser.create_module(&dep.1, &dep.0).unwrap();
100            let symbols = symbol_provider
101                .query_symbols(
102                    &shader_module,
103                    ShaderParams::default(),
104                    &mut default_include_callback::<T>,
105                    None,
106                )
107                .unwrap();
108            let symbols = symbols.get_all_symbols();
109            all_symbols.append(symbols.into());
110        }
111        Ok(all_symbols)
112    }
113
114    #[test]
115    fn intrinsics_glsl_ok() {
116        // Ensure parsing of intrinsics is OK
117        let _ = ShaderSymbolList::parse_from_json(String::from(include_str!(
118            "glsl/glsl-intrinsics.json"
119        )));
120    }
121    #[test]
122    fn intrinsics_hlsl_ok() {
123        // Ensure parsing of intrinsics is OK
124        let _ = ShaderSymbolList::parse_from_json(String::from(include_str!(
125            "hlsl/hlsl-intrinsics.json"
126        )));
127    }
128    #[test]
129    fn intrinsics_wgsl_ok() {
130        // Ensure parsing of intrinsics is OK
131        let _ = ShaderSymbolList::parse_from_json(String::from(include_str!(
132            "wgsl/wgsl-intrinsics.json"
133        )));
134    }
135    #[test]
136    fn create_glsl_module_ok() {
137        let mut parser = ShaderModuleParser::glsl();
138        let path = Path::new("./test/glsl/ok.frag.glsl");
139        let _module = parser.create_module(path, &std::fs::read_to_string(path).unwrap());
140    }
141    #[test]
142    fn create_hlsl_module_ok() {
143        let mut parser = ShaderModuleParser::hlsl();
144        let path = Path::new("./test/hlsl/ok.hlsl");
145        let _module = parser.create_module(path, &std::fs::read_to_string(path).unwrap());
146    }
147    #[test]
148    fn create_wgsl_module_ok() {
149        let mut parser = ShaderModuleParser::wgsl();
150        let path = Path::new("./test/wgsl/ok.wgsl");
151        let _module = parser.create_module(path, &std::fs::read_to_string(path).unwrap());
152    }
153    #[test]
154    fn symbols_glsl_ok() {
155        // Ensure parsing of symbols is OK
156        let file_path = Path::new("./test/glsl/include-level.comp.glsl");
157        let shader_content = std::fs::read_to_string(file_path).unwrap();
158        let mut shader_module_parser =
159            ShaderModuleParser::from_shading_language(ShadingLanguage::Glsl);
160        let symbol_provider = SymbolProvider::from_shading_language(ShadingLanguage::Glsl);
161        let shader_module = shader_module_parser
162            .create_module(file_path, &shader_content)
163            .unwrap();
164        let symbols = symbol_provider
165            .query_symbols(
166                &shader_module,
167                ShaderParams::default(),
168                &mut default_include_callback::<GlslShadingLanguageTag>,
169                None,
170            )
171            .unwrap();
172        let symbols = symbols.get_all_symbols();
173        assert!(!symbols.functions.is_empty());
174    }
175    #[test]
176    fn symbols_hlsl_ok() {
177        // Ensure parsing of symbols is OK
178        let file_path = Path::new("./test/hlsl/include-level.hlsl");
179        let shader_content = std::fs::read_to_string(file_path).unwrap();
180        let mut shader_module_parser =
181            ShaderModuleParser::from_shading_language(ShadingLanguage::Hlsl);
182        let symbol_provider = SymbolProvider::from_shading_language(ShadingLanguage::Hlsl);
183        let shader_module = shader_module_parser
184            .create_module(file_path, &shader_content)
185            .unwrap();
186        let symbols = symbol_provider
187            .query_symbols(
188                &shader_module,
189                ShaderParams::default(),
190                &mut default_include_callback::<HlslShadingLanguageTag>,
191                None,
192            )
193            .unwrap();
194        let symbols = symbols.get_all_symbols();
195        assert!(!symbols.functions.is_empty());
196    }
197    #[test]
198    fn symbols_wgsl_ok() {
199        // Ensure parsing of symbols is OK
200        let file_path = Path::new("./test/wgsl/ok.wgsl");
201        let shader_content = std::fs::read_to_string(file_path).unwrap();
202        let mut shader_module_parser =
203            ShaderModuleParser::from_shading_language(ShadingLanguage::Wgsl);
204        let symbol_provider = SymbolProvider::from_shading_language(ShadingLanguage::Wgsl);
205        let shader_module = shader_module_parser
206            .create_module(file_path, &shader_content)
207            .unwrap();
208        let symbols = symbol_provider
209            .query_symbols(
210                &shader_module,
211                ShaderParams::default(),
212                &mut default_include_callback::<WgslShadingLanguageTag>,
213                None,
214            )
215            .unwrap();
216        let symbols = symbols.get_all_symbols();
217        assert!(symbols.functions.is_empty());
218    }
219    #[test]
220    fn symbol_scope_glsl_ok() {
221        let file_path = Path::new("./test/glsl/scopes.frag.glsl");
222        let shader_content = std::fs::read_to_string(file_path).unwrap();
223        let mut shader_module_parser =
224            ShaderModuleParser::from_shading_language(ShadingLanguage::Glsl);
225        let symbol_provider = SymbolProvider::from_shading_language(ShadingLanguage::Glsl);
226        let preprocessed_symbol_list = get_all_preprocessed_symbols::<GlslShadingLanguageTag>(
227            &mut shader_module_parser,
228            &symbol_provider,
229            file_path,
230            &shader_content,
231        )
232        .unwrap();
233        let symbol_list = preprocessed_symbol_list.as_ref();
234        let symbols = symbol_list.filter_scoped_symbol(&ShaderFilePosition::new(
235            PathBuf::from(file_path),
236            16,
237            0,
238        ));
239        let variables_visibles: Vec<String> = vec![
240            "scopeRoot".into(),
241            "scope1".into(),
242            "scopeGlobal".into(),
243            "level1".into(),
244        ];
245        let variables_not_visibles: Vec<String> = vec!["scope2".into(), "testData".into()];
246        for variable_visible in variables_visibles {
247            assert!(
248                symbols
249                    .variables
250                    .iter()
251                    .any(|e| e.label == variable_visible),
252                "Failed to find variable {} {:#?}",
253                variable_visible,
254                symbols.variables
255            );
256        }
257        for variable_not_visible in variables_not_visibles {
258            assert!(
259                !symbols
260                    .variables
261                    .iter()
262                    .any(|e| e.label == variable_not_visible),
263                "Found variable {}",
264                variable_not_visible
265            );
266        }
267    }
268    #[test]
269    fn uniform_glsl_ok() {
270        // Ensure parsing of symbols is OK
271        let file_path = Path::new("./test/glsl/uniforms.frag.glsl");
272        let shader_content = std::fs::read_to_string(file_path).unwrap();
273        let mut shader_module_parser =
274            ShaderModuleParser::from_shading_language(ShadingLanguage::Glsl);
275        let symbol_provider = SymbolProvider::from_shading_language(ShadingLanguage::Glsl);
276        let shader_module = shader_module_parser
277            .create_module(file_path, &shader_content)
278            .unwrap();
279        let symbols = symbol_provider
280            .query_symbols(
281                &shader_module,
282                ShaderParams::default(),
283                &mut default_include_callback::<GlslShadingLanguageTag>,
284                None,
285            )
286            .unwrap();
287        let symbols = symbols.get_all_symbols();
288        assert!(symbols
289            .types
290            .iter()
291            .find(|e| e.label == "MatrixHidden")
292            .is_some());
293        assert!(symbols
294            .variables
295            .iter()
296            .find(|e| e.label == "u_accessor"
297                && match &e.data {
298                    ShaderSymbolData::Variables { ty, count: _ } => ty == "MatrixHidden",
299                    _ => false,
300                })
301            .is_some());
302        assert!(symbols
303            .variables
304            .iter()
305            .find(|e| e.label == "u_modelviewGlobal")
306            .is_some());
307        assert!(symbols
308            .variables
309            .iter()
310            .find(|e| e.label == "u_modelviewHidden")
311            .is_none());
312    }
313    #[test]
314    fn test_position_conversion() {
315        fn test_to_byte_offset(
316            shader_content: &str,
317            expected_content: &str,
318            position: &ShaderPosition,
319        ) -> usize {
320            let byte_offset = position.to_byte_offset(&shader_content).unwrap();
321            if expected_content.len() > 0 {
322                let content_from_offset = &shader_content[byte_offset..];
323                assert!(content_from_offset.len() >= expected_content.len());
324                assert!(
325                    content_from_offset == expected_content,
326                    "Offseted content {:?} with offset {} is incorrect.",
327                    &shader_content[byte_offset..],
328                    byte_offset
329                );
330            } else {
331                assert!(byte_offset == shader_content.len());
332            }
333            byte_offset
334        }
335        fn test_back_to_position(
336            shader_content: &str,
337            expected_position: &ShaderPosition,
338            byte_offset: usize,
339        ) {
340            let converted_position =
341                ShaderPosition::from_byte_offset(&shader_content, byte_offset).unwrap();
342            let converted_byte_offset = converted_position.to_byte_offset(&shader_content).unwrap();
343            assert!(converted_position == *expected_position, "Position {:#?} with byte offset {} is different from converted position: {:#?} with byte offset {}", expected_position, byte_offset, converted_position, converted_byte_offset);
344        }
345
346        // Testing file
347        let utf8_file_path = Path::new("./test/hlsl/utf8.hlsl");
348        let utf8_shader_content = std::fs::read_to_string(utf8_file_path).unwrap();
349        // End of line are enforced to \n through gitattributes for hlsl / glsl / wgsl in this repo.
350        let test_data = vec![
351            ("\n}", ShaderPosition::new(5, 0), &utf8_shader_content),
352            ("", ShaderPosition::new(6, 1), &utf8_shader_content),
353            (
354                "id main() {\n\n}",
355                ShaderPosition::new(4, 2),
356                &utf8_shader_content,
357            ),
358            (
359                "にちは世界!\n\nvoid main() {\n\n}",
360                ShaderPosition::new(2, 5),
361                &utf8_shader_content,
362            ),
363        ];
364        for (index, (expected_content, position, shader_content)) in test_data.iter().enumerate() {
365            println!("Testing conversion {} for {:?}", index, position);
366            println!(
367                "Content: {:?} (len {})",
368                shader_content,
369                shader_content.len()
370            );
371            let byte_offset = test_to_byte_offset(&shader_content, expected_content, &position);
372            println!("Found byte_offset {}", byte_offset);
373            test_back_to_position(&shader_content, &position, byte_offset);
374        }
375    }
376    #[test]
377    fn test_end_range() {
378        let file_path = Path::new("./test/hlsl/utf8.hlsl");
379        let shader_content = std::fs::read_to_string(file_path).unwrap();
380        let range = ShaderFileRange::whole(file_path.into(), &shader_content);
381        println!("File range: {:#?}", range);
382        let end_byte_offset = range.range.end.to_byte_offset(&shader_content).unwrap();
383        assert!(end_byte_offset == shader_content.len());
384    }
385    #[test]
386    fn test_intrinsic_filtering() {
387        let intrinsics = ShaderIntrinsics::get(ShadingLanguage::Hlsl);
388        // Check with frag stage set
389        let intrinsics_frag = intrinsics.get_intrinsics_symbol(&ShaderCompilationParams {
390            shader_stage: Some(ShaderStage::Fragment),
391            ..Default::default()
392        });
393        assert!(
394            intrinsics_frag.find_symbol("clip").is_some(),
395            "clip() should be available from fragment shader."
396        );
397        // Check without stage set
398        let intrinsics_common =
399            intrinsics.get_intrinsics_symbol(&ShaderCompilationParams::default());
400        assert!(
401            intrinsics_common.find_symbol("clip").is_some(),
402            "clip() should be available if no shader given."
403        );
404        // Check with vert stage set
405        let intrinsics_vert = intrinsics.get_intrinsics_symbol(&ShaderCompilationParams {
406            shader_stage: Some(ShaderStage::Vertex),
407            ..Default::default()
408        });
409        assert!(
410            intrinsics_vert.find_symbol("clip").is_none(),
411            "clip() should not be available from vertex shader."
412        );
413    }
414}