spirv_cross2/reflect/
entry_points.rs1use crate::cell::AllocationDropGuard;
2use crate::error;
3use crate::error::{SpirvCrossError, ToContextError};
4use crate::handle::Handle;
5use crate::iter::impl_iterator;
6use crate::reflect::try_valid_slice;
7use crate::string::CompilerStr;
8use crate::Compiler;
9use core::slice;
10use spirv_cross_sys as sys;
11use spirv_cross_sys::{spvc_entry_point, SpvBuiltIn, SpvExecutionModel, SpvStorageClass};
12use std::ffi::c_char;
13
14pub struct ExtensionsIter<'a>(slice::Iter<'a, *const c_char>, AllocationDropGuard);
16
17impl_iterator!(ExtensionsIter<'c>: CompilerStr<'c> as map |s, ptr: &*const c_char| {
18 unsafe {
21 CompilerStr::from_ptr(*ptr, s.1.clone())
22 }
23} for <'c> [0]);
24
25impl<T> Compiler<T> {
27 pub fn declared_capabilities(&self) -> error::Result<&[spirv::Capability]> {
29 unsafe {
30 let mut caps = std::ptr::null();
31 let mut size = 0;
32
33 sys::spvc_compiler_get_declared_capabilities(self.ptr.as_ptr(), &mut caps, &mut size)
34 .ok(self)?;
35
36 const _: () =
37 assert!(std::mem::size_of::<spirv::Capability>() == std::mem::size_of::<i32>());
38 try_valid_slice(caps.cast(), size)
39 }
40 }
41
42 pub fn declared_extensions(&self) -> error::Result<ExtensionsIter<'static>> {
44 unsafe {
47 let mut caps = std::ptr::null_mut();
48 let mut size = 0;
49
50 sys::spvc_compiler_get_declared_extensions(self.ptr.as_ptr(), &mut caps, &mut size)
51 .ok(self)?;
52
53 let ptr_slice = slice::from_raw_parts(caps, size);
54
55 if ptr_slice.iter().any(|p| p.is_null()) {
59 return Err(SpirvCrossError::OutOfMemory(String::from(
60 "Out of memory allocating declared extension names",
61 )));
62 }
63
64 Ok(ExtensionsIter(ptr_slice.iter(), self.ctx.drop_guard()))
65 }
66 }
67
68 pub fn execution_model(&self) -> error::Result<spirv::ExecutionModel> {
70 unsafe {
71 let exec_model = sys::spvc_compiler_get_execution_model(self.ptr.as_ptr());
72
73 let Some(exec_model) = spirv::ExecutionModel::from_u32(exec_model.0 as u32) else {
74 return Err(SpirvCrossError::InvalidEnum);
75 };
76
77 Ok(exec_model)
78 }
79 }
80}
81
82#[derive(Debug, Copy, Clone)]
84pub struct ActiveBuiltinsUpdatedProof(Handle<()>);
85
86impl<T> Compiler<T> {
88 pub fn update_active_builtins(&mut self) -> ActiveBuiltinsUpdatedProof {
90 unsafe {
91 sys::spvc_compiler_update_active_builtins(self.ptr.as_ptr());
92 ActiveBuiltinsUpdatedProof(self.create_handle(()))
93 }
94 }
95
96 pub fn has_active_builtin(
101 &self,
102 builtin: spirv::BuiltIn,
103 storage_class: spirv::StorageClass,
104 proof: ActiveBuiltinsUpdatedProof,
105 ) -> error::Result<bool> {
106 if !self.handle_is_valid(&proof.0) {
107 return Err(SpirvCrossError::InvalidOperation(String::from(
108 "The provided proof of building active builtins is invalid",
109 )));
110 }
111
112 unsafe {
113 Ok(sys::spvc_compiler_has_active_builtin(
114 self.ptr.as_ptr(),
115 SpvBuiltIn(builtin as u32),
116 SpvStorageClass(storage_class as u32),
117 ))
118 }
119 }
120}
121
122pub struct EntryPointIter<'a>(slice::Iter<'a, spvc_entry_point>, AllocationDropGuard);
124
125#[derive(Debug)]
127pub struct EntryPoint<'a> {
128 pub execution_model: spirv::ExecutionModel,
130 pub name: CompilerStr<'a>,
132}
133
134impl_iterator!(EntryPointIter<'a>: EntryPoint<'a> as and_then|s, entry: &spvc_entry_point| {
135 unsafe {
136 let Some(execution_model) = spirv::ExecutionModel::from_u32(entry.execution_model.0 as u32) else {
137 if cfg!(debug_assertions) {
138 panic!("Unexpected SpvExecutionModelMax in valid entry point!")
139 } else {
140 return None;
141 }
142 };
143
144 let name = CompilerStr::from_ptr(entry.name, s.1.clone());
145 Some(EntryPoint {
146 name,
147 execution_model,
148 })
149 }
150} for <'a> [0]);
151
152impl<T> Compiler<T> {
154 pub fn entry_points(&self) -> error::Result<EntryPointIter<'static>> {
177 unsafe {
178 let mut entry_points = std::ptr::null();
181 let mut size = 0;
182 sys::spvc_compiler_get_entry_points(self.ptr.as_ptr(), &mut entry_points, &mut size)
183 .ok(self)?;
184
185 Ok(EntryPointIter(
186 slice::from_raw_parts(entry_points.cast(), size).iter(),
187 self.ctx.drop_guard(),
188 ))
189 }
190 }
191
192 pub fn cleansed_entry_point_name<'str>(
194 &self,
195 name: impl Into<CompilerStr<'str>>,
196 model: spirv::ExecutionModel,
197 ) -> error::Result<Option<CompilerStr<'static>>> {
198 let name = name.into();
201 let name = name.into_cstring_ptr()?;
202
203 unsafe {
204 let name = sys::spvc_compiler_get_cleansed_entry_point_name(
205 self.ptr.as_ptr(),
206 name.as_ptr(),
207 SpvExecutionModel(model as u32),
208 );
209
210 if name.is_null() {
211 return Ok(None);
212 }
213 Ok(Some(CompilerStr::from_ptr(name, self.ctx.drop_guard())))
214 }
215 }
216
217 pub fn set_entry_point<'str>(
236 &mut self,
237 name: impl Into<CompilerStr<'str>>,
238 model: spirv::ExecutionModel,
239 ) -> error::Result<()> {
240 let name = name.into();
241 unsafe {
242 let name = name.into_cstring_ptr()?;
243
244 sys::spvc_compiler_set_entry_point(
245 self.ptr.as_ptr(),
246 name.as_ptr(),
247 SpvExecutionModel(model as u32),
248 )
249 .ok(&*self)
250 }
251 }
252
253 pub fn rename_entry_point<'str>(
260 &mut self,
261 from: impl Into<CompilerStr<'str>>,
262 to: impl Into<CompilerStr<'str>>,
263 model: spirv::ExecutionModel,
264 ) -> error::Result<()> {
265 let from = from.into();
266 let to = to.into();
267
268 unsafe {
269 let from = from.into_cstring_ptr()?;
270 let to = to.into_cstring_ptr()?;
271
272 sys::spvc_compiler_rename_entry_point(
273 self.ptr.as_ptr(),
274 from.as_ptr(),
275 to.as_ptr(),
276 SpvExecutionModel(model as u32),
277 )
278 .ok(&*self)
279 }
280 }
281}
282
283#[cfg(test)]
284mod test {
285 use crate::error::SpirvCrossError;
286 use crate::Compiler;
287 use crate::{targets, Module};
288 use spirv::ExecutionModel;
289
290 static BASIC_SPV: &[u8] = include_bytes!("../../basic.spv");
291
292 #[test]
293 pub fn get_entry_points() -> Result<(), SpirvCrossError> {
294 let vec = Vec::from(BASIC_SPV);
295 let words = Module::from_words(bytemuck::cast_slice(&vec));
296
297 let mut compiler: Compiler<targets::None> = Compiler::new(words)?;
298 let old_entry_points: Vec<_> = compiler.entry_points()?.collect();
299 let main = &old_entry_points[0];
300
301 eprintln!("{:?}", main);
302
303 assert_eq!("main", main.name.as_ref());
304 compiler.rename_entry_point("main", "new_main", spirv::ExecutionModel::Fragment)?;
305
306 let no_name =
307 compiler.cleansed_entry_point_name("main", spirv::ExecutionModel::Fragment)?;
308
309 assert!(no_name.is_none());
310
311 assert_eq!("main", main.name.as_ref());
312 let new_name =
313 compiler.cleansed_entry_point_name("new_main", spirv::ExecutionModel::Fragment)?;
314
315 assert_eq!(Some("new_main"), new_name.as_deref());
316
317 Ok(())
318 }
319
320 #[test]
321 pub fn entry_point_soundness() -> Result<(), SpirvCrossError> {
322 let vec = Vec::from(BASIC_SPV);
323 let words = Module::from_words(bytemuck::cast_slice(&vec));
324
325 let mut compiler: Compiler<targets::None> = Compiler::new(words)?;
326 let entry_points = compiler.entry_points()?;
327 let name = compiler
328 .cleansed_entry_point_name("main", spirv::ExecutionModel::Fragment)?
329 .unwrap();
330
331 assert_eq!("main", name.as_ref());
332
333 drop(compiler);
334
335 assert_eq!("main", name.as_ref());
336 let entries: Vec<_> = entry_points.collect();
337
338 eprintln!("{:?}", entries);
339 Ok(())
340 }
341
342 #[test]
343 pub fn capabilities() -> Result<(), SpirvCrossError> {
344 let vec = Vec::from(BASIC_SPV);
345 let words = Module::from_words(bytemuck::cast_slice(&vec));
346
347 let mut compiler: Compiler<targets::None> = Compiler::new(words)?;
348 let resources = compiler.shader_resources()?.all_resources()?;
349
350 let ty = compiler.declared_capabilities()?;
351
352 assert_eq!([spirv::Capability::Shader], ty);
353
354 Ok(())
355 }
356}