shader_prepper/
lib.rs

1//! **shader-prepper** is a shader include parser and crawler. It is mostly aimed at GLSL
2//! which doesn't provide include directive support out of the box.
3//!
4//! This crate does not implement a full C-like preprocessor, only `#include` scanning.
5//! Other directives are instead copied into the expanded code, so they can be subsequently
6//! handled by the shader compiler.
7//!
8//! The API supports user-driven include file providers, which enable custom
9//! virtual file systems, include paths, and allow build systems to track dependencies.
10//!
11//! Source files are not concatenated together, but returned as a Vec of [`SourceChunk`].
12//! If a single string is needed, a `join` over the source strings can be used.
13//! Otherwise, the individual chunks can be passed to the graphics API, and source info
14//! contained within `SourceChunk` can then remap the compiler's errors back to
15//! the original code.
16//!
17//! # Example
18//!
19//! ```rust
20//! use failure;
21//!
22//! struct FileIncludeProvider;
23//! impl shader_prepper::IncludeProvider for FileIncludeProvider {
24//! 	type IncludeContext = ();
25//!
26//!     fn get_include(
27//!         &mut self,
28//!         path: &str,
29//!         _context: &Self::IncludeContext,
30//!     ) -> Result<(String, Self::IncludeContext), failure::Error> {
31//!         std::fs::read_to_string(path)
32//!             .map_err(|e| failure::format_err!("{}", e))
33//!             .map(|res| (res, ()))
34//!     }
35//! }
36//!
37//! // ...
38//!
39//! let chunks = shader_prepper::process_file("myfile.glsl", &mut FileIncludeProvider, ());
40//! ```
41
42#[macro_use]
43extern crate failure;
44
45use std::collections::HashSet;
46use std::iter::Peekable;
47use std::str::Chars;
48
49use failure::Error;
50
51#[derive(Debug, Fail)]
52pub enum PrepperError {
53    /// Any error reported by the user-supplied `IncludeProvider`
54    #[fail(
55        display = "include provider error: \"{}\" when trying to include {}",
56        cause, file
57    )]
58    IncludeProviderError {
59        file: String,
60        #[cause]
61        cause: Error,
62    },
63
64    /// Recursively included file, along with information about where it was encountered
65    #[fail(
66        display = "file {} is recursively included; triggered in {} ({})",
67        file, from, from_line
68    )]
69    RecursiveInclude {
70        /// File which was included recursively
71        file: String,
72
73        /// File which included the recursively included one
74        from: String,
75
76        /// Line in the `from` file on which the include happened
77        from_line: usize,
78    },
79
80    /// Error parsing an include directive
81    #[fail(display = "parse error: {} ({})", file, line)]
82    ParseError { file: String, line: usize },
83}
84
85/// User-supplied include reader
86pub trait IncludeProvider {
87    type IncludeContext;
88
89    fn get_include(
90        &mut self,
91        path: &str,
92        context: &Self::IncludeContext,
93    ) -> Result<(String, Self::IncludeContext), Error>;
94}
95
96/// Chunk of source code along with information pointing back at the origin
97#[derive(PartialEq, Eq, Debug)]
98pub struct SourceChunk {
99    /// Source text
100    pub source: String,
101
102    /// File the code came from
103    pub file: String,
104
105    /// Line in the `file` at which this snippet starts
106    pub line_offset: usize,
107}
108
109/// Process a single file, and then any code recursively referenced.
110///
111/// `include_provider` is used to read all of the files, including the one at `file_path`.
112pub fn process_file<IncludeContext>(
113    file_path: &str,
114    include_provider: &mut dyn IncludeProvider<IncludeContext = IncludeContext>,
115    include_context: IncludeContext,
116) -> Result<Vec<SourceChunk>, Error> {
117    let mut prior_includes = HashSet::new();
118    let mut scanner = Scanner::new(
119        "",
120        String::new(),
121        &mut prior_includes,
122        include_provider,
123        include_context,
124    );
125    scanner.include_child(file_path, 1)?;
126    Ok(scanner.chunks)
127}
128
129#[derive(Clone)]
130struct LocationTracking<I> {
131    iter: I,
132    line: u32,
133}
134
135impl<I> Iterator for LocationTracking<I>
136where
137    I: Iterator<Item = char>,
138{
139    type Item = (u32, <I as Iterator>::Item);
140
141    #[inline]
142    fn next(&mut self) -> Option<(u32, <I as Iterator>::Item)> {
143        self.iter.next().map(|a| {
144            let nl = a == '\n';
145            let ret = (self.line, a);
146            // Possible undefined overflow.
147            if nl {
148                self.line += 1;
149            }
150            ret
151        })
152    }
153}
154
155// Inspired by JayKickliter/monkey
156struct Scanner<'a, 'b, 'c, IncludeContext> {
157    include_provider: &'b mut dyn IncludeProvider<IncludeContext = IncludeContext>,
158    include_context: IncludeContext,
159    input_iter: Peekable<LocationTracking<Chars<'a>>>,
160    this_file: String,
161    prior_includes: &'c mut HashSet<String>,
162    chunks: Vec<SourceChunk>,
163    current_chunk: String,
164    current_chunk_first_line: u32,
165}
166
167impl<'a, 'b, 'c, IncludeContext> Scanner<'a, 'b, 'c, IncludeContext> {
168    fn new(
169        input: &'a str,
170        this_file: String,
171        prior_includes: &'c mut HashSet<String>,
172        include_provider: &'b mut dyn IncludeProvider<IncludeContext = IncludeContext>,
173        include_context: IncludeContext,
174    ) -> Scanner<'a, 'b, 'c, IncludeContext> {
175        Scanner {
176            include_provider,
177            include_context,
178            input_iter: LocationTracking {
179                iter: input.chars(),
180                line: 1,
181            }
182            .peekable(),
183            this_file,
184            prior_includes,
185            chunks: Vec::new(),
186            current_chunk: String::new(),
187            current_chunk_first_line: 1,
188        }
189    }
190
191    fn read_char(&mut self) -> Option<(u32, char)> {
192        self.input_iter.next()
193    }
194
195    fn peek_char(&mut self) -> Option<&(u32, char)> {
196        self.input_iter.peek()
197    }
198
199    fn skip_whitespace_until_eol(&mut self) {
200        while let Some(&(_, c)) = self.peek_char() {
201            if c == '\n' {
202                break;
203            } else if c.is_whitespace() {
204                let _ = self.read_char();
205            } else if c == '\\' {
206                let mut peek_next = self.input_iter.clone();
207                let _ = peek_next.next();
208                if let Some(&(_, '\n')) = peek_next.peek() {
209                    let _ = self.read_char();
210                    let _ = self.read_char();
211                } else {
212                    break;
213                }
214            } else if c == '/' {
215                let mut next_peek = self.input_iter.clone();
216                let _ = next_peek.next();
217
218                if let Some(&(_, '*')) = next_peek.peek() {
219                    // Block comment. Skip it.
220                    let _ = self.read_char();
221                    let _ = self.read_char();
222
223                    self.input_iter = Self::skip_block_comment(self.input_iter.clone()).1;
224                }
225            } else {
226                break;
227            }
228        }
229    }
230
231    fn read_string(&mut self, right_delim: char) -> Option<String> {
232        let mut s = String::new();
233
234        while let Some(&(_, c)) = self.peek_char() {
235            if c == '\n' {
236                break;
237            } else if c == '\\' {
238                let _ = self.read_char();
239                let _ = self.read_char();
240            } else if c == right_delim {
241                let _ = self.read_char();
242                return Some(s);
243            } else {
244                s.push(c);
245                let _ = self.read_char();
246            }
247        }
248
249        None
250    }
251
252    fn skip_block_comment(
253        mut it: Peekable<LocationTracking<Chars<'a>>>,
254    ) -> (String, Peekable<LocationTracking<Chars<'a>>>) {
255        let mut s = String::new();
256
257        while let Some((_, c)) = it.next() {
258            if c == '*' {
259                s.push(' ');
260                if let Some(&(_, '/')) = it.peek() {
261                    let _ = it.next();
262                    s.push(' ');
263                    break;
264                }
265            } else if c == '\n' {
266                s.push('\n');
267            } else {
268                s.push(' ');
269            }
270        }
271
272        (s, it)
273    }
274
275    fn skip_line(&mut self) {
276        while let Some((_, c)) = self.read_char() {
277            if c == '\n' {
278                self.current_chunk.push('\n');
279                break;
280            } else if c == '\\' {
281                if let Some((_, '\n')) = self.read_char() {
282                    self.current_chunk.push('\n');
283                }
284            }
285        }
286    }
287
288    fn peek_preprocessor_ident(
289        &mut self,
290    ) -> Option<(String, Peekable<LocationTracking<Chars<'a>>>)> {
291        let mut token = String::new();
292        let mut it = self.input_iter.clone();
293
294        while let Some(&(_, c)) = it.peek() {
295            if '\n' == c || '\r' == c {
296                break;
297            } else if c.is_alphabetic() {
298                let _ = it.next();
299                token.push(c);
300            } else if c.is_whitespace() {
301                if !token.is_empty() {
302                    // Already found some chars, and this ends the identifier
303                    break;
304                } else {
305                    // Still haven't found anything. Continue scanning.
306                    let _ = it.next();
307                }
308            } else if '\\' == c {
309                let _ = it.next();
310                let next = it.next();
311
312                if let Some((_, '\n')) = next {
313                    // Continue scanning on next line
314                    continue;
315                } else if let (Some((_, '\r')), Some(&(_, '\n'))) = (next, it.peek()) {
316                    // ditto, but Windows-special
317                    let _ = it.next();
318                    continue;
319                } else {
320                    // Unrecognized escape sequence. Abort.
321                    return None;
322                }
323            } else if '/' == c {
324                if !token.is_empty() {
325                    // Already found some chars, and this ends the identifier
326                    break;
327                }
328
329                let mut next_peek = it.clone();
330                let _ = next_peek.next();
331
332                if let Some(&(_, '*')) = next_peek.peek() {
333                    // Block comment. Skip it.
334                    let _ = it.next();
335                    let _ = it.next();
336
337                    it = Self::skip_block_comment(it).1;
338                } else {
339                    // Something other than a block comment. End the identifier.
340                    break;
341                }
342            } else {
343                // Some other character. This finishes the identifier.
344                break;
345            }
346        }
347
348        Some((token, it))
349    }
350
351    fn flush_current_chunk(&mut self) {
352        if !self.current_chunk.is_empty() {
353            self.chunks.push(SourceChunk {
354                file: self.this_file.clone(),
355                line_offset: (self.current_chunk_first_line - 1) as usize,
356                source: self.current_chunk.clone(),
357            });
358            self.current_chunk.clear();
359        }
360
361        if let Some(&(line, _)) = self.peek_char() {
362            self.current_chunk_first_line = line;
363        }
364    }
365
366    fn include_child(&mut self, path: &str, included_on_line: u32) -> Result<(), PrepperError> {
367        if self.prior_includes.contains(path) {
368            return Err(PrepperError::RecursiveInclude {
369                file: path.to_string(),
370                from: self.this_file.clone(),
371                from_line: included_on_line as usize,
372            });
373        }
374
375        self.flush_current_chunk();
376
377        let (child_code, child_include_context) = self
378            .include_provider
379            .get_include(path, &self.include_context)
380            .map_err(|e| PrepperError::IncludeProviderError {
381                file: path.to_string(),
382                cause: e,
383            })?;
384
385        self.prior_includes.insert(path.to_string());
386
387        self.chunks.append(&mut {
388            let mut child_scanner = Scanner::new(
389                &child_code,
390                path.to_string(),
391                &mut self.prior_includes,
392                self.include_provider,
393                child_include_context,
394            );
395            child_scanner.process_input()?;
396            child_scanner.chunks
397        });
398
399        self.prior_includes.remove(path);
400
401        Ok(())
402    }
403
404    fn process_input(&mut self) -> Result<(), PrepperError> {
405        while let Some((c_line, c)) = self.read_char() {
406            match c {
407                '/' => {
408                    let next = self.peek_char();
409
410                    if let Some(&(_, '*')) = next {
411                        let _ = self.read_char();
412                        self.current_chunk.push_str("  ");
413                        let (white, it) = Self::skip_block_comment(self.input_iter.clone());
414
415                        self.input_iter = it;
416                        self.current_chunk.push_str(&white);
417                    } else if let Some(&(_, '/')) = next {
418                        let _ = self.read_char();
419                        self.skip_line();
420                    } else {
421                        self.current_chunk.push(c);
422                    }
423                }
424                '#' => {
425                    if let Some(preprocessor_ident) = self.peek_preprocessor_ident() {
426                        if "include" == preprocessor_ident.0 {
427                            self.input_iter = preprocessor_ident.1;
428                            self.skip_whitespace_until_eol();
429
430                            let left_delim = self.read_char();
431
432                            let right_delim = match left_delim {
433                                Some((_, '"')) => Some('"'),
434                                Some((_, '<')) => Some('>'),
435                                _ => None,
436                            };
437
438                            let path = right_delim
439                                .map(|right_delim| self.read_string(right_delim))
440                                .unwrap_or_default();
441
442                            if let Some(ref path) = path {
443                                self.include_child(path, c_line)?;
444                            } else {
445                                return Err(PrepperError::ParseError {
446                                    file: self.this_file.clone(),
447                                    line: c_line as usize,
448                                });
449                            }
450                        } else {
451                            self.current_chunk.push(c);
452                        }
453                    } else {
454                        self.current_chunk.push(c);
455                    }
456                }
457                _ => {
458                    self.current_chunk.push(c);
459                }
460            }
461        }
462
463        self.flush_current_chunk();
464        Ok(())
465    }
466}
467
468#[cfg(test)]
469mod tests {
470    use std::collections::{HashMap, HashSet};
471
472    struct DummyIncludeProvider;
473    impl crate::IncludeProvider for DummyIncludeProvider {
474        type IncludeContext = ();
475
476        fn get_include(
477            &mut self,
478            path: &str,
479            _context: &Self::IncludeContext,
480        ) -> Result<(String, Self::IncludeContext), crate::Error> {
481            Ok((String::from("[") + path + "]", ()))
482        }
483    }
484
485    struct HashMapIncludeProvider(HashMap<String, String>);
486    impl crate::IncludeProvider for HashMapIncludeProvider {
487        type IncludeContext = ();
488
489        fn get_include(
490            &mut self,
491            path: &str,
492            _context: &Self::IncludeContext,
493        ) -> Result<(String, Self::IncludeContext), crate::Error> {
494            Ok((self.0.get(path).unwrap().clone(), ()))
495        }
496    }
497
498    fn preprocess_into_string<IncludeContext>(
499        s: &str,
500        include_provider: &mut crate::IncludeProvider<IncludeContext = IncludeContext>,
501        include_context: IncludeContext,
502    ) -> Result<String, crate::PrepperError> {
503        let mut prior_includes = HashSet::new();
504        let mut scanner = crate::Scanner::new(
505            s,
506            "no-file".to_string(),
507            &mut prior_includes,
508            include_provider,
509            include_context,
510        );
511        scanner.process_input()?;
512        Ok(scanner
513            .chunks
514            .into_iter()
515            .map(|chunk| chunk.source)
516            .collect::<Vec<_>>()
517            .join(""))
518    }
519
520    fn test_string(s: &str, s2: &str) {
521        match preprocess_into_string(s, &mut DummyIncludeProvider, ()) {
522            Ok(r) => assert_eq!(r, s2.to_string()),
523            val @ _ => panic!("{:?}", val),
524        };
525    }
526
527    #[test]
528    fn ignore_unrecognized() {
529        test_string("*/ */ \t/ /", "*/ */ \t/ /");
530        test_string("int foo;", "int foo;");
531        test_string("#version 430\n#pragma stuff", "#version 430\n#pragma stuff");
532    }
533
534    #[test]
535    fn basic_block_comment() {
536        test_string("foo /* bar */ baz", "foo           baz");
537        test_string("foo /* /* bar */ baz", "foo              baz");
538    }
539
540    #[test]
541    fn basic_line_comment() {
542        test_string("foo // baz", "foo ");
543        test_string("// foo /* bar */ baz", "");
544    }
545
546    #[test]
547    fn continued_line_comment() {
548        test_string("foo // baz\nbar", "foo \nbar");
549        test_string("foo // baz\\\nbar", "foo \n");
550    }
551
552    #[test]
553    fn mixed_comments() {
554        test_string("/*\nfoo\n/*/\nbar\n//*/", "  \n   \n   \nbar\n");
555        test_string("//*\nfoo\n/*/\nbar\n//*/", "\nfoo\n   \n   \n    ");
556    }
557
558    #[test]
559    fn basic_preprocessor() {
560        test_string("#", "#");
561        test_string("#in/**/clude", "#in    clude");
562        test_string("#in\nclude", "#in\nclude");
563    }
564
565    #[test]
566    fn basic_include() {
567        test_string(r#"#include"foo""#, "[foo]");
568        test_string(r#"#include "foo""#, "[foo]");
569        test_string("#include <foo>", "[foo]");
570        test_string("#include <foo/bar/baz>", "[foo/bar/baz]");
571        test_string("#include <foo\\\nbar\\\nbaz>", "[foobarbaz]");
572        test_string("#include <foo>//\n", "[foo]\n");
573        test_string("# include <foo>", "[foo]");
574        test_string("#  include <foo>", "[foo]");
575        test_string("#/**/include <foo>", "[foo]");
576        test_string("#include /**/ <foo>", "[foo]");
577    }
578
579    #[test]
580    fn multi_line_include() {
581        match preprocess_into_string("#inc\\\nlude", &mut DummyIncludeProvider, ()) {
582            Err(crate::PrepperError::ParseError { file: _, line: 1 }) => (),
583            _ => panic!(),
584        }
585
586        test_string("#inc\\\nlude <foo>", "[foo]");
587        test_string("#\\\ninc\\\n\\\nlude <foo>", "[foo]");
588        test_string("#\\\n   inc\\\n\\\nlude <foo>", "[foo]");
589    }
590
591    #[test]
592    fn multi_level_include() {
593        let mut include_provider = HashMapIncludeProvider(
594            [
595                (
596                    "foo",
597                    "double rainbow;\n#include <bar>\nint spam;\n#include <baz>\nvoid ham();",
598                ),
599                ("bar", "int bar;"),
600                ("baz", "int baz;"),
601            ]
602            .iter()
603            .map(|(a, b)| (a.to_string(), b.to_string()))
604            .collect(),
605        );
606
607        assert_eq!(
608            preprocess_into_string("#include <bar>", &mut include_provider, ()).unwrap(),
609            "int bar;"
610        );
611        assert_eq!(
612            preprocess_into_string("#include <foo>", &mut include_provider, ()).unwrap(),
613            "double rainbow;\nint bar;\nint spam;\nint baz;\nvoid ham();"
614        );
615
616        assert_eq!(
617            crate::process_file("foo", &mut include_provider, ()).unwrap(),
618            vec![
619                crate::SourceChunk {
620                    file: "foo".to_string(),
621                    line_offset: 0,
622                    source: "double rainbow;\n".to_string()
623                },
624                crate::SourceChunk {
625                    file: "bar".to_string(),
626                    line_offset: 0,
627                    source: "int bar;".to_string()
628                },
629                crate::SourceChunk {
630                    file: "foo".to_string(),
631                    line_offset: 1,
632                    source: "\nint spam;\n".to_string()
633                },
634                crate::SourceChunk {
635                    file: "baz".to_string(),
636                    line_offset: 0,
637                    source: "int baz;".to_string()
638                },
639                crate::SourceChunk {
640                    file: "foo".to_string(),
641                    line_offset: 3,
642                    source: "\nvoid ham();".to_string()
643                },
644            ]
645        );
646    }
647
648    #[test]
649    fn include_err() {
650        match preprocess_into_string("#include", &mut DummyIncludeProvider, ()) {
651            Err(crate::PrepperError::ParseError { file: _, line: 1 }) => (),
652            val @ _ => panic!("{:?}", val),
653        }
654
655        match preprocess_into_string("#include @", &mut DummyIncludeProvider, ()) {
656            Err(crate::PrepperError::ParseError { file: _, line: 1 }) => (),
657            val @ _ => panic!("{:?}", val),
658        }
659
660        match preprocess_into_string("#include <foo", &mut DummyIncludeProvider, ()) {
661            Err(crate::PrepperError::ParseError { file: _, line: 1 }) => (),
662            val @ _ => panic!("{:?}", val),
663        }
664
665        let mut recursive_include_provider = HashMapIncludeProvider(
666            [
667                ("foo", "#include <bar>"),
668                ("bar", "#include <baz>"),
669                ("baz", "#include <foo>"),
670            ]
671            .iter()
672            .map(|(a, b)| (a.to_string(), b.to_string()))
673            .collect(),
674        );
675
676        match &preprocess_into_string("#include <foo>", &mut recursive_include_provider, ()) {
677            Err(crate::PrepperError::RecursiveInclude {
678                file: fname @ _,
679                from: fsrc @ _,
680                from_line: 1,
681            }) if fname == "foo" && fsrc == "baz" => (),
682            val @ _ => panic!("{:?}", val),
683        }
684    }
685
686    struct FileIncludeProvider;
687    impl crate::IncludeProvider for FileIncludeProvider {
688        type IncludeContext = ();
689
690        fn get_include(
691            &mut self,
692            path: &str,
693            _context: &Self::IncludeContext,
694        ) -> Result<(String, Self::IncludeContext), failure::Error> {
695            std::fs::read_to_string(path)
696                .map_err(|e| failure::format_err!("{}", e))
697                .map(|res| (res, ()))
698        }
699    }
700
701    #[test]
702    fn include_file() {
703        assert!(preprocess_into_string("src/lib.rs", &mut FileIncludeProvider, ()).is_ok());
704    }
705}