spirv_compiler/
lib.rs

1pub use shaderc::{
2    GlslProfile, Limit, OptimizationLevel, ResourceKind, ShaderKind, SourceLanguage, SpirvVersion,
3    TargetEnv,
4};
5use std::{
6    cmp::Ordering,
7    collections::HashMap,
8    error::Error,
9    ffi::OsString,
10    fmt::{Debug, Display},
11    fs::File,
12    io::{Read, Write},
13    path::{Path, PathBuf},
14    sync::{Arc, Mutex},
15};
16
17#[derive(Debug, Clone)]
18pub enum CompilerError {
19    Log(CompilationError),
20    LoadError(String),
21    WriteError(String),
22}
23
24impl Display for CompilerError {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        write!(
27            f,
28            "Error: {}",
29            match self {
30                CompilerError::Log(e) => format!("{}", e),
31                CompilerError::LoadError(e) => format!("could not load file: {}", e),
32                CompilerError::WriteError(e) => format!("could not write file: {}", e),
33            }
34        )
35    }
36}
37
38impl Error for CompilerError {}
39
40#[derive(Debug, Clone)]
41pub struct CompilationError {
42    pub file: Option<PathBuf>,
43    pub description: String,
44}
45
46impl From<CompilationError> for CompilerError {
47    fn from(val: CompilationError) -> Self {
48        CompilerError::Log(val)
49    }
50}
51
52impl Display for CompilationError {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        let message = if let Some(file) = self.file.as_ref() {
55            format!(
56                "file: {}, description: {}",
57                file.display(),
58                self.description.as_str(),
59            )
60        } else {
61            format!("description: {}", self.description.as_str())
62        };
63
64        write!(f, "{}", message)
65    }
66}
67
68pub struct CompilerBuilder<'a> {
69    options: shaderc::CompileOptions<'a>,
70    include_dirs: Vec<PathBuf>,
71    has_macros: bool,
72}
73
74impl Default for CompilerBuilder<'_> {
75    fn default() -> Self {
76        Self::new()
77    }
78}
79
80impl<'a> CompilerBuilder<'a> {
81    pub fn new() -> CompilerBuilder<'a> {
82        CompilerBuilder {
83            options: shaderc::CompileOptions::new().unwrap(),
84            include_dirs: Vec::new(),
85            has_macros: false,
86        }
87    }
88
89    pub fn with_target_spirv(mut self, version: SpirvVersion) -> Self {
90        self.options.set_target_spirv(version);
91        self
92    }
93
94    pub fn with_macro(mut self, name: &str, value: Option<&str>) -> Self {
95        self.options.add_macro_definition(name, value);
96        self.has_macros = true;
97        self
98    }
99
100    pub fn with_auto_bind_uniforms(mut self, auto_bind: bool) -> Self {
101        self.options.set_auto_bind_uniforms(auto_bind);
102        self
103    }
104
105    pub fn with_binding_base(mut self, kind: ResourceKind, base: u32) -> Self {
106        self.options.set_binding_base(kind, base);
107        self
108    }
109
110    pub fn generate_debug_info(mut self) -> Self {
111        self.options.set_generate_debug_info();
112        self
113    }
114
115    pub fn force_version_profile(mut self, version: u32, profile: shaderc::GlslProfile) -> Self {
116        self.options.set_forced_version_profile(version, profile);
117        self
118    }
119
120    pub fn with_target_env(mut self, env: shaderc::TargetEnv, version: u32) -> Self {
121        self.options.set_target_env(env, version);
122        self
123    }
124
125    pub fn with_hlsl_io_mapping(mut self, iomap: bool) -> Self {
126        self.options.set_hlsl_io_mapping(iomap);
127        self
128    }
129
130    pub fn with_hlsl_register_set_and_binding(
131        mut self,
132        register: &str,
133        set: &str,
134        binding: &str,
135    ) -> Self {
136        self.options
137            .set_hlsl_register_set_and_binding(register, set, binding);
138        self
139    }
140
141    pub fn with_hlsl_offsets(mut self, offsets: bool) -> Self {
142        self.options.set_hlsl_offsets(offsets);
143        self
144    }
145
146    pub fn with_source_language(mut self, lang: SourceLanguage) -> Self {
147        self.options.set_source_language(lang);
148        self
149    }
150
151    pub fn with_binding_base_for_stage(
152        mut self,
153        kind: shaderc::ShaderKind,
154        resource_kind: shaderc::ResourceKind,
155        base: u32,
156    ) -> Self {
157        self.options
158            .set_binding_base_for_stage(kind, resource_kind, base);
159        self
160    }
161
162    pub fn with_opt_level(mut self, level: OptimizationLevel) -> Self {
163        self.options.set_optimization_level(level);
164        self
165    }
166
167    pub fn supress_warnings(mut self) -> Self {
168        self.options.set_suppress_warnings();
169        self
170    }
171
172    pub fn with_warnings_as_errors(mut self) -> Self {
173        self.options.set_warnings_as_errors();
174        self
175    }
176
177    pub fn with_limit(mut self, limit: shaderc::Limit, value: i32) -> Self {
178        self.options.set_limit(limit, value);
179        self
180    }
181
182    pub fn with_include_dir<T: AsRef<Path>>(mut self, path: T) -> Self {
183        debug_assert!(path.as_ref().exists());
184        self.include_dirs.push(path.as_ref().to_path_buf());
185        self
186    }
187
188    pub fn build(self) -> Option<Compiler<'a>> {
189        if let Some(compiler) = shaderc::Compiler::new() {
190            let mut compiler = Compiler {
191                compiler,
192                options: self.options,
193                compile_cache: HashMap::new(),
194                include_dirs: Arc::new(Mutex::new(self.include_dirs)),
195                has_macros: self.has_macros,
196            };
197
198            let include_dirs = compiler.include_dirs.clone();
199            compiler.options.set_include_callback(
200                move |requested_source, include_type, requesting_source, include_depth| {
201                    Compiler::include_callback(
202                        include_dirs.lock().unwrap().as_slice(),
203                        requested_source,
204                        include_type,
205                        requesting_source,
206                        include_depth,
207                    )
208                },
209            );
210
211            Some(compiler)
212        } else {
213            None
214        }
215    }
216}
217
218pub struct Compiler<'a> {
219    compiler: shaderc::Compiler,
220    options: shaderc::CompileOptions<'a>,
221    compile_cache: HashMap<PathBuf, Vec<u32>>,
222    include_dirs: Arc<Mutex<Vec<PathBuf>>>,
223    has_macros: bool,
224}
225
226impl Debug for Compiler<'_> {
227    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228        f.debug_struct("Compiler")
229            .field("compile_cache", &self.compile_cache)
230            .field("include_dirs", &self.include_dirs)
231            .field("has_macros", &self.has_macros)
232            .finish()
233    }
234}
235
236impl<'a> Compiler<'a> {
237    pub fn new() -> Option<Compiler<'a>> {
238        if let Some(compiler) = shaderc::Compiler::new() {
239            return Some(Compiler {
240                compiler,
241                options: shaderc::CompileOptions::new().unwrap(),
242                compile_cache: HashMap::new(),
243                include_dirs: Arc::new(Mutex::new(Vec::new())),
244                has_macros: false,
245            });
246        }
247        None
248    }
249
250    pub fn add_macro_definition(&mut self, name: &str, value: Option<&str>) {
251        self.options.add_macro_definition(name, value);
252        self.has_macros = true;
253    }
254
255    pub(crate) fn include_callback(
256        include_dirs: &[PathBuf],
257        requested_source: &str,
258        include_type: shaderc::IncludeType,
259        requesting_source: &str,
260        include_depth: usize,
261    ) -> Result<shaderc::ResolvedInclude, String> {
262        use shaderc::{IncludeType, ResolvedInclude};
263        if include_depth >= 32 {
264            return Err(format!(
265                "Include depth {} too high!",
266                include_depth
267            ));
268        }
269
270        let requested_path = PathBuf::from(String::from(requested_source));
271        let requesting_path = PathBuf::from(String::from(requesting_source));
272
273        if include_type == IncludeType::Standard {
274            for path in include_dirs {
275                let final_path = path.join(requested_path.as_path());
276                if final_path.exists() {
277                    if let Ok(mut file) = File::open(final_path.clone()) {
278                        let mut source = String::new();
279                        file.read_to_string(&mut source).unwrap();
280                        return Ok(ResolvedInclude {
281                            resolved_name: String::from(final_path.to_str().unwrap()),
282                            content: source,
283                        });
284                    }
285                }
286            }
287
288            return Err(format!(
289                "Could not find file: {}",
290                requested_source
291            ));
292        } else if include_type == IncludeType::Relative {
293            // #include ""
294            let base_folder = requesting_path.as_path().parent().unwrap();
295            let final_path = base_folder.join(requested_path.clone());
296            if final_path.exists() {
297                if let Ok(mut file) = File::open(final_path.clone()) {
298                    let mut source = String::new();
299                    file.read_to_string(&mut source).unwrap();
300                    return Ok(ResolvedInclude {
301                        resolved_name: String::from(final_path.to_str().unwrap()),
302                        content: source,
303                    });
304                }
305            }
306
307            for path in include_dirs {
308                let final_path = path.join(requested_path.as_path());
309                if final_path.exists() {
310                    if let Ok(mut file) = File::open(final_path.clone()) {
311                        let mut source = String::new();
312                        file.read_to_string(&mut source).unwrap();
313                        return Ok(ResolvedInclude {
314                            resolved_name: String::from(final_path.to_str().unwrap()),
315                            content: source,
316                        });
317                    }
318                }
319            }
320
321            return Err(format!(
322                "Could not find file: {}",
323                requested_source
324            ));
325        }
326
327        Err(format!(
328            "Unkown error resolving file: {}",
329            requested_source
330        ))
331    }
332
333    pub fn compile_from_string(
334        &mut self,
335        source: &str,
336        kind: shaderc::ShaderKind,
337    ) -> Result<Vec<u32>, CompilerError> {
338        let binary_result =
339            self.compiler
340                .compile_into_spirv(source, kind, "memory", "main", Some(&self.options));
341
342        match binary_result {
343            Err(e) => Err(CompilationError {
344                file: None,
345                description: e.to_string(),
346            }
347            .into()),
348            Ok(result) => Ok(result.as_binary().to_vec()),
349        }
350    }
351
352    pub fn compile_from_file<T: AsRef<Path>>(
353        &mut self,
354        path: T,
355        kind: shaderc::ShaderKind,
356        cache: bool,
357    ) -> Result<Vec<u32>, CompilerError> {
358        let mut precompiled = OsString::from(path.as_ref().as_os_str());
359        precompiled.push(".spv");
360        let precompiled = PathBuf::from(precompiled);
361
362        if cache {
363            if let Some(binary) = self.compile_cache.get(&path.as_ref().to_path_buf()) {
364                return Ok(binary.clone());
365            }
366
367            if precompiled.exists() && !self.has_macros {
368                let should_recompile: bool = if let (Ok(meta_data), Ok(pre_meta_data)) =
369                    (path.as_ref().metadata(), precompiled.metadata())
370                {
371                    let source_last_modified = meta_data.modified();
372                    let last_modified = pre_meta_data.modified();
373                    if let (Ok(source_last_modified), Ok(last_modified)) =
374                        (source_last_modified, last_modified)
375                    {
376                        source_last_modified.cmp(&last_modified) == Ordering::Less
377                    } else {
378                        true
379                    }
380                } else {
381                    true
382                };
383
384                // Only load pre-compiled files if they are up to date
385                if should_recompile {
386                    if let Ok(mut file) = File::open(&precompiled) {
387                        let mut bytes = Vec::new();
388                        file.read_to_end(&mut bytes).unwrap();
389                        let bytes: Vec<u32> = Vec::from(unsafe {
390                            std::slice::from_raw_parts(
391                                bytes.as_ptr() as *const u32,
392                                bytes.len() / 4,
393                            )
394                        });
395
396                        self.compile_cache
397                            .insert(path.as_ref().to_path_buf(), bytes.clone());
398                        return Ok(bytes);
399                    }
400                }
401            }
402        }
403
404        let file = File::open(&path);
405        if let Err(e) = file {
406            return Err(CompilerError::LoadError(e.to_string()));
407        }
408
409        let mut file = file.unwrap();
410        let mut source = String::new();
411        file.read_to_string(&mut source).unwrap();
412
413        let binary_result = self.compiler.compile_into_spirv(
414            source.as_str(),
415            kind,
416            path.as_ref().to_str().unwrap(),
417            "main",
418            Some(&self.options),
419        );
420
421        if let Err(e) = binary_result {
422            return Err(CompilationError {
423                file: Some(path.as_ref().to_path_buf()),
424                description: e.to_string(),
425            }
426            .into());
427        }
428
429        let binary_result = binary_result.unwrap();
430        if binary_result.get_num_warnings() > 0 {
431            eprintln!(
432                "File {} produced {} warnings: {}",
433                path.as_ref().display(),
434                binary_result.get_num_warnings(),
435                binary_result.get_warning_messages()
436            );
437        }
438        let bytes = binary_result.as_binary().to_vec();
439
440        if cache {
441            let file = File::create(&precompiled);
442            if let Err(e) = file {
443                return Err(CompilerError::WriteError(e.to_string()));
444            }
445
446            let mut file = file.unwrap();
447
448            if let Err(e) = file.write_all(unsafe {
449                std::slice::from_raw_parts(bytes.as_ptr() as *const u8, bytes.len() * 4)
450            }) {
451                return Err(CompilerError::WriteError(e.to_string()));
452            }
453        }
454
455        self.compile_cache
456            .insert(path.as_ref().to_path_buf(), bytes.clone());
457        Ok(bytes)
458    }
459}
460
461#[cfg(test)]
462mod tests {
463    use crate::*;
464
465    #[test]
466    fn test_include() {
467        let mut compiler = CompilerBuilder::new()
468            .with_include_dir("test-spirv")
469            .build()
470            .unwrap();
471
472        let result =
473            compiler.compile_from_file("test-spirv/test-include.vert", ShaderKind::Vertex, false);
474        assert!(result.is_ok());
475    }
476
477    #[test]
478    fn test_include_rel() {
479        let mut compiler = CompilerBuilder::new()
480            .with_include_dir("test-spirv")
481            .build()
482            .unwrap();
483
484        let result = compiler.compile_from_file(
485            "test-spirv/test-include-rel.vert",
486            ShaderKind::Vertex,
487            false,
488        );
489        assert!(result.is_ok());
490    }
491
492    #[test]
493    fn test_with_macro() {
494        let mut compiler = CompilerBuilder::new()
495            .with_include_dir("test-spirv")
496            .with_macro("MY_MACRO", Some("1"))
497            .build()
498            .unwrap();
499
500        let result =
501            compiler.compile_from_file("test-spirv/test-macro.vert", ShaderKind::Vertex, false);
502        assert!(result.is_ok());
503    }
504
505    #[test]
506    fn test_without_macro() {
507        let mut compiler = CompilerBuilder::new()
508            .with_include_dir("test-spirv")
509            .build()
510            .unwrap();
511
512        let result =
513            compiler.compile_from_file("test-spirv/test-macro.vert", ShaderKind::Vertex, false);
514        assert!(result.is_err());
515    }
516
517    #[test]
518    fn test_cache() {
519        let cached = PathBuf::from("test-spirv/test-macro.vert.spv");
520        if cached.exists() {
521            std::fs::remove_file(&cached).unwrap();
522        }
523
524        let mut compiler = CompilerBuilder::new()
525            .with_include_dir("test-spirv")
526            .with_macro("MY_MACRO", Some("1"))
527            .build()
528            .unwrap();
529
530        let result =
531            compiler.compile_from_file("test-spirv/test-macro.vert", ShaderKind::Vertex, true);
532        assert!(result.is_ok());
533        assert!(cached.exists());
534        // Cleanup
535        std::fs::remove_file(cached).unwrap();
536    }
537}