shader_sense/symbols/
shader_module.rs

1//! Shader module storing the [`tree-sitter`] AST
2use std::{
3    cell::RefCell,
4    path::{Path, PathBuf},
5    rc::Rc,
6};
7
8use tree_sitter::{Tree, TreeCursor};
9
10use crate::{
11    shader::ShaderContextParams,
12    symbols::symbol_list::{ShaderSymbolList, ShaderSymbolListRef},
13};
14
15use super::prepocessor::{
16    ShaderPreprocessor, ShaderPreprocessorContext, ShaderPreprocessorInclude,
17};
18
19/// Shader module holding the tree-sitter AST.
20/// Need to be created with [`ShaderModuleParser`]
21#[derive(Debug, Clone)]
22pub struct ShaderModule {
23    pub file_path: PathBuf,
24    pub content: String,
25    pub tree: Tree,
26}
27
28pub type ShaderModuleHandle = Rc<RefCell<ShaderModule>>;
29
30#[derive(Debug, Default, Clone)]
31pub struct ShaderSymbols {
32    pub(super) preprocessor: ShaderPreprocessor,
33    pub(super) symbol_list: ShaderSymbolList,
34}
35impl ShaderSymbols {
36    pub fn new(file_path: &Path, shader_params: ShaderContextParams) -> Self {
37        Self {
38            preprocessor: ShaderPreprocessor::new(ShaderPreprocessorContext::main(
39                file_path,
40                shader_params,
41            )),
42            symbol_list: ShaderSymbolList::default(),
43        }
44    }
45    pub fn get_all_symbols<'a>(&'a self) -> ShaderSymbolListRef<'a> {
46        let mut symbols = self.get_local_symbols();
47        for include in &self.preprocessor.includes {
48            assert!(
49                include.cache.is_some(),
50                "Include {} do not have cache, but is being queried.\n{}",
51                include.get_relative_path(),
52                self.dump_dependency_tree(&PathBuf::from("oui"))
53            );
54            symbols.append(include.get_cache().get_all_symbols());
55        }
56        symbols
57    }
58    pub fn get_local_symbols<'a>(&'a self) -> ShaderSymbolListRef<'a> {
59        self.preprocessor.preprocess_symbols(&self.symbol_list)
60    }
61    pub fn get_context(&self) -> &ShaderPreprocessorContext {
62        &self.preprocessor.context
63    }
64    // TODO: should abstract this.
65    pub fn get_preprocessor(&self) -> &ShaderPreprocessor {
66        &self.preprocessor
67    }
68    pub fn get_preprocessor_mut(&mut self) -> &mut ShaderPreprocessor {
69        &mut self.preprocessor
70    }
71    pub fn visit_includes<F: FnMut(&ShaderPreprocessorInclude)>(&self, callback: &mut F) {
72        for include in &self.preprocessor.includes {
73            callback(&include);
74            include.get_cache().visit_includes(callback);
75        }
76    }
77    pub fn visit_includes_mut<F: FnMut(&mut ShaderPreprocessorInclude)>(
78        &mut self,
79        callback: &mut F,
80    ) {
81        for include in &mut self.preprocessor.includes {
82            callback(include);
83            include.cache.as_mut().unwrap().visit_includes_mut(callback);
84        }
85    }
86    pub fn find_include_stack<F: FnMut(&ShaderPreprocessorInclude) -> bool>(
87        &self,
88        callback: &mut F,
89    ) -> Option<Vec<&ShaderPreprocessorInclude>> {
90        for include in &self.preprocessor.includes {
91            if callback(&include) {
92                return Some(vec![&include]);
93            } else {
94                match include.get_cache().find_include_stack(callback) {
95                    Some(mut stack) => {
96                        stack.insert(0, include);
97                        return Some(stack);
98                    }
99                    None => {}
100                }
101            }
102        }
103        None
104    }
105    pub fn find_include<F: FnMut(&ShaderPreprocessorInclude) -> bool>(
106        &self,
107        callback: &mut F,
108    ) -> Option<&ShaderPreprocessorInclude> {
109        for include in &self.preprocessor.includes {
110            if callback(&include) {
111                return Some(&include);
112            } else {
113                match include.get_cache().find_include(callback) {
114                    Some(include) => {
115                        return Some(&include);
116                    }
117                    None => {}
118                }
119            }
120        }
121        None
122    }
123    pub fn find_direct_includer(&self, include_path: &Path) -> Option<&ShaderPreprocessorInclude> {
124        match self.find_include_stack(&mut |include| *include.get_absolute_path() == *include_path)
125        {
126            Some(stack) => {
127                assert!(!stack.is_empty());
128                Some(stack[0])
129            }
130            None => None,
131        }
132    }
133    pub fn has_dependency(&self, dependency_to_find_path: &Path) -> bool {
134        self.find_include(&mut |e| *e.get_absolute_path() == *dependency_to_find_path)
135            .is_some()
136    }
137    fn dump_dependency_node(
138        &self,
139        include: &ShaderPreprocessorInclude,
140        header: String,
141        is_last: bool,
142    ) -> String {
143        let mut dependency_tree = format!(
144            "{}{} {} ({})\n",
145            header,
146            if is_last { "└─" } else { "├─" },
147            include.get_absolute_path().display(),
148            match &include.cache {
149                Some(cache) => format!("Mode: {:?}", cache.preprocessor.mode),
150                None => "Missing cache".into(),
151            }
152        );
153        let childs_header = format!("{}{}", header, if is_last { "  " } else { "|  " });
154        let mut deps_iter = match &include.cache {
155            Some(data) => data.preprocessor.includes.iter().peekable(),
156            None => {
157                return dependency_tree;
158            }
159        };
160        while let Some(included_include) = deps_iter.next() {
161            dependency_tree.push_str(
162                self.dump_dependency_node(
163                    included_include,
164                    childs_header.clone(),
165                    deps_iter.peek().is_none(),
166                )
167                .as_str(),
168            );
169        }
170        dependency_tree
171    }
172    pub fn dump_dependency_tree(&self, absolute_path: &PathBuf) -> String {
173        let mut dependency_tree = format!("{}\n", absolute_path.display());
174        let mut deps_iter = self.preprocessor.includes.iter().peekable();
175        while let Some(include) = deps_iter.next() {
176            dependency_tree.push_str(
177                self.dump_dependency_node(include, "   ".into(), deps_iter.peek().is_none())
178                    .as_str(),
179            );
180        }
181        dependency_tree
182    }
183}
184
185impl ShaderModule {
186    // Dump AST from tree
187    pub fn dump_ast(&self) -> String {
188        Self::dump_ast_node(self.tree.root_node())
189    }
190    pub fn dump_ast_node(node: tree_sitter::Node) -> String {
191        fn format_debug_cursor(cursor: &mut TreeCursor, depth: usize) -> String {
192            let mut debug_tree = String::new();
193            loop {
194                debug_tree.push_str(&match cursor.field_name() {
195                    Some(field_name) => format!(
196                        "{}{}: {} [{}, {}] - [{}, {}]\n",
197                        " ".repeat(depth * 2),
198                        field_name,
199                        if cursor.node().is_named() {
200                            cursor.node().kind().into()
201                        } else {
202                            format!("\"{}\"", cursor.node().kind())
203                        },
204                        cursor.node().range().start_point.row,
205                        cursor.node().range().start_point.column,
206                        cursor.node().range().end_point.row,
207                        cursor.node().range().end_point.column,
208                    ),
209                    None => format!(
210                        "{}{} [{}, {}] - [{}, {}]\n",
211                        " ".repeat(depth * 2),
212                        if cursor.node().is_named() {
213                            cursor.node().kind().into()
214                        } else {
215                            format!("\"{}\"", cursor.node().kind())
216                        },
217                        cursor.node().range().start_point.row,
218                        cursor.node().range().start_point.column,
219                        cursor.node().range().end_point.row,
220                        cursor.node().range().end_point.column,
221                    ),
222                });
223                if cursor.goto_first_child() {
224                    debug_tree.push_str(format_debug_cursor(cursor, depth + 1).as_str());
225                    cursor.goto_parent();
226                }
227                if !cursor.goto_next_sibling() {
228                    break;
229                }
230            }
231            debug_tree
232        }
233        format_debug_cursor(&mut node.walk(), 0)
234    }
235}