1use crate::error::{SpirvCrossError, ToContextError};
2use crate::handle::{ConstantId, Handle, Id, TypeId, VariableId};
3use crate::reflect::StructMember;
4use crate::sealed::Sealed;
5use crate::string::CompilerStr;
6use crate::Compiler;
7use crate::{error, ToStatic};
8use spirv::Decoration;
9use spirv_cross_sys as sys;
10use spirv_cross_sys::{SpvDecoration, SpvId};
11
12#[derive(Debug, Eq, PartialEq)]
14pub enum DecorationValue<'a> {
15 Literal(u32),
30 BuiltIn(spirv::BuiltIn),
32 RoundingMode(spirv::FPRoundingMode),
34 Constant(Handle<ConstantId>),
36 String(CompilerStr<'a>),
38 Present,
40}
41
42impl DecorationValue<'_> {
43 pub const fn unset() -> Option<Self> {
46 None
47 }
48
49 pub fn as_literal(&self) -> Option<u32> {
51 match self {
52 Self::Literal(l) => Some(*l),
53 _ => None,
54 }
55 }
56}
57
58impl From<u32> for DecorationValue<'_> {
59 fn from(value: u32) -> Self {
60 DecorationValue::Literal(value)
61 }
62}
63
64impl From<()> for DecorationValue<'_> {
65 fn from(_value: ()) -> Self {
66 DecorationValue::Present
67 }
68}
69
70impl From<Handle<ConstantId>> for DecorationValue<'_> {
71 fn from(value: Handle<ConstantId>) -> Self {
72 DecorationValue::Constant(value)
73 }
74}
75
76impl<'a> From<&'a str> for DecorationValue<'a> {
77 fn from(value: &'a str) -> Self {
78 DecorationValue::String(CompilerStr::from_str(value))
79 }
80}
81
82impl From<String> for DecorationValue<'_> {
83 fn from(value: String) -> Self {
84 DecorationValue::String(CompilerStr::from_string(value))
85 }
86}
87
88impl<'a> From<CompilerStr<'a>> for DecorationValue<'a> {
89 fn from(value: CompilerStr<'a>) -> Self {
90 DecorationValue::String(value)
91 }
92}
93
94impl Sealed for DecorationValue<'_> {}
95impl ToStatic for DecorationValue<'_> {
96 type Static<'a>
97 = DecorationValue<'static>
98 where
99 'a: 'static;
100
101 fn to_static(&self) -> Self::Static<'static> {
102 match self {
103 DecorationValue::Literal(a) => DecorationValue::Literal(*a),
104 DecorationValue::BuiltIn(a) => DecorationValue::BuiltIn(*a),
105 DecorationValue::RoundingMode(a) => DecorationValue::RoundingMode(*a),
106 DecorationValue::Constant(a) => DecorationValue::Constant(*a),
107 DecorationValue::String(c) => {
108 let owned = c.to_string();
109 DecorationValue::String(CompilerStr::from_string(owned))
110 }
111 DecorationValue::Present => DecorationValue::Present,
112 }
113 }
114}
115
116impl<'a> Clone for DecorationValue<'a> {
117 fn clone(&self) -> DecorationValue<'static> {
118 self.to_static()
119 }
120}
121
122impl DecorationValue<'_> {
123 pub fn type_is_valid_for_decoration(&self, decoration: spirv::Decoration) -> bool {
125 match self {
126 DecorationValue::Literal(_) => decoration_is_literal(decoration),
127 DecorationValue::BuiltIn(_) => decoration == Decoration::BuiltIn,
128 DecorationValue::RoundingMode(_) => decoration == Decoration::FPRoundingMode,
129 DecorationValue::Constant(_) => decoration == Decoration::SpecId,
130 DecorationValue::String(_) => decoration_is_string(decoration),
131 DecorationValue::Present => {
132 !decoration_is_literal(decoration)
133 && !decoration_is_string(decoration)
134 && decoration != Decoration::BuiltIn
135 && decoration != Decoration::FPRoundingMode
136 && decoration != Decoration::SpecId
137 }
138 }
139 }
140}
141fn decoration_is_literal(decoration: spirv::Decoration) -> bool {
142 match decoration {
143 Decoration::Location
144 | Decoration::Component
145 | Decoration::Offset
146 | Decoration::XfbBuffer
147 | Decoration::XfbStride
148 | Decoration::Stream
149 | Decoration::Binding
150 | Decoration::DescriptorSet
151 | Decoration::InputAttachmentIndex
152 | Decoration::ArrayStride
153 | Decoration::MatrixStride
154 | Decoration::Index => true,
155 _ => false,
156 }
157}
158
159fn decoration_is_string(decoration: Decoration) -> bool {
160 match decoration {
161 Decoration::HlslSemanticGOOGLE | Decoration::UserTypeGOOGLE => true,
162 _ => false,
163 }
164}
165
166impl<T> Compiler<T> {
167 pub fn decoration<I: Id>(
169 &self,
170 id: Handle<I>,
171 decoration: Decoration,
172 ) -> error::Result<Option<DecorationValue>> {
173 let id = SpvId(self.yield_id(id)?.id());
178 unsafe {
179 let has_decoration = sys::spvc_compiler_has_decoration(
180 self.ptr.as_ptr(),
181 id,
182 SpvDecoration(decoration as u32 as i32),
183 );
184 if !has_decoration {
185 return Ok(None);
186 };
187
188 if decoration_is_string(decoration) {
189 let str = sys::spvc_compiler_get_decoration_string(
190 self.ptr.as_ptr(),
191 id,
192 SpvDecoration(decoration as u32 as i32),
193 );
194 return Ok(Some(DecorationValue::String(CompilerStr::from_ptr(
195 str,
196 self.ctx.drop_guard(),
197 ))));
198 }
199
200 let value = sys::spvc_compiler_get_decoration(
201 self.ptr.as_ptr(),
202 id,
203 SpvDecoration(decoration as u32 as i32),
204 );
205 self.parse_decoration_value(decoration, value)
206 }
207 }
208
209 pub fn member_decoration_by_handle(
211 &self,
212 struct_type_id: Handle<TypeId>,
213 index: u32,
214 decoration: Decoration,
215 ) -> error::Result<Option<DecorationValue>> {
216 let struct_type = self.yield_id(struct_type_id)?;
218 let index = index;
219
220 unsafe {
221 let has_decoration = sys::spvc_compiler_has_member_decoration(
222 self.ptr.as_ptr(),
223 struct_type,
224 index,
225 SpvDecoration(decoration as u32 as i32),
226 );
227 if !has_decoration {
228 return Ok(None);
229 };
230
231 if decoration_is_string(decoration) {
232 let str = sys::spvc_compiler_get_member_decoration_string(
233 self.ptr.as_ptr(),
234 struct_type,
235 index,
236 SpvDecoration(decoration as u32 as i32),
237 );
238 return Ok(Some(DecorationValue::String(CompilerStr::from_ptr(
239 str,
240 self.ctx.drop_guard(),
241 ))));
242 }
243
244 let value = sys::spvc_compiler_get_member_decoration(
245 self.ptr.as_ptr(),
246 struct_type,
247 index,
248 SpvDecoration(decoration as u32 as i32),
249 );
250 self.parse_decoration_value(decoration, value)
251 }
252 }
253
254 pub fn member_decoration(
256 &self,
257 member: &StructMember,
258 decoration: Decoration,
259 ) -> error::Result<Option<DecorationValue>> {
260 self.member_decoration_by_handle(member.struct_type, member.index as u32, decoration)
261 }
262
263 pub fn set_decoration<'value, I: Id>(
265 &mut self,
266 id: Handle<I>,
267 decoration: spirv::Decoration,
268 value: Option<impl Into<DecorationValue<'value>>>,
269 ) -> error::Result<()> {
270 let id = SpvId(self.yield_id(id)?.id());
272 unsafe {
273 let Some(value) = value else {
274 sys::spvc_compiler_unset_decoration(
275 self.ptr.as_ptr(),
276 id,
277 SpvDecoration(decoration as u32 as i32),
278 );
279 return Ok(());
280 };
281
282 let value = value.into();
283
284 if !value.type_is_valid_for_decoration(decoration) {
285 return Err(SpirvCrossError::InvalidDecorationInput(
286 decoration,
287 DecorationValue::to_static(&value),
288 ));
289 }
290
291 match value {
292 DecorationValue::Literal(literal) => {
293 sys::spvc_compiler_set_decoration(
294 self.ptr.as_ptr(),
295 id,
296 SpvDecoration(decoration as u32 as i32),
297 literal,
298 );
299 }
300 DecorationValue::BuiltIn(builtin) => {
301 sys::spvc_compiler_set_decoration(
302 self.ptr.as_ptr(),
303 id,
304 SpvDecoration(decoration as u32 as i32),
305 builtin as u32,
306 );
307 }
308 DecorationValue::RoundingMode(rounding_mode) => {
309 sys::spvc_compiler_set_decoration(
310 self.ptr.as_ptr(),
311 id,
312 SpvDecoration(decoration as u32 as i32),
313 rounding_mode as u32,
314 );
315 }
316 DecorationValue::Constant(constant) => {
317 let constant = self.yield_id(constant)?;
318 sys::spvc_compiler_set_decoration(
319 self.ptr.as_ptr(),
320 id,
321 SpvDecoration(decoration as u32 as i32),
322 constant.id(),
323 );
324 }
325 DecorationValue::Present => {
326 sys::spvc_compiler_set_decoration(
327 self.ptr.as_ptr(),
328 id,
329 SpvDecoration(decoration as u32 as i32),
330 1,
331 );
332 }
333 DecorationValue::String(string) => {
334 let cstring = string.into_cstring_ptr().map_err(|e| {
335 let SpirvCrossError::InvalidString(string) = e else {
336 unreachable!("into_cstring_ptr only errors InvalidString")
337 };
338 SpirvCrossError::InvalidDecorationInput(
339 decoration,
340 DecorationValue::String(string.into()),
341 )
342 })?;
343
344 sys::spvc_compiler_set_decoration_string(
345 self.ptr.as_ptr(),
346 id,
347 SpvDecoration(decoration as u32 as i32),
348 cstring.as_ptr(),
349 );
350
351 drop(cstring);
355 }
356 }
357 }
358 Ok(())
359 }
360
361 pub fn set_member_decoration<'value>(
363 &mut self,
364 member: &StructMember,
365 decoration: Decoration,
366 value: Option<impl Into<DecorationValue<'value>>>,
367 ) -> error::Result<()> {
368 self.set_member_decoration_by_handle(
369 member.struct_type,
370 member.index as u32,
371 decoration,
372 value,
373 )
374 }
375
376 pub fn set_member_decoration_by_handle<'value>(
379 &mut self,
380 struct_type: Handle<TypeId>,
381 index: u32,
382 decoration: Decoration,
383 value: Option<impl Into<DecorationValue<'value>>>,
384 ) -> error::Result<()> {
385 let struct_type = self.yield_id(struct_type)?;
387
388 unsafe {
389 let Some(value) = value else {
390 sys::spvc_compiler_unset_member_decoration(
391 self.ptr.as_ptr(),
392 struct_type,
393 index,
394 SpvDecoration(decoration as u32 as i32),
395 );
396 return Ok(());
397 };
398
399 let value = value.into();
400
401 if !value.type_is_valid_for_decoration(decoration) {
402 return Err(SpirvCrossError::InvalidDecorationInput(
403 decoration,
404 DecorationValue::to_static(&value),
405 ));
406 }
407
408 match value {
409 DecorationValue::Literal(literal) => {
410 sys::spvc_compiler_set_member_decoration(
411 self.ptr.as_ptr(),
412 struct_type,
413 index,
414 SpvDecoration(decoration as u32 as i32),
415 literal,
416 );
417 }
418 DecorationValue::BuiltIn(builtin) => {
419 sys::spvc_compiler_set_member_decoration(
420 self.ptr.as_ptr(),
421 struct_type,
422 index,
423 SpvDecoration(decoration as u32 as i32),
424 builtin as u32,
425 );
426 }
427 DecorationValue::RoundingMode(rounding_mode) => {
428 sys::spvc_compiler_set_member_decoration(
429 self.ptr.as_ptr(),
430 struct_type,
431 index,
432 SpvDecoration(decoration as u32 as i32),
433 rounding_mode as u32,
434 );
435 }
436 DecorationValue::Constant(constant) => {
437 let constant = self.yield_id(constant)?;
438 sys::spvc_compiler_set_member_decoration(
439 self.ptr.as_ptr(),
440 struct_type,
441 index,
442 SpvDecoration(decoration as u32 as i32),
443 constant.id(),
444 );
445 }
446 DecorationValue::Present => {
447 sys::spvc_compiler_set_member_decoration(
448 self.ptr.as_ptr(),
449 struct_type,
450 index,
451 SpvDecoration(decoration as u32 as i32),
452 1,
453 );
454 }
455 DecorationValue::String(string) => {
456 let cstring = string.into_cstring_ptr().map_err(|e| {
457 let SpirvCrossError::InvalidString(string) = e else {
458 unreachable!("into_cstring_ptr only errors InvalidString")
459 };
460 SpirvCrossError::InvalidDecorationInput(
461 decoration,
462 DecorationValue::String(string.into()),
463 )
464 })?;
465
466 sys::spvc_compiler_set_member_decoration_string(
467 self.ptr.as_ptr(),
468 struct_type,
469 index,
470 SpvDecoration(decoration as u32 as i32),
471 cstring.as_ptr(),
472 );
473
474 drop(cstring);
478 }
479 }
480 }
481 Ok(())
482 }
483
484 pub fn binary_offset_for_decoration(
493 &self,
494 variable: impl Into<Handle<VariableId>>,
495 decoration: Decoration,
496 ) -> error::Result<Option<u32>> {
497 let id = self.yield_id(variable.into())?;
498
499 unsafe {
500 let mut offset = 0;
501 if !sys::spvc_compiler_get_binary_offset_for_decoration(
502 self.ptr.as_ptr(),
503 id,
504 SpvDecoration(decoration as u32 as i32),
505 &mut offset,
506 ) {
507 Ok(None)
508 } else {
509 Ok(Some(offset))
510 }
511 }
512 }
513
514 fn parse_decoration_value(
515 &self,
516 decoration: Decoration,
517 value: u32,
518 ) -> error::Result<Option<DecorationValue>> {
519 if decoration_is_literal(decoration) {
520 return Ok(Some(DecorationValue::Literal(value)));
521 }
522
523 match decoration {
525 Decoration::BuiltIn => {
526 let Some(builtin) = spirv::BuiltIn::from_u32(value) else {
527 return Err(SpirvCrossError::InvalidDecorationOutput(decoration, value));
528 };
529 Ok(Some(DecorationValue::BuiltIn(builtin)))
530 }
531 Decoration::FPRoundingMode => {
532 if value as i32 == i32::MAX {
534 return Ok(None);
535 }
536
537 let Some(rounding_mode) = spirv::FPRoundingMode::from_u32(value) else {
538 return Err(SpirvCrossError::InvalidDecorationOutput(decoration, value));
539 };
540 Ok(Some(DecorationValue::RoundingMode(rounding_mode)))
541 }
542 Decoration::SpecId => unsafe {
543 Ok(Some(DecorationValue::Constant(
544 self.create_handle(ConstantId(SpvId(value))),
545 )))
546 },
547 _ => {
548 if value == 1 {
549 Ok(Some(DecorationValue::Present))
550 } else {
551 Ok(None)
552 }
553 }
554 }
555 }
556
557 pub fn buffer_block_decorations(
562 &self,
563 variable: impl Into<Handle<VariableId>>,
564 ) -> error::Result<Option<&[Decoration]>> {
565 let variable = variable.into();
566 let id = self.yield_id(variable)?;
567
568 unsafe {
569 let mut size = 0;
570 let mut buffer = std::ptr::null();
571 sys::spvc_compiler_get_buffer_block_decorations(
572 self.ptr.as_ptr(),
573 id,
574 &mut buffer,
575 &mut size,
576 )
577 .ok(self)?;
578
579 let slice = super::try_valid_slice::<Decoration>(buffer.cast(), size)?;
582 if slice.is_empty() {
583 Ok(None)
584 } else {
585 Ok(Some(slice))
586 }
587 }
588 }
589}
590
591#[cfg(test)]
592mod test {
593 use crate::error::SpirvCrossError;
594 use crate::Compiler;
595
596 use crate::{targets, Module};
597
598 static BASIC_SPV: &[u8] = include_bytes!("../../basic.spv");
599
600 #[test]
601 pub fn set_decoration_test() -> Result<(), SpirvCrossError> {
602 let vec = Vec::from(BASIC_SPV);
603 let words = Module::from_words(bytemuck::cast_slice(&vec));
604
605 let compiler: Compiler<targets::None> = Compiler::new(words)?;
606 let resources = compiler.shader_resources()?.all_resources()?;
607
608 Ok(())
610 }
611}