1use crate::error::{SpirvCrossError, ToContextError};
2use crate::handle::{ConstantId, Handle, Id, TypeId, VariableId};
3use crate::reflect::StructMember;
4use crate::sealed::Sealed;
5use crate::string::CompilerStr;
6use crate::Compiler;
7use crate::{error, ToStatic};
8use spirv::Decoration;
9use spirv_cross_sys as sys;
10use spirv_cross_sys::{SpvDecoration, SpvId};
11
12#[derive(Debug, Eq, PartialEq)]
14pub enum DecorationValue<'a> {
15 Literal(u32),
30 BuiltIn(spirv::BuiltIn),
32 RoundingMode(spirv::FPRoundingMode),
34 Constant(Handle<ConstantId>),
36 String(CompilerStr<'a>),
38 Present,
40}
41
42impl DecorationValue<'_> {
43 pub const fn unset() -> Option<Self> {
46 None
47 }
48
49 pub fn as_literal(&self) -> Option<u32> {
51 match self {
52 Self::Literal(l) => Some(*l),
53 _ => None,
54 }
55 }
56}
57
58impl From<u32> for DecorationValue<'_> {
59 fn from(value: u32) -> Self {
60 DecorationValue::Literal(value)
61 }
62}
63
64impl From<()> for DecorationValue<'_> {
65 fn from(_value: ()) -> Self {
66 DecorationValue::Present
67 }
68}
69
70impl From<Handle<ConstantId>> for DecorationValue<'_> {
71 fn from(value: Handle<ConstantId>) -> Self {
72 DecorationValue::Constant(value)
73 }
74}
75
76impl<'a> From<&'a str> for DecorationValue<'a> {
77 fn from(value: &'a str) -> Self {
78 DecorationValue::String(CompilerStr::from_str(value))
79 }
80}
81
82impl From<String> for DecorationValue<'_> {
83 fn from(value: String) -> Self {
84 DecorationValue::String(CompilerStr::from_string(value))
85 }
86}
87
88impl<'a> From<CompilerStr<'a>> for DecorationValue<'a> {
89 fn from(value: CompilerStr<'a>) -> Self {
90 DecorationValue::String(value)
91 }
92}
93
94impl Sealed for DecorationValue<'_> {}
95impl ToStatic for DecorationValue<'_> {
96 type Static<'a>
97 = DecorationValue<'static>
98 where
99 'a: 'static;
100
101 fn to_static(&self) -> Self::Static<'static> {
102 match self {
103 DecorationValue::Literal(a) => DecorationValue::Literal(*a),
104 DecorationValue::BuiltIn(a) => DecorationValue::BuiltIn(*a),
105 DecorationValue::RoundingMode(a) => DecorationValue::RoundingMode(*a),
106 DecorationValue::Constant(a) => DecorationValue::Constant(*a),
107 DecorationValue::String(c) => {
108 let owned = c.to_string();
109 DecorationValue::String(CompilerStr::from_string(owned))
110 }
111 DecorationValue::Present => DecorationValue::Present,
112 }
113 }
114}
115
116impl<'a> Clone for DecorationValue<'a> {
117 fn clone(&self) -> DecorationValue<'static> {
118 self.to_static()
119 }
120}
121
122impl DecorationValue<'_> {
123 pub fn type_is_valid_for_decoration(&self, decoration: spirv::Decoration) -> bool {
125 match self {
126 DecorationValue::Literal(_) => decoration_is_literal(decoration),
127 DecorationValue::BuiltIn(_) => decoration == Decoration::BuiltIn,
128 DecorationValue::RoundingMode(_) => decoration == Decoration::FPRoundingMode,
129 DecorationValue::Constant(_) => decoration == Decoration::SpecId,
130 DecorationValue::String(_) => decoration_is_string(decoration),
131 DecorationValue::Present => {
132 !decoration_is_literal(decoration)
133 && !decoration_is_string(decoration)
134 && decoration != Decoration::BuiltIn
135 && decoration != Decoration::FPRoundingMode
136 && decoration != Decoration::SpecId
137 }
138 }
139 }
140}
141fn decoration_is_literal(decoration: spirv::Decoration) -> bool {
142 matches!(
143 decoration,
144 Decoration::Location
145 | Decoration::Component
146 | Decoration::Offset
147 | Decoration::XfbBuffer
148 | Decoration::XfbStride
149 | Decoration::Stream
150 | Decoration::Binding
151 | Decoration::DescriptorSet
152 | Decoration::InputAttachmentIndex
153 | Decoration::ArrayStride
154 | Decoration::MatrixStride
155 | Decoration::Index
156 )
157}
158
159fn decoration_is_string(decoration: Decoration) -> bool {
160 matches!(
161 decoration,
162 Decoration::HlslSemanticGOOGLE | Decoration::UserTypeGOOGLE
163 )
164}
165
166impl<T> Compiler<T> {
167 pub fn decoration<I: Id>(
169 &self,
170 id: Handle<I>,
171 decoration: Decoration,
172 ) -> error::Result<Option<DecorationValue<'_>>> {
173 let id = SpvId(self.yield_id(id)?.id());
178 unsafe {
179 let has_decoration = sys::spvc_compiler_has_decoration(
180 self.ptr.as_ptr(),
181 id,
182 SpvDecoration(decoration as u32 as i32),
183 );
184 if !has_decoration {
185 return Ok(None);
186 };
187
188 if decoration_is_string(decoration) {
189 let str = sys::spvc_compiler_get_decoration_string(
190 self.ptr.as_ptr(),
191 id,
192 SpvDecoration(decoration as u32 as i32),
193 );
194 return Ok(Some(DecorationValue::String(CompilerStr::from_ptr(
195 str,
196 self.ctx.drop_guard(),
197 ))));
198 }
199
200 let value = sys::spvc_compiler_get_decoration(
201 self.ptr.as_ptr(),
202 id,
203 SpvDecoration(decoration as u32 as i32),
204 );
205 self.parse_decoration_value(decoration, value)
206 }
207 }
208
209 pub fn member_decoration_by_handle(
211 &self,
212 struct_type_id: Handle<TypeId>,
213 index: u32,
214 decoration: Decoration,
215 ) -> error::Result<Option<DecorationValue<'_>>> {
216 let struct_type = self.yield_id(struct_type_id)?;
218
219 unsafe {
220 let has_decoration = sys::spvc_compiler_has_member_decoration(
221 self.ptr.as_ptr(),
222 struct_type,
223 index,
224 SpvDecoration(decoration as u32 as i32),
225 );
226 if !has_decoration {
227 return Ok(None);
228 };
229
230 if decoration_is_string(decoration) {
231 let str = sys::spvc_compiler_get_member_decoration_string(
232 self.ptr.as_ptr(),
233 struct_type,
234 index,
235 SpvDecoration(decoration as u32 as i32),
236 );
237 return Ok(Some(DecorationValue::String(CompilerStr::from_ptr(
238 str,
239 self.ctx.drop_guard(),
240 ))));
241 }
242
243 let value = sys::spvc_compiler_get_member_decoration(
244 self.ptr.as_ptr(),
245 struct_type,
246 index,
247 SpvDecoration(decoration as u32 as i32),
248 );
249 self.parse_decoration_value(decoration, value)
250 }
251 }
252
253 pub fn member_decoration(
255 &self,
256 member: &StructMember,
257 decoration: Decoration,
258 ) -> error::Result<Option<DecorationValue<'_>>> {
259 self.member_decoration_by_handle(member.struct_type, member.index as u32, decoration)
260 }
261
262 pub fn set_decoration<'value, I: Id>(
264 &mut self,
265 id: Handle<I>,
266 decoration: spirv::Decoration,
267 value: Option<impl Into<DecorationValue<'value>>>,
268 ) -> error::Result<()> {
269 let id = SpvId(self.yield_id(id)?.id());
271 unsafe {
272 let Some(value) = value else {
273 sys::spvc_compiler_unset_decoration(
274 self.ptr.as_ptr(),
275 id,
276 SpvDecoration(decoration as u32 as i32),
277 );
278 return Ok(());
279 };
280
281 let value = value.into();
282
283 if !value.type_is_valid_for_decoration(decoration) {
284 return Err(SpirvCrossError::InvalidDecorationInput(
285 decoration,
286 DecorationValue::to_static(&value),
287 ));
288 }
289
290 match value {
291 DecorationValue::Literal(literal) => {
292 sys::spvc_compiler_set_decoration(
293 self.ptr.as_ptr(),
294 id,
295 SpvDecoration(decoration as u32 as i32),
296 literal,
297 );
298 }
299 DecorationValue::BuiltIn(builtin) => {
300 sys::spvc_compiler_set_decoration(
301 self.ptr.as_ptr(),
302 id,
303 SpvDecoration(decoration as u32 as i32),
304 builtin as u32,
305 );
306 }
307 DecorationValue::RoundingMode(rounding_mode) => {
308 sys::spvc_compiler_set_decoration(
309 self.ptr.as_ptr(),
310 id,
311 SpvDecoration(decoration as u32 as i32),
312 rounding_mode as u32,
313 );
314 }
315 DecorationValue::Constant(constant) => {
316 let constant = self.yield_id(constant)?;
317 sys::spvc_compiler_set_decoration(
318 self.ptr.as_ptr(),
319 id,
320 SpvDecoration(decoration as u32 as i32),
321 constant.id(),
322 );
323 }
324 DecorationValue::Present => {
325 sys::spvc_compiler_set_decoration(
326 self.ptr.as_ptr(),
327 id,
328 SpvDecoration(decoration as u32 as i32),
329 1,
330 );
331 }
332 DecorationValue::String(string) => {
333 let cstring = string.into_cstring_ptr().map_err(|e| {
334 let SpirvCrossError::InvalidString(string) = e else {
335 unreachable!("into_cstring_ptr only errors InvalidString")
336 };
337 SpirvCrossError::InvalidDecorationInput(
338 decoration,
339 DecorationValue::String(string.into()),
340 )
341 })?;
342
343 sys::spvc_compiler_set_decoration_string(
344 self.ptr.as_ptr(),
345 id,
346 SpvDecoration(decoration as u32 as i32),
347 cstring.as_ptr(),
348 );
349
350 drop(cstring);
354 }
355 }
356 }
357 Ok(())
358 }
359
360 pub fn set_member_decoration<'value>(
362 &mut self,
363 member: &StructMember,
364 decoration: Decoration,
365 value: Option<impl Into<DecorationValue<'value>>>,
366 ) -> error::Result<()> {
367 self.set_member_decoration_by_handle(
368 member.struct_type,
369 member.index as u32,
370 decoration,
371 value,
372 )
373 }
374
375 pub fn set_member_decoration_by_handle<'value>(
378 &mut self,
379 struct_type: Handle<TypeId>,
380 index: u32,
381 decoration: Decoration,
382 value: Option<impl Into<DecorationValue<'value>>>,
383 ) -> error::Result<()> {
384 let struct_type = self.yield_id(struct_type)?;
386
387 unsafe {
388 let Some(value) = value else {
389 sys::spvc_compiler_unset_member_decoration(
390 self.ptr.as_ptr(),
391 struct_type,
392 index,
393 SpvDecoration(decoration as u32 as i32),
394 );
395 return Ok(());
396 };
397
398 let value = value.into();
399
400 if !value.type_is_valid_for_decoration(decoration) {
401 return Err(SpirvCrossError::InvalidDecorationInput(
402 decoration,
403 DecorationValue::to_static(&value),
404 ));
405 }
406
407 match value {
408 DecorationValue::Literal(literal) => {
409 sys::spvc_compiler_set_member_decoration(
410 self.ptr.as_ptr(),
411 struct_type,
412 index,
413 SpvDecoration(decoration as u32 as i32),
414 literal,
415 );
416 }
417 DecorationValue::BuiltIn(builtin) => {
418 sys::spvc_compiler_set_member_decoration(
419 self.ptr.as_ptr(),
420 struct_type,
421 index,
422 SpvDecoration(decoration as u32 as i32),
423 builtin as u32,
424 );
425 }
426 DecorationValue::RoundingMode(rounding_mode) => {
427 sys::spvc_compiler_set_member_decoration(
428 self.ptr.as_ptr(),
429 struct_type,
430 index,
431 SpvDecoration(decoration as u32 as i32),
432 rounding_mode as u32,
433 );
434 }
435 DecorationValue::Constant(constant) => {
436 let constant = self.yield_id(constant)?;
437 sys::spvc_compiler_set_member_decoration(
438 self.ptr.as_ptr(),
439 struct_type,
440 index,
441 SpvDecoration(decoration as u32 as i32),
442 constant.id(),
443 );
444 }
445 DecorationValue::Present => {
446 sys::spvc_compiler_set_member_decoration(
447 self.ptr.as_ptr(),
448 struct_type,
449 index,
450 SpvDecoration(decoration as u32 as i32),
451 1,
452 );
453 }
454 DecorationValue::String(string) => {
455 let cstring = string.into_cstring_ptr().map_err(|e| {
456 let SpirvCrossError::InvalidString(string) = e else {
457 unreachable!("into_cstring_ptr only errors InvalidString")
458 };
459 SpirvCrossError::InvalidDecorationInput(
460 decoration,
461 DecorationValue::String(string.into()),
462 )
463 })?;
464
465 sys::spvc_compiler_set_member_decoration_string(
466 self.ptr.as_ptr(),
467 struct_type,
468 index,
469 SpvDecoration(decoration as u32 as i32),
470 cstring.as_ptr(),
471 );
472
473 drop(cstring);
477 }
478 }
479 }
480 Ok(())
481 }
482
483 pub fn binary_offset_for_decoration(
492 &self,
493 variable: impl Into<Handle<VariableId>>,
494 decoration: Decoration,
495 ) -> error::Result<Option<u32>> {
496 let id = self.yield_id(variable.into())?;
497
498 unsafe {
499 let mut offset = 0;
500 if !sys::spvc_compiler_get_binary_offset_for_decoration(
501 self.ptr.as_ptr(),
502 id,
503 SpvDecoration(decoration as u32 as i32),
504 &mut offset,
505 ) {
506 Ok(None)
507 } else {
508 Ok(Some(offset))
509 }
510 }
511 }
512
513 fn parse_decoration_value(
514 &self,
515 decoration: Decoration,
516 value: u32,
517 ) -> error::Result<Option<DecorationValue<'_>>> {
518 if decoration_is_literal(decoration) {
519 return Ok(Some(DecorationValue::Literal(value)));
520 }
521
522 match decoration {
524 Decoration::BuiltIn => {
525 let Some(builtin) = spirv::BuiltIn::from_u32(value) else {
526 return Err(SpirvCrossError::InvalidDecorationOutput(decoration, value));
527 };
528 Ok(Some(DecorationValue::BuiltIn(builtin)))
529 }
530 Decoration::FPRoundingMode => {
531 if value as i32 == i32::MAX {
533 return Ok(None);
534 }
535
536 let Some(rounding_mode) = spirv::FPRoundingMode::from_u32(value) else {
537 return Err(SpirvCrossError::InvalidDecorationOutput(decoration, value));
538 };
539 Ok(Some(DecorationValue::RoundingMode(rounding_mode)))
540 }
541 Decoration::SpecId => unsafe {
542 Ok(Some(DecorationValue::Constant(
543 self.create_handle(ConstantId(SpvId(value))),
544 )))
545 },
546 _ => {
547 if value == 1 {
548 Ok(Some(DecorationValue::Present))
549 } else {
550 Ok(None)
551 }
552 }
553 }
554 }
555
556 pub fn buffer_block_decorations(
561 &self,
562 variable: impl Into<Handle<VariableId>>,
563 ) -> error::Result<Option<&[Decoration]>> {
564 let variable = variable.into();
565 let id = self.yield_id(variable)?;
566
567 unsafe {
568 let mut size = 0;
569 let mut buffer = std::ptr::null();
570 sys::spvc_compiler_get_buffer_block_decorations(
571 self.ptr.as_ptr(),
572 id,
573 &mut buffer,
574 &mut size,
575 )
576 .ok(self)?;
577
578 let slice = super::try_valid_slice::<Decoration>(buffer.cast(), size)?;
581 if slice.is_empty() {
582 Ok(None)
583 } else {
584 Ok(Some(slice))
585 }
586 }
587 }
588}
589
590#[cfg(test)]
591mod test {
592 use crate::error::SpirvCrossError;
593 use crate::Compiler;
594
595 use crate::{targets, Module};
596
597 static BASIC_SPV: &[u8] = include_bytes!("../../basic.spv");
598
599 #[test]
600 pub fn set_decoration_test() -> Result<(), SpirvCrossError> {
601 let vec = Vec::from(BASIC_SPV);
602 let words = Module::from_words(bytemuck::cast_slice(&vec));
603
604 let compiler: Compiler<targets::None> = Compiler::new(words)?;
605 let resources = compiler.shader_resources()?.all_resources()?;
606
607 Ok(())
609 }
610}