spirv_cross2/reflect/
entry_points.rs

1use crate::cell::AllocationDropGuard;
2use crate::error;
3use crate::error::{SpirvCrossError, ToContextError};
4use crate::handle::Handle;
5use crate::iter::impl_iterator;
6use crate::reflect::try_valid_slice;
7use crate::string::CompilerStr;
8use crate::Compiler;
9use core::slice;
10use spirv_cross_sys as sys;
11use spirv_cross_sys::{spvc_entry_point, SpvBuiltIn, SpvExecutionModel, SpvStorageClass};
12use std::ffi::c_char;
13
14/// Iterator for declared extensions, created by [`Compiler::declared_extensions`].
15pub struct ExtensionsIter<'a>(slice::Iter<'a, *const c_char>, AllocationDropGuard);
16
17impl_iterator!(ExtensionsIter<'c>: CompilerStr<'c> as map |s, ptr: &*const c_char| {
18    unsafe {
19        CompilerStr::from_ptr(*ptr, s.1.clone())
20    }
21} for <'c> [0]);
22
23/// Querying declared properties of the SPIR-V module.
24impl<T> Compiler<T> {
25    /// Gets the list of all SPIR-V Capabilities which were declared in the SPIR-V module.
26    pub fn declared_capabilities(&self) -> error::Result<&[spirv::Capability]> {
27        unsafe {
28            let mut caps = std::ptr::null();
29            let mut size = 0;
30
31            sys::spvc_compiler_get_declared_capabilities(self.ptr.as_ptr(), &mut caps, &mut size)
32                .ok(self)?;
33
34            const _: () =
35                assert!(std::mem::size_of::<spirv::Capability>() == std::mem::size_of::<i32>());
36            try_valid_slice(caps.cast(), size)
37        }
38    }
39
40    /// Gets the list of all SPIR-V extensions which were declared in the SPIR-V module.
41    pub fn declared_extensions(&self) -> error::Result<ExtensionsIter<'static>> {
42        // SAFETY: 'a is OK to return here
43        // https://github.com/KhronosGroup/SPIRV-Cross/blob/6a1fb66eef1bdca14acf7d0a51a3f883499d79f0/spirv_cross_c.cpp#L2756
44        unsafe {
45            let mut caps = std::ptr::null_mut();
46            let mut size = 0;
47
48            sys::spvc_compiler_get_declared_extensions(self.ptr.as_ptr(), &mut caps, &mut size)
49                .ok(self)?;
50
51            let ptr_slice = slice::from_raw_parts(caps, size);
52
53            Ok(ExtensionsIter(ptr_slice.iter(), self.ctx.drop_guard()))
54        }
55    }
56
57    /// Get the execution model of the module.
58    pub fn execution_model(&self) -> error::Result<spirv::ExecutionModel> {
59        unsafe {
60            let exec_model = sys::spvc_compiler_get_execution_model(self.ptr.as_ptr());
61
62            let Some(exec_model) = spirv::ExecutionModel::from_u32(exec_model.0 as u32) else {
63                return Err(SpirvCrossError::InvalidEnum);
64            };
65
66            Ok(exec_model)
67        }
68    }
69}
70
71/// Proof that [`Compiler::update_active_builtins`] was called.
72#[derive(Debug, Copy, Clone)]
73pub struct ActiveBuiltinsUpdatedProof(Handle<()>);
74
75/// Querying builtins in the SPIR-V module
76impl<T> Compiler<T> {
77    /// Update active built-ins in the SPIR-V module.
78    pub fn update_active_builtins(&mut self) -> ActiveBuiltinsUpdatedProof {
79        unsafe {
80            sys::spvc_compiler_update_active_builtins(self.ptr.as_ptr());
81            ActiveBuiltinsUpdatedProof(self.create_handle(()))
82        }
83    }
84
85    /// Return whether the builtin is used or not.
86    ///
87    /// Requires [`Compiler::update_active_builtins`] to be called first,
88    /// proof of which is required to call this function.
89    pub fn has_active_builtin(
90        &self,
91        builtin: spirv::BuiltIn,
92        storage_class: spirv::StorageClass,
93        proof: ActiveBuiltinsUpdatedProof,
94    ) -> error::Result<bool> {
95        if !self.handle_is_valid(&proof.0) {
96            return Err(SpirvCrossError::InvalidOperation(String::from(
97                "The provided proof of building active builtins is invalid",
98            )));
99        }
100
101        unsafe {
102            Ok(sys::spvc_compiler_has_active_builtin(
103                self.ptr.as_ptr(),
104                SpvBuiltIn(builtin as i32),
105                SpvStorageClass(storage_class as i32),
106            ))
107        }
108    }
109}
110
111/// Iterator type created by [`Compiler::entry_points`].
112pub struct EntryPointIter<'a>(slice::Iter<'a, spvc_entry_point>, AllocationDropGuard);
113
114/// A SPIR-V entry point.
115#[derive(Debug)]
116pub struct EntryPoint<'a> {
117    /// The execution model for the entry point.
118    pub execution_model: spirv::ExecutionModel,
119    /// The name of the entry point.
120    pub name: CompilerStr<'a>,
121}
122
123impl_iterator!(EntryPointIter<'a>: EntryPoint<'a> as and_then|s, entry: &spvc_entry_point| {
124    unsafe {
125        let Some(execution_model) = spirv::ExecutionModel::from_u32(entry.execution_model.0 as u32) else {
126            if cfg!(debug_assertions) {
127                panic!("Unexpected SpvExecutionModelMax in valid entry point!")
128            } else {
129                return None;
130            }
131        };
132
133        let name = CompilerStr::from_ptr(entry.name, s.1.clone());
134        Some(EntryPoint {
135            name,
136            execution_model,
137        })
138    }
139} for <'a> [0]);
140
141/// Reflection of entry points.
142impl<T> Compiler<T> {
143    /// All operations work on the current entry point.
144    ///
145    /// Entry points can be swapped out with [`Compiler::set_entry_point`].
146    ///
147    /// Entry points should be set right after creating the compiler as some reflection
148    /// functions traverse the graph from the entry point.
149    ///
150    /// Resource reflection also depends on the entry point.
151    /// By default, the current entry point is set to the first `OpEntryPoint` which appears in the SPIR-V module.
152    //
153    /// Some shader languages restrict the names that can be given to entry points, and the
154    /// corresponding backend will automatically rename an entry point name when compiling,
155    /// if it is illegal.
156    ///
157    /// For example, the common entry point name `main()` is illegal in MSL, and is renamed to an
158    /// alternate name by the MSL backend.
159    ///
160    /// Given the original entry point name contained in the SPIR-V, this function returns
161    /// the name, as updated by the backend, if called after compilation.
162    ///
163    /// If the name is not illegal, and has not been renamed this function will simply return the
164    /// original name.
165    pub fn entry_points(&self) -> error::Result<EntryPointIter<'static>> {
166        unsafe {
167            // SAFETY: 'ctx is sound here
168            // https://github.com/KhronosGroup/SPIRV-Cross/blob/6a1fb66eef1bdca14acf7d0a51a3f883499d79f0/spirv_cross_c.cpp#L2170
169            let mut entry_points = std::ptr::null();
170            let mut size = 0;
171            sys::spvc_compiler_get_entry_points(self.ptr.as_ptr(), &mut entry_points, &mut size)
172                .ok(self)?;
173
174            Ok(EntryPointIter(
175                slice::from_raw_parts(entry_points.cast(), size).iter(),
176                self.ctx.drop_guard(),
177            ))
178        }
179    }
180
181    /// Get the cleansed name of the entry point for the given original name.
182    pub fn cleansed_entry_point_name<'str>(
183        &self,
184        name: impl Into<CompilerStr<'str>>,
185        model: spirv::ExecutionModel,
186    ) -> error::Result<Option<CompilerStr<'static>>> {
187        // SAFETY: 'ctx is sound here
188        // https://github.com/KhronosGroup/SPIRV-Cross/blob/6a1fb66eef1bdca14acf7d0a51a3f883499d79f0/spirv_cross_c.cpp#L2217
189        let name = name.into();
190        let name = name.into_cstring_ptr()?;
191
192        unsafe {
193            let name = sys::spvc_compiler_get_cleansed_entry_point_name(
194                self.ptr.as_ptr(),
195                name.as_ptr(),
196                SpvExecutionModel(model as u32 as i32),
197            );
198
199            if name.is_null() {
200                return Ok(None);
201            }
202            Ok(Some(CompilerStr::from_ptr(name, self.ctx.drop_guard())))
203        }
204    }
205
206    /// Set the current entry point by name.
207    ///
208    /// All operations work on the current entry point.
209    ///
210    /// Entry points should be set right after the constructor completes as some reflection functions traverse the graph from the entry point.
211    /// Resource reflection also depends on the entry point.
212    ///
213    /// By default, the current entry point is set to the first OpEntryPoint which appears in the SPIR-V module.
214    ///
215    /// Names for entry points in the SPIR-V module may alias if they belong to different execution models.
216    /// To disambiguate, we must pass along with the entry point names the execution model.
217    ///
218    /// ## Shader language restrictions
219    /// Some shader languages restrict the names that can be given to entry points, and the
220    /// corresponding backend will automatically rename an entry point name, on compilation if it is illegal.
221    ///
222    /// For example, the common entry point name `main()` is illegal in MSL, and is renamed to an
223    /// alternate name by the MSL backend.
224    pub fn set_entry_point<'str>(
225        &mut self,
226        name: impl Into<CompilerStr<'str>>,
227        model: spirv::ExecutionModel,
228    ) -> error::Result<()> {
229        let name = name.into();
230        unsafe {
231            let name = name.into_cstring_ptr()?;
232
233            sys::spvc_compiler_set_entry_point(
234                self.ptr.as_ptr(),
235                name.as_ptr(),
236                SpvExecutionModel(model as u32 as i32),
237            )
238            .ok(&*self)
239        }
240    }
241
242    /// Renames an entry point from `from` to `to`.
243    ///
244    /// If old_name is currently selected as the current entry point, it will continue to be the current entry point,
245    /// albeit with a new name.
246    ///
247    /// Values returned from [`Compiler::entry_points`] before this call will be outdated.
248    pub fn rename_entry_point<'str>(
249        &mut self,
250        from: impl Into<CompilerStr<'str>>,
251        to: impl Into<CompilerStr<'str>>,
252        model: spirv::ExecutionModel,
253    ) -> error::Result<()> {
254        let from = from.into();
255        let to = to.into();
256
257        unsafe {
258            let from = from.into_cstring_ptr()?;
259            let to = to.into_cstring_ptr()?;
260
261            sys::spvc_compiler_rename_entry_point(
262                self.ptr.as_ptr(),
263                from.as_ptr(),
264                to.as_ptr(),
265                SpvExecutionModel(model as u32 as i32),
266            )
267            .ok(&*self)
268        }
269    }
270}
271
272#[cfg(test)]
273mod test {
274    use crate::error::SpirvCrossError;
275    use crate::Compiler;
276    use crate::{targets, Module};
277    use spirv::ExecutionModel;
278
279    static BASIC_SPV: &[u8] = include_bytes!("../../basic.spv");
280
281    #[test]
282    pub fn get_entry_points() -> Result<(), SpirvCrossError> {
283        let vec = Vec::from(BASIC_SPV);
284        let words = Module::from_words(bytemuck::cast_slice(&vec));
285
286        let mut compiler: Compiler<targets::None> = Compiler::new(words)?;
287        let old_entry_points: Vec<_> = compiler.entry_points()?.collect();
288        let main = &old_entry_points[0];
289
290        eprintln!("{:?}", main);
291
292        assert_eq!("main", main.name.as_ref());
293        compiler.rename_entry_point("main", "new_main", spirv::ExecutionModel::Fragment)?;
294
295        let no_name =
296            compiler.cleansed_entry_point_name("main", spirv::ExecutionModel::Fragment)?;
297
298        assert!(no_name.is_none());
299
300        assert_eq!("main", main.name.as_ref());
301        let new_name =
302            compiler.cleansed_entry_point_name("new_main", spirv::ExecutionModel::Fragment)?;
303
304        assert_eq!(Some("new_main"), new_name.as_deref());
305
306        Ok(())
307    }
308
309    #[test]
310    pub fn entry_point_soundness() -> Result<(), SpirvCrossError> {
311        let vec = Vec::from(BASIC_SPV);
312        let words = Module::from_words(bytemuck::cast_slice(&vec));
313
314        let mut compiler: Compiler<targets::None> = Compiler::new(words)?;
315        let entry_points = compiler.entry_points()?;
316        let name = compiler
317            .cleansed_entry_point_name("main", spirv::ExecutionModel::Fragment)?
318            .unwrap();
319
320        assert_eq!("main", name.as_ref());
321
322        drop(compiler);
323
324        assert_eq!("main", name.as_ref());
325        let entries: Vec<_> = entry_points.collect();
326
327        eprintln!("{:?}", entries);
328        Ok(())
329    }
330
331    #[test]
332    pub fn capabilities() -> Result<(), SpirvCrossError> {
333        let vec = Vec::from(BASIC_SPV);
334        let words = Module::from_words(bytemuck::cast_slice(&vec));
335
336        let mut compiler: Compiler<targets::None> = Compiler::new(words)?;
337        let resources = compiler.shader_resources()?.all_resources()?;
338
339        let ty = compiler.declared_capabilities()?;
340
341        assert_eq!([spirv::Capability::Shader], ty);
342
343        Ok(())
344    }
345}