Skip to main content

spirv_cross2/reflect/
entry_points.rs

1use crate::cell::AllocationDropGuard;
2use crate::error;
3use crate::error::{SpirvCrossError, ToContextError};
4use crate::handle::Handle;
5use crate::iter::impl_iterator;
6use crate::reflect::try_valid_slice;
7use crate::string::CompilerStr;
8use crate::Compiler;
9use core::slice;
10use spirv_cross_sys as sys;
11use spirv_cross_sys::{spvc_entry_point, SpvBuiltIn, SpvExecutionModel, SpvStorageClass};
12use std::ffi::c_char;
13
14/// Iterator for declared extensions, created by [`Compiler::declared_extensions`].
15pub struct ExtensionsIter<'a>(slice::Iter<'a, *const c_char>, AllocationDropGuard);
16
17impl_iterator!(ExtensionsIter<'c>: CompilerStr<'c> as map |s, ptr: &*const c_char| {
18    // SAFETY: `declared_extensions` validates that every pointer in the slice
19    // is non-null before constructing the iterator.
20    unsafe {
21        CompilerStr::from_ptr(*ptr, s.1.clone())
22    }
23} for <'c> [0]);
24
25/// Querying declared properties of the SPIR-V module.
26impl<T> Compiler<T> {
27    /// Gets the list of all SPIR-V Capabilities which were declared in the SPIR-V module.
28    pub fn declared_capabilities(&self) -> error::Result<&[spirv::Capability]> {
29        unsafe {
30            let mut caps = std::ptr::null();
31            let mut size = 0;
32
33            sys::spvc_compiler_get_declared_capabilities(self.ptr.as_ptr(), &mut caps, &mut size)
34                .ok(self)?;
35
36            const _: () =
37                assert!(std::mem::size_of::<spirv::Capability>() == std::mem::size_of::<i32>());
38            try_valid_slice(caps.cast(), size)
39        }
40    }
41
42    /// Gets the list of all SPIR-V extensions which were declared in the SPIR-V module.
43    pub fn declared_extensions(&self) -> error::Result<ExtensionsIter<'static>> {
44        // SAFETY: 'a is OK to return here
45        // https://github.com/KhronosGroup/SPIRV-Cross/blob/6a1fb66eef1bdca14acf7d0a51a3f883499d79f0/spirv_cross_c.cpp#L2756
46        unsafe {
47            let mut caps = std::ptr::null_mut();
48            let mut size = 0;
49
50            sys::spvc_compiler_get_declared_extensions(self.ptr.as_ptr(), &mut caps, &mut size)
51                .ok(self)?;
52
53            let ptr_slice = slice::from_raw_parts(caps, size);
54
55            // The C wrapper populates this buffer via `allocate_name`, which
56            // returns nullptr on OOM without propagating an error code.
57            // `CompilerStr::from_ptr` requires non-null pointers, so guard up front.
58            if ptr_slice.iter().any(|p| p.is_null()) {
59                return Err(SpirvCrossError::OutOfMemory(String::from(
60                    "Out of memory allocating declared extension names",
61                )));
62            }
63
64            Ok(ExtensionsIter(ptr_slice.iter(), self.ctx.drop_guard()))
65        }
66    }
67
68    /// Get the execution model of the module.
69    pub fn execution_model(&self) -> error::Result<spirv::ExecutionModel> {
70        unsafe {
71            let exec_model = sys::spvc_compiler_get_execution_model(self.ptr.as_ptr());
72
73            let Some(exec_model) = spirv::ExecutionModel::from_u32(exec_model.0 as u32) else {
74                return Err(SpirvCrossError::InvalidEnum);
75            };
76
77            Ok(exec_model)
78        }
79    }
80}
81
82/// Proof that [`Compiler::update_active_builtins`] was called.
83#[derive(Debug, Copy, Clone)]
84pub struct ActiveBuiltinsUpdatedProof(Handle<()>);
85
86/// Querying builtins in the SPIR-V module
87impl<T> Compiler<T> {
88    /// Update active built-ins in the SPIR-V module.
89    pub fn update_active_builtins(&mut self) -> ActiveBuiltinsUpdatedProof {
90        unsafe {
91            sys::spvc_compiler_update_active_builtins(self.ptr.as_ptr());
92            ActiveBuiltinsUpdatedProof(self.create_handle(()))
93        }
94    }
95
96    /// Return whether the builtin is used or not.
97    ///
98    /// Requires [`Compiler::update_active_builtins`] to be called first,
99    /// proof of which is required to call this function.
100    pub fn has_active_builtin(
101        &self,
102        builtin: spirv::BuiltIn,
103        storage_class: spirv::StorageClass,
104        proof: ActiveBuiltinsUpdatedProof,
105    ) -> error::Result<bool> {
106        if !self.handle_is_valid(&proof.0) {
107            return Err(SpirvCrossError::InvalidOperation(String::from(
108                "The provided proof of building active builtins is invalid",
109            )));
110        }
111
112        unsafe {
113            Ok(sys::spvc_compiler_has_active_builtin(
114                self.ptr.as_ptr(),
115                SpvBuiltIn(builtin as u32),
116                SpvStorageClass(storage_class as u32),
117            ))
118        }
119    }
120}
121
122/// Iterator type created by [`Compiler::entry_points`].
123pub struct EntryPointIter<'a>(slice::Iter<'a, spvc_entry_point>, AllocationDropGuard);
124
125/// A SPIR-V entry point.
126#[derive(Debug)]
127pub struct EntryPoint<'a> {
128    /// The execution model for the entry point.
129    pub execution_model: spirv::ExecutionModel,
130    /// The name of the entry point.
131    pub name: CompilerStr<'a>,
132}
133
134impl_iterator!(EntryPointIter<'a>: EntryPoint<'a> as and_then|s, entry: &spvc_entry_point| {
135    unsafe {
136        let Some(execution_model) = spirv::ExecutionModel::from_u32(entry.execution_model.0 as u32) else {
137            if cfg!(debug_assertions) {
138                panic!("Unexpected SpvExecutionModelMax in valid entry point!")
139            } else {
140                return None;
141            }
142        };
143
144        let name = CompilerStr::from_ptr(entry.name, s.1.clone());
145        Some(EntryPoint {
146            name,
147            execution_model,
148        })
149    }
150} for <'a> [0]);
151
152/// Reflection of entry points.
153impl<T> Compiler<T> {
154    /// All operations work on the current entry point.
155    ///
156    /// Entry points can be swapped out with [`Compiler::set_entry_point`].
157    ///
158    /// Entry points should be set right after creating the compiler as some reflection
159    /// functions traverse the graph from the entry point.
160    ///
161    /// Resource reflection also depends on the entry point.
162    /// By default, the current entry point is set to the first `OpEntryPoint` which appears in the SPIR-V module.
163    //
164    /// Some shader languages restrict the names that can be given to entry points, and the
165    /// corresponding backend will automatically rename an entry point name when compiling,
166    /// if it is illegal.
167    ///
168    /// For example, the common entry point name `main()` is illegal in MSL, and is renamed to an
169    /// alternate name by the MSL backend.
170    ///
171    /// Given the original entry point name contained in the SPIR-V, this function returns
172    /// the name, as updated by the backend, if called after compilation.
173    ///
174    /// If the name is not illegal, and has not been renamed this function will simply return the
175    /// original name.
176    pub fn entry_points(&self) -> error::Result<EntryPointIter<'static>> {
177        unsafe {
178            // SAFETY: 'ctx is sound here
179            // https://github.com/KhronosGroup/SPIRV-Cross/blob/6a1fb66eef1bdca14acf7d0a51a3f883499d79f0/spirv_cross_c.cpp#L2170
180            let mut entry_points = std::ptr::null();
181            let mut size = 0;
182            sys::spvc_compiler_get_entry_points(self.ptr.as_ptr(), &mut entry_points, &mut size)
183                .ok(self)?;
184
185            Ok(EntryPointIter(
186                slice::from_raw_parts(entry_points.cast(), size).iter(),
187                self.ctx.drop_guard(),
188            ))
189        }
190    }
191
192    /// Get the cleansed name of the entry point for the given original name.
193    pub fn cleansed_entry_point_name<'str>(
194        &self,
195        name: impl Into<CompilerStr<'str>>,
196        model: spirv::ExecutionModel,
197    ) -> error::Result<Option<CompilerStr<'static>>> {
198        // SAFETY: 'ctx is sound here
199        // https://github.com/KhronosGroup/SPIRV-Cross/blob/6a1fb66eef1bdca14acf7d0a51a3f883499d79f0/spirv_cross_c.cpp#L2217
200        let name = name.into();
201        let name = name.into_cstring_ptr()?;
202
203        unsafe {
204            let name = sys::spvc_compiler_get_cleansed_entry_point_name(
205                self.ptr.as_ptr(),
206                name.as_ptr(),
207                SpvExecutionModel(model as u32),
208            );
209
210            if name.is_null() {
211                return Ok(None);
212            }
213            Ok(Some(CompilerStr::from_ptr(name, self.ctx.drop_guard())))
214        }
215    }
216
217    /// Set the current entry point by name.
218    ///
219    /// All operations work on the current entry point.
220    ///
221    /// Entry points should be set right after the constructor completes as some reflection functions traverse the graph from the entry point.
222    /// Resource reflection also depends on the entry point.
223    ///
224    /// By default, the current entry point is set to the first OpEntryPoint which appears in the SPIR-V module.
225    ///
226    /// Names for entry points in the SPIR-V module may alias if they belong to different execution models.
227    /// To disambiguate, we must pass along with the entry point names the execution model.
228    ///
229    /// ## Shader language restrictions
230    /// Some shader languages restrict the names that can be given to entry points, and the
231    /// corresponding backend will automatically rename an entry point name, on compilation if it is illegal.
232    ///
233    /// For example, the common entry point name `main()` is illegal in MSL, and is renamed to an
234    /// alternate name by the MSL backend.
235    pub fn set_entry_point<'str>(
236        &mut self,
237        name: impl Into<CompilerStr<'str>>,
238        model: spirv::ExecutionModel,
239    ) -> error::Result<()> {
240        let name = name.into();
241        unsafe {
242            let name = name.into_cstring_ptr()?;
243
244            sys::spvc_compiler_set_entry_point(
245                self.ptr.as_ptr(),
246                name.as_ptr(),
247                SpvExecutionModel(model as u32),
248            )
249            .ok(&*self)
250        }
251    }
252
253    /// Renames an entry point from `from` to `to`.
254    ///
255    /// If old_name is currently selected as the current entry point, it will continue to be the current entry point,
256    /// albeit with a new name.
257    ///
258    /// Values returned from [`Compiler::entry_points`] before this call will be outdated.
259    pub fn rename_entry_point<'str>(
260        &mut self,
261        from: impl Into<CompilerStr<'str>>,
262        to: impl Into<CompilerStr<'str>>,
263        model: spirv::ExecutionModel,
264    ) -> error::Result<()> {
265        let from = from.into();
266        let to = to.into();
267
268        unsafe {
269            let from = from.into_cstring_ptr()?;
270            let to = to.into_cstring_ptr()?;
271
272            sys::spvc_compiler_rename_entry_point(
273                self.ptr.as_ptr(),
274                from.as_ptr(),
275                to.as_ptr(),
276                SpvExecutionModel(model as u32),
277            )
278            .ok(&*self)
279        }
280    }
281}
282
283#[cfg(test)]
284mod test {
285    use crate::error::SpirvCrossError;
286    use crate::Compiler;
287    use crate::{targets, Module};
288    use spirv::ExecutionModel;
289
290    static BASIC_SPV: &[u8] = include_bytes!("../../basic.spv");
291
292    #[test]
293    pub fn get_entry_points() -> Result<(), SpirvCrossError> {
294        let vec = Vec::from(BASIC_SPV);
295        let words = Module::from_words(bytemuck::cast_slice(&vec));
296
297        let mut compiler: Compiler<targets::None> = Compiler::new(words)?;
298        let old_entry_points: Vec<_> = compiler.entry_points()?.collect();
299        let main = &old_entry_points[0];
300
301        eprintln!("{:?}", main);
302
303        assert_eq!("main", main.name.as_ref());
304        compiler.rename_entry_point("main", "new_main", spirv::ExecutionModel::Fragment)?;
305
306        let no_name =
307            compiler.cleansed_entry_point_name("main", spirv::ExecutionModel::Fragment)?;
308
309        assert!(no_name.is_none());
310
311        assert_eq!("main", main.name.as_ref());
312        let new_name =
313            compiler.cleansed_entry_point_name("new_main", spirv::ExecutionModel::Fragment)?;
314
315        assert_eq!(Some("new_main"), new_name.as_deref());
316
317        Ok(())
318    }
319
320    #[test]
321    pub fn entry_point_soundness() -> Result<(), SpirvCrossError> {
322        let vec = Vec::from(BASIC_SPV);
323        let words = Module::from_words(bytemuck::cast_slice(&vec));
324
325        let mut compiler: Compiler<targets::None> = Compiler::new(words)?;
326        let entry_points = compiler.entry_points()?;
327        let name = compiler
328            .cleansed_entry_point_name("main", spirv::ExecutionModel::Fragment)?
329            .unwrap();
330
331        assert_eq!("main", name.as_ref());
332
333        drop(compiler);
334
335        assert_eq!("main", name.as_ref());
336        let entries: Vec<_> = entry_points.collect();
337
338        eprintln!("{:?}", entries);
339        Ok(())
340    }
341
342    #[test]
343    pub fn capabilities() -> Result<(), SpirvCrossError> {
344        let vec = Vec::from(BASIC_SPV);
345        let words = Module::from_words(bytemuck::cast_slice(&vec));
346
347        let mut compiler: Compiler<targets::None> = Compiler::new(words)?;
348        let resources = compiler.shader_resources()?.all_resources()?;
349
350        let ty = compiler.declared_capabilities()?;
351
352        assert_eq!([spirv::Capability::Shader], ty);
353
354        Ok(())
355    }
356}