spirv_cross2/reflect/
entry_points.rs1use crate::cell::AllocationDropGuard;
2use crate::error;
3use crate::error::{SpirvCrossError, ToContextError};
4use crate::handle::Handle;
5use crate::iter::impl_iterator;
6use crate::reflect::try_valid_slice;
7use crate::string::CompilerStr;
8use crate::Compiler;
9use core::slice;
10use spirv_cross_sys as sys;
11use spirv_cross_sys::{spvc_entry_point, SpvBuiltIn, SpvExecutionModel, SpvStorageClass};
12use std::ffi::c_char;
13
14pub struct ExtensionsIter<'a>(slice::Iter<'a, *const c_char>, AllocationDropGuard);
16
17impl_iterator!(ExtensionsIter<'c>: CompilerStr<'c> as map |s, ptr: &*const c_char| {
18 unsafe {
19 CompilerStr::from_ptr(*ptr, s.1.clone())
20 }
21} for <'c> [0]);
22
23impl<T> Compiler<T> {
25 pub fn declared_capabilities(&self) -> error::Result<&[spirv::Capability]> {
27 unsafe {
28 let mut caps = std::ptr::null();
29 let mut size = 0;
30
31 sys::spvc_compiler_get_declared_capabilities(self.ptr.as_ptr(), &mut caps, &mut size)
32 .ok(self)?;
33
34 const _: () =
35 assert!(std::mem::size_of::<spirv::Capability>() == std::mem::size_of::<i32>());
36 try_valid_slice(caps.cast(), size)
37 }
38 }
39
40 pub fn declared_extensions(&self) -> error::Result<ExtensionsIter<'static>> {
42 unsafe {
45 let mut caps = std::ptr::null_mut();
46 let mut size = 0;
47
48 sys::spvc_compiler_get_declared_extensions(self.ptr.as_ptr(), &mut caps, &mut size)
49 .ok(self)?;
50
51 let ptr_slice = slice::from_raw_parts(caps, size);
52
53 Ok(ExtensionsIter(ptr_slice.iter(), self.ctx.drop_guard()))
54 }
55 }
56
57 pub fn execution_model(&self) -> error::Result<spirv::ExecutionModel> {
59 unsafe {
60 let exec_model = sys::spvc_compiler_get_execution_model(self.ptr.as_ptr());
61
62 let Some(exec_model) = spirv::ExecutionModel::from_u32(exec_model.0 as u32) else {
63 return Err(SpirvCrossError::InvalidEnum);
64 };
65
66 Ok(exec_model)
67 }
68 }
69}
70
71#[derive(Debug, Copy, Clone)]
73pub struct ActiveBuiltinsUpdatedProof(Handle<()>);
74
75impl<T> Compiler<T> {
77 pub fn update_active_builtins(&mut self) -> ActiveBuiltinsUpdatedProof {
79 unsafe {
80 sys::spvc_compiler_update_active_builtins(self.ptr.as_ptr());
81 ActiveBuiltinsUpdatedProof(self.create_handle(()))
82 }
83 }
84
85 pub fn has_active_builtin(
90 &self,
91 builtin: spirv::BuiltIn,
92 storage_class: spirv::StorageClass,
93 proof: ActiveBuiltinsUpdatedProof,
94 ) -> error::Result<bool> {
95 if !self.handle_is_valid(&proof.0) {
96 return Err(SpirvCrossError::InvalidOperation(String::from(
97 "The provided proof of building active builtins is invalid",
98 )));
99 }
100
101 unsafe {
102 Ok(sys::spvc_compiler_has_active_builtin(
103 self.ptr.as_ptr(),
104 SpvBuiltIn(builtin as i32),
105 SpvStorageClass(storage_class as i32),
106 ))
107 }
108 }
109}
110
111pub struct EntryPointIter<'a>(slice::Iter<'a, spvc_entry_point>, AllocationDropGuard);
113
114#[derive(Debug)]
116pub struct EntryPoint<'a> {
117 pub execution_model: spirv::ExecutionModel,
119 pub name: CompilerStr<'a>,
121}
122
123impl_iterator!(EntryPointIter<'a>: EntryPoint<'a> as and_then|s, entry: &spvc_entry_point| {
124 unsafe {
125 let Some(execution_model) = spirv::ExecutionModel::from_u32(entry.execution_model.0 as u32) else {
126 if cfg!(debug_assertions) {
127 panic!("Unexpected SpvExecutionModelMax in valid entry point!")
128 } else {
129 return None;
130 }
131 };
132
133 let name = CompilerStr::from_ptr(entry.name, s.1.clone());
134 Some(EntryPoint {
135 name,
136 execution_model,
137 })
138 }
139} for <'a> [0]);
140
141impl<T> Compiler<T> {
143 pub fn entry_points(&self) -> error::Result<EntryPointIter<'static>> {
166 unsafe {
167 let mut entry_points = std::ptr::null();
170 let mut size = 0;
171 sys::spvc_compiler_get_entry_points(self.ptr.as_ptr(), &mut entry_points, &mut size)
172 .ok(self)?;
173
174 Ok(EntryPointIter(
175 slice::from_raw_parts(entry_points.cast(), size).iter(),
176 self.ctx.drop_guard(),
177 ))
178 }
179 }
180
181 pub fn cleansed_entry_point_name<'str>(
183 &self,
184 name: impl Into<CompilerStr<'str>>,
185 model: spirv::ExecutionModel,
186 ) -> error::Result<Option<CompilerStr<'static>>> {
187 let name = name.into();
190 let name = name.into_cstring_ptr()?;
191
192 unsafe {
193 let name = sys::spvc_compiler_get_cleansed_entry_point_name(
194 self.ptr.as_ptr(),
195 name.as_ptr(),
196 SpvExecutionModel(model as u32 as i32),
197 );
198
199 if name.is_null() {
200 return Ok(None);
201 }
202 Ok(Some(CompilerStr::from_ptr(name, self.ctx.drop_guard())))
203 }
204 }
205
206 pub fn set_entry_point<'str>(
225 &mut self,
226 name: impl Into<CompilerStr<'str>>,
227 model: spirv::ExecutionModel,
228 ) -> error::Result<()> {
229 let name = name.into();
230 unsafe {
231 let name = name.into_cstring_ptr()?;
232
233 sys::spvc_compiler_set_entry_point(
234 self.ptr.as_ptr(),
235 name.as_ptr(),
236 SpvExecutionModel(model as u32 as i32),
237 )
238 .ok(&*self)
239 }
240 }
241
242 pub fn rename_entry_point<'str>(
249 &mut self,
250 from: impl Into<CompilerStr<'str>>,
251 to: impl Into<CompilerStr<'str>>,
252 model: spirv::ExecutionModel,
253 ) -> error::Result<()> {
254 let from = from.into();
255 let to = to.into();
256
257 unsafe {
258 let from = from.into_cstring_ptr()?;
259 let to = to.into_cstring_ptr()?;
260
261 sys::spvc_compiler_rename_entry_point(
262 self.ptr.as_ptr(),
263 from.as_ptr(),
264 to.as_ptr(),
265 SpvExecutionModel(model as u32 as i32),
266 )
267 .ok(&*self)
268 }
269 }
270}
271
272#[cfg(test)]
273mod test {
274 use crate::error::SpirvCrossError;
275 use crate::Compiler;
276 use crate::{targets, Module};
277 use spirv::ExecutionModel;
278
279 static BASIC_SPV: &[u8] = include_bytes!("../../basic.spv");
280
281 #[test]
282 pub fn get_entry_points() -> Result<(), SpirvCrossError> {
283 let vec = Vec::from(BASIC_SPV);
284 let words = Module::from_words(bytemuck::cast_slice(&vec));
285
286 let mut compiler: Compiler<targets::None> = Compiler::new(words)?;
287 let old_entry_points: Vec<_> = compiler.entry_points()?.collect();
288 let main = &old_entry_points[0];
289
290 eprintln!("{:?}", main);
291
292 assert_eq!("main", main.name.as_ref());
293 compiler.rename_entry_point("main", "new_main", spirv::ExecutionModel::Fragment)?;
294
295 let no_name =
296 compiler.cleansed_entry_point_name("main", spirv::ExecutionModel::Fragment)?;
297
298 assert!(no_name.is_none());
299
300 assert_eq!("main", main.name.as_ref());
301 let new_name =
302 compiler.cleansed_entry_point_name("new_main", spirv::ExecutionModel::Fragment)?;
303
304 assert_eq!(Some("new_main"), new_name.as_deref());
305
306 Ok(())
307 }
308
309 #[test]
310 pub fn entry_point_soundness() -> Result<(), SpirvCrossError> {
311 let vec = Vec::from(BASIC_SPV);
312 let words = Module::from_words(bytemuck::cast_slice(&vec));
313
314 let mut compiler: Compiler<targets::None> = Compiler::new(words)?;
315 let entry_points = compiler.entry_points()?;
316 let name = compiler
317 .cleansed_entry_point_name("main", spirv::ExecutionModel::Fragment)?
318 .unwrap();
319
320 assert_eq!("main", name.as_ref());
321
322 drop(compiler);
323
324 assert_eq!("main", name.as_ref());
325 let entries: Vec<_> = entry_points.collect();
326
327 eprintln!("{:?}", entries);
328 Ok(())
329 }
330
331 #[test]
332 pub fn capabilities() -> Result<(), SpirvCrossError> {
333 let vec = Vec::from(BASIC_SPV);
334 let words = Module::from_words(bytemuck::cast_slice(&vec));
335
336 let mut compiler: Compiler<targets::None> = Compiler::new(words)?;
337 let resources = compiler.shader_resources()?.all_resources()?;
338
339 let ty = compiler.declared_capabilities()?;
340
341 assert_eq!([spirv::Capability::Shader], ty);
342
343 Ok(())
344 }
345}