spirv_cross2/reflect/
decorations.rs

1use crate::error::{SpirvCrossError, ToContextError};
2use crate::handle::{ConstantId, Handle, Id, TypeId, VariableId};
3use crate::reflect::StructMember;
4use crate::sealed::Sealed;
5use crate::string::CompilerStr;
6use crate::Compiler;
7use crate::{error, ToStatic};
8use spirv::Decoration;
9use spirv_cross_sys as sys;
10use spirv_cross_sys::{SpvDecoration, SpvId};
11
12/// A value accompanying an `OpDecoration`
13#[derive(Debug, Eq, PartialEq)]
14pub enum DecorationValue<'a> {
15    /// Returned by the following decorations.
16    ///
17    /// - [`Location`](Decoration::Location).
18    /// - [`Component`](Decoration::Component).
19    /// - [`Offset`](Decoration::Offset).
20    /// - [`XfbBuffer`](Decoration::XfbBuffer).
21    /// - [`XfbStride`](Decoration::XfbStride).
22    /// - [`Stream`](Decoration::Stream).
23    /// - [`Binding`](Decoration::Binding).
24    /// - [`DescriptorSet`](Decoration::DescriptorSet).
25    /// - [`InputAttachmentIndex`](Decoration::InputAttachmentIndex).
26    /// - [`ArrayStride`](Decoration::ArrayStride).
27    /// - [`MatrixStride`](Decoration::MatrixStride).
28    /// - [`Index`](Decoration::Index).
29    Literal(u32),
30    /// Only for decoration [`BuiltIn`](Decoration::BuiltIn).
31    BuiltIn(spirv::BuiltIn),
32    /// Only for decoration [`FPRoundingMode`](Decoration::FPRoundingMode).
33    RoundingMode(spirv::FPRoundingMode),
34    /// Only for decoration [`SpecId`](Decoration::SpecId).
35    Constant(Handle<ConstantId>),
36    /// Only for decoration [`HlslSemanticGOOGLE`](Decoration::HlslSemanticGOOGLE) and [`UserTypeGOOGLE`](Decoration::HlslSemanticGOOGLE).
37    String(CompilerStr<'a>),
38    /// All other decorations to indicate the presence of a decoration.
39    Present,
40}
41
42impl DecorationValue<'_> {
43    /// Helper function to unset a decoration value, to be passed to
44    /// [`Compiler::set_decoration`].
45    pub const fn unset() -> Option<Self> {
46        None
47    }
48
49    /// Get the value if it is a literal `u32`.
50    pub fn as_literal(&self) -> Option<u32> {
51        match self {
52            Self::Literal(l) => Some(*l),
53            _ => None,
54        }
55    }
56}
57
58impl From<u32> for DecorationValue<'_> {
59    fn from(value: u32) -> Self {
60        DecorationValue::Literal(value)
61    }
62}
63
64impl From<()> for DecorationValue<'_> {
65    fn from(_value: ()) -> Self {
66        DecorationValue::Present
67    }
68}
69
70impl From<Handle<ConstantId>> for DecorationValue<'_> {
71    fn from(value: Handle<ConstantId>) -> Self {
72        DecorationValue::Constant(value)
73    }
74}
75
76impl<'a> From<&'a str> for DecorationValue<'a> {
77    fn from(value: &'a str) -> Self {
78        DecorationValue::String(CompilerStr::from_str(value))
79    }
80}
81
82impl From<String> for DecorationValue<'_> {
83    fn from(value: String) -> Self {
84        DecorationValue::String(CompilerStr::from_string(value))
85    }
86}
87
88impl<'a> From<CompilerStr<'a>> for DecorationValue<'a> {
89    fn from(value: CompilerStr<'a>) -> Self {
90        DecorationValue::String(value)
91    }
92}
93
94impl Sealed for DecorationValue<'_> {}
95impl ToStatic for DecorationValue<'_> {
96    type Static<'a>
97    = DecorationValue<'static>
98    where
99        'a: 'static;
100
101    fn to_static(&self) -> Self::Static<'static> {
102        match self {
103            DecorationValue::Literal(a) => DecorationValue::Literal(*a),
104            DecorationValue::BuiltIn(a) => DecorationValue::BuiltIn(*a),
105            DecorationValue::RoundingMode(a) => DecorationValue::RoundingMode(*a),
106            DecorationValue::Constant(a) => DecorationValue::Constant(*a),
107            DecorationValue::String(c) => {
108                let owned = c.to_string();
109                DecorationValue::String(CompilerStr::from_string(owned))
110            }
111            DecorationValue::Present => DecorationValue::Present,
112        }
113    }
114}
115
116impl<'a> Clone for DecorationValue<'a> {
117    fn clone(&self) -> DecorationValue<'static> {
118        self.to_static()
119    }
120}
121
122impl DecorationValue<'_> {
123    /// Check that the value is valid for the decoration type.
124    pub fn type_is_valid_for_decoration(&self, decoration: spirv::Decoration) -> bool {
125        match self {
126            DecorationValue::Literal(_) => decoration_is_literal(decoration),
127            DecorationValue::BuiltIn(_) => decoration == Decoration::BuiltIn,
128            DecorationValue::RoundingMode(_) => decoration == Decoration::FPRoundingMode,
129            DecorationValue::Constant(_) => decoration == Decoration::SpecId,
130            DecorationValue::String(_) => decoration_is_string(decoration),
131            DecorationValue::Present => {
132                !decoration_is_literal(decoration)
133                    && !decoration_is_string(decoration)
134                    && decoration != Decoration::BuiltIn
135                    && decoration != Decoration::FPRoundingMode
136                    && decoration != Decoration::SpecId
137            }
138        }
139    }
140}
141fn decoration_is_literal(decoration: spirv::Decoration) -> bool {
142    match decoration {
143        Decoration::Location
144        | Decoration::Component
145        | Decoration::Offset
146        | Decoration::XfbBuffer
147        | Decoration::XfbStride
148        | Decoration::Stream
149        | Decoration::Binding
150        | Decoration::DescriptorSet
151        | Decoration::InputAttachmentIndex
152        | Decoration::ArrayStride
153        | Decoration::MatrixStride
154        | Decoration::Index => true,
155        _ => false,
156    }
157}
158
159fn decoration_is_string(decoration: Decoration) -> bool {
160    match decoration {
161        Decoration::HlslSemanticGOOGLE | Decoration::UserTypeGOOGLE => true,
162        _ => false,
163    }
164}
165
166impl<T> Compiler<T> {
167    /// Gets the value for decorations which take arguments.
168    pub fn decoration<I: Id>(
169        &self,
170        id: Handle<I>,
171        decoration: Decoration,
172    ) -> error::Result<Option<DecorationValue>> {
173        // SAFETY: 'ctx is not sound to return here!
174        //  https://github.com/KhronosGroup/SPIRV-Cross/blob/6a1fb66eef1bdca14acf7d0a51a3f883499d79f0/spirv_cross_c.cpp#L2154
175
176        // SAFETY: id is yielded by the instance so it's safe to use.
177        let id = SpvId(self.yield_id(id)?.id());
178        unsafe {
179            let has_decoration = sys::spvc_compiler_has_decoration(
180                self.ptr.as_ptr(),
181                id,
182                SpvDecoration(decoration as u32 as i32),
183            );
184            if !has_decoration {
185                return Ok(None);
186            };
187
188            if decoration_is_string(decoration) {
189                let str = sys::spvc_compiler_get_decoration_string(
190                    self.ptr.as_ptr(),
191                    id,
192                    SpvDecoration(decoration as u32 as i32),
193                );
194                return Ok(Some(DecorationValue::String(CompilerStr::from_ptr(
195                    str,
196                    self.ctx.drop_guard(),
197                ))));
198            }
199
200            let value = sys::spvc_compiler_get_decoration(
201                self.ptr.as_ptr(),
202                id,
203                SpvDecoration(decoration as u32 as i32),
204            );
205            self.parse_decoration_value(decoration, value)
206        }
207    }
208
209    /// Gets the value for member decorations which take arguments.
210    pub fn member_decoration_by_handle(
211        &self,
212        struct_type_id: Handle<TypeId>,
213        index: u32,
214        decoration: Decoration,
215    ) -> error::Result<Option<DecorationValue>> {
216        // SAFETY: id is yielded by the instance so it's safe to use.
217        let struct_type = self.yield_id(struct_type_id)?;
218        let index = index;
219
220        unsafe {
221            let has_decoration = sys::spvc_compiler_has_member_decoration(
222                self.ptr.as_ptr(),
223                struct_type,
224                index,
225                SpvDecoration(decoration as u32 as i32),
226            );
227            if !has_decoration {
228                return Ok(None);
229            };
230
231            if decoration_is_string(decoration) {
232                let str = sys::spvc_compiler_get_member_decoration_string(
233                    self.ptr.as_ptr(),
234                    struct_type,
235                    index,
236                    SpvDecoration(decoration as u32 as i32),
237                );
238                return Ok(Some(DecorationValue::String(CompilerStr::from_ptr(
239                    str,
240                    self.ctx.drop_guard(),
241                ))));
242            }
243
244            let value = sys::spvc_compiler_get_member_decoration(
245                self.ptr.as_ptr(),
246                struct_type,
247                index,
248                SpvDecoration(decoration as u32 as i32),
249            );
250            self.parse_decoration_value(decoration, value)
251        }
252    }
253
254    /// Gets the value for member decorations which take arguments.
255    pub fn member_decoration(
256        &self,
257        member: &StructMember,
258        decoration: Decoration,
259    ) -> error::Result<Option<DecorationValue>> {
260        self.member_decoration_by_handle(member.struct_type, member.index as u32, decoration)
261    }
262
263    /// Set the value of a decoration for an ID.
264    pub fn set_decoration<'value, I: Id>(
265        &mut self,
266        id: Handle<I>,
267        decoration: spirv::Decoration,
268        value: Option<impl Into<DecorationValue<'value>>>,
269    ) -> error::Result<()> {
270        // SAFETY: id is yielded by the instance so it's safe to use.
271        let id = SpvId(self.yield_id(id)?.id());
272        unsafe {
273            let Some(value) = value else {
274                sys::spvc_compiler_unset_decoration(
275                    self.ptr.as_ptr(),
276                    id,
277                    SpvDecoration(decoration as u32 as i32),
278                );
279                return Ok(());
280            };
281
282            let value = value.into();
283
284            if !value.type_is_valid_for_decoration(decoration) {
285                return Err(SpirvCrossError::InvalidDecorationInput(
286                    decoration,
287                    DecorationValue::to_static(&value),
288                ));
289            }
290
291            match value {
292                DecorationValue::Literal(literal) => {
293                    sys::spvc_compiler_set_decoration(
294                        self.ptr.as_ptr(),
295                        id,
296                        SpvDecoration(decoration as u32 as i32),
297                        literal,
298                    );
299                }
300                DecorationValue::BuiltIn(builtin) => {
301                    sys::spvc_compiler_set_decoration(
302                        self.ptr.as_ptr(),
303                        id,
304                        SpvDecoration(decoration as u32 as i32),
305                        builtin as u32,
306                    );
307                }
308                DecorationValue::RoundingMode(rounding_mode) => {
309                    sys::spvc_compiler_set_decoration(
310                        self.ptr.as_ptr(),
311                        id,
312                        SpvDecoration(decoration as u32 as i32),
313                        rounding_mode as u32,
314                    );
315                }
316                DecorationValue::Constant(constant) => {
317                    let constant = self.yield_id(constant)?;
318                    sys::spvc_compiler_set_decoration(
319                        self.ptr.as_ptr(),
320                        id,
321                        SpvDecoration(decoration as u32 as i32),
322                        constant.id(),
323                    );
324                }
325                DecorationValue::Present => {
326                    sys::spvc_compiler_set_decoration(
327                        self.ptr.as_ptr(),
328                        id,
329                        SpvDecoration(decoration as u32 as i32),
330                        1,
331                    );
332                }
333                DecorationValue::String(string) => {
334                    let cstring = string.into_cstring_ptr().map_err(|e| {
335                        let SpirvCrossError::InvalidString(string) = e else {
336                            unreachable!("into_cstring_ptr only errors InvalidString")
337                        };
338                        SpirvCrossError::InvalidDecorationInput(
339                            decoration,
340                            DecorationValue::String(string.into()),
341                        )
342                    })?;
343
344                    sys::spvc_compiler_set_decoration_string(
345                        self.ptr.as_ptr(),
346                        id,
347                        SpvDecoration(decoration as u32 as i32),
348                        cstring.as_ptr(),
349                    );
350
351                    // Sanity drop to show that the lifetime of the cstring is only up until
352                    // we have returned. AFAIK, SPIRV-Cross will do a string copy.
353                    // If it does not, then we'll have to keep this string alive for a while.
354                    drop(cstring);
355                }
356            }
357        }
358        Ok(())
359    }
360
361    /// Set the value of a decoration for a struct member.
362    pub fn set_member_decoration<'value>(
363        &mut self,
364        member: &StructMember,
365        decoration: Decoration,
366        value: Option<impl Into<DecorationValue<'value>>>,
367    ) -> error::Result<()> {
368        self.set_member_decoration_by_handle(
369            member.struct_type,
370            member.index as u32,
371            decoration,
372            value,
373        )
374    }
375
376    /// Set the value of a decoration for a struct member by the handle of its parent struct
377    /// and the index.
378    pub fn set_member_decoration_by_handle<'value>(
379        &mut self,
380        struct_type: Handle<TypeId>,
381        index: u32,
382        decoration: Decoration,
383        value: Option<impl Into<DecorationValue<'value>>>,
384    ) -> error::Result<()> {
385        // SAFETY: id is yielded by the instance so it's safe to use.
386        let struct_type = self.yield_id(struct_type)?;
387
388        unsafe {
389            let Some(value) = value else {
390                sys::spvc_compiler_unset_member_decoration(
391                    self.ptr.as_ptr(),
392                    struct_type,
393                    index,
394                    SpvDecoration(decoration as u32 as i32),
395                );
396                return Ok(());
397            };
398
399            let value = value.into();
400
401            if !value.type_is_valid_for_decoration(decoration) {
402                return Err(SpirvCrossError::InvalidDecorationInput(
403                    decoration,
404                    DecorationValue::to_static(&value),
405                ));
406            }
407
408            match value {
409                DecorationValue::Literal(literal) => {
410                    sys::spvc_compiler_set_member_decoration(
411                        self.ptr.as_ptr(),
412                        struct_type,
413                        index,
414                        SpvDecoration(decoration as u32 as i32),
415                        literal,
416                    );
417                }
418                DecorationValue::BuiltIn(builtin) => {
419                    sys::spvc_compiler_set_member_decoration(
420                        self.ptr.as_ptr(),
421                        struct_type,
422                        index,
423                        SpvDecoration(decoration as u32 as i32),
424                        builtin as u32,
425                    );
426                }
427                DecorationValue::RoundingMode(rounding_mode) => {
428                    sys::spvc_compiler_set_member_decoration(
429                        self.ptr.as_ptr(),
430                        struct_type,
431                        index,
432                        SpvDecoration(decoration as u32 as i32),
433                        rounding_mode as u32,
434                    );
435                }
436                DecorationValue::Constant(constant) => {
437                    let constant = self.yield_id(constant)?;
438                    sys::spvc_compiler_set_member_decoration(
439                        self.ptr.as_ptr(),
440                        struct_type,
441                        index,
442                        SpvDecoration(decoration as u32 as i32),
443                        constant.id(),
444                    );
445                }
446                DecorationValue::Present => {
447                    sys::spvc_compiler_set_member_decoration(
448                        self.ptr.as_ptr(),
449                        struct_type,
450                        index,
451                        SpvDecoration(decoration as u32 as i32),
452                        1,
453                    );
454                }
455                DecorationValue::String(string) => {
456                    let cstring = string.into_cstring_ptr().map_err(|e| {
457                        let SpirvCrossError::InvalidString(string) = e else {
458                            unreachable!("into_cstring_ptr only errors InvalidString")
459                        };
460                        SpirvCrossError::InvalidDecorationInput(
461                            decoration,
462                            DecorationValue::String(string.into()),
463                        )
464                    })?;
465
466                    sys::spvc_compiler_set_member_decoration_string(
467                        self.ptr.as_ptr(),
468                        struct_type,
469                        index,
470                        SpvDecoration(decoration as u32 as i32),
471                        cstring.as_ptr(),
472                    );
473
474                    // Sanity drop to show that the lifetime of the cstring is only up until
475                    // we have returned. AFAIK, SPIRV-Cross will do a string copy.
476                    // If it does not, then we'll have to keep this string alive for a while.
477                    drop(cstring);
478                }
479            }
480        }
481        Ok(())
482    }
483
484    /// Gets the offset in SPIR-V words (uint32_t) for a decoration which was originally declared in the SPIR-V binary.
485    /// The offset will point to one or more uint32_t literals which can be modified in-place before using the SPIR-V binary.
486    ///
487    /// Note that adding or removing decorations using the reflection API will not change the behavior of this function.
488    /// If the decoration was declared, returns an offset into the provided SPIR-V binary buffer,
489    /// otherwise returns None.
490    ///
491    /// If the decoration does not have any value attached to it (e.g. DecorationRelaxedPrecision), this function will also return None.
492    pub fn binary_offset_for_decoration(
493        &self,
494        variable: impl Into<Handle<VariableId>>,
495        decoration: Decoration,
496    ) -> error::Result<Option<u32>> {
497        let id = self.yield_id(variable.into())?;
498
499        unsafe {
500            let mut offset = 0;
501            if !sys::spvc_compiler_get_binary_offset_for_decoration(
502                self.ptr.as_ptr(),
503                id,
504                SpvDecoration(decoration as u32 as i32),
505                &mut offset,
506            ) {
507                Ok(None)
508            } else {
509                Ok(Some(offset))
510            }
511        }
512    }
513
514    fn parse_decoration_value(
515        &self,
516        decoration: Decoration,
517        value: u32,
518    ) -> error::Result<Option<DecorationValue>> {
519        if decoration_is_literal(decoration) {
520            return Ok(Some(DecorationValue::Literal(value)));
521        }
522
523        // String is handled.
524        match decoration {
525            Decoration::BuiltIn => {
526                let Some(builtin) = spirv::BuiltIn::from_u32(value) else {
527                    return Err(SpirvCrossError::InvalidDecorationOutput(decoration, value));
528                };
529                Ok(Some(DecorationValue::BuiltIn(builtin)))
530            }
531            Decoration::FPRoundingMode => {
532                // https://github.com/KhronosGroup/SPIRV-Cross/blob/6a1fb66eef1bdca14acf7d0a51a3f883499d79f0/spirv_cross_parsed_ir.cpp#L730
533                if value as i32 == i32::MAX {
534                    return Ok(None);
535                }
536
537                let Some(rounding_mode) = spirv::FPRoundingMode::from_u32(value) else {
538                    return Err(SpirvCrossError::InvalidDecorationOutput(decoration, value));
539                };
540                Ok(Some(DecorationValue::RoundingMode(rounding_mode)))
541            }
542            Decoration::SpecId => unsafe {
543                Ok(Some(DecorationValue::Constant(
544                    self.create_handle(ConstantId(SpvId(value))),
545                )))
546            },
547            _ => {
548                if value == 1 {
549                    Ok(Some(DecorationValue::Present))
550                } else {
551                    Ok(None)
552                }
553            }
554        }
555    }
556
557    /// Get the decorations for a buffer block resource.
558    ///
559    /// If the variable handle is not a handle to with struct
560    /// base type, returns [`SpirvCrossError::InvalidArgument`].
561    pub fn buffer_block_decorations(
562        &self,
563        variable: impl Into<Handle<VariableId>>,
564    ) -> error::Result<Option<&[Decoration]>> {
565        let variable = variable.into();
566        let id = self.yield_id(variable)?;
567
568        unsafe {
569            let mut size = 0;
570            let mut buffer = std::ptr::null();
571            sys::spvc_compiler_get_buffer_block_decorations(
572                self.ptr.as_ptr(),
573                id,
574                &mut buffer,
575                &mut size,
576            )
577            .ok(self)?;
578
579            // SAFETY: 'ctx is sound here.
580            // https://github.com/KhronosGroup/SPIRV-Cross/blob/main/spirv_cross_c.cpp#L2790
581            let slice = super::try_valid_slice::<Decoration>(buffer.cast(), size)?;
582            if slice.is_empty() {
583                Ok(None)
584            } else {
585                Ok(Some(slice))
586            }
587        }
588    }
589}
590
591#[cfg(test)]
592mod test {
593    use crate::error::SpirvCrossError;
594    use crate::Compiler;
595
596    use crate::{targets, Module};
597
598    static BASIC_SPV: &[u8] = include_bytes!("../../basic.spv");
599
600    #[test]
601    pub fn set_decoration_test() -> Result<(), SpirvCrossError> {
602        let vec = Vec::from(BASIC_SPV);
603        let words = Module::from_words(bytemuck::cast_slice(&vec));
604
605        let compiler: Compiler<targets::None> = Compiler::new(words)?;
606        let resources = compiler.shader_resources()?.all_resources()?;
607
608        // compiler.set_decoration(Decoration::HlslSemanticGOOGLE, DecorationValue::String(Cow::Borrowed("hello")));
609        Ok(())
610    }
611}