Skip to main content

spirv_cross2/reflect/
decorations.rs

1use crate::error::{SpirvCrossError, ToContextError};
2use crate::handle::{ConstantId, Handle, Id, TypeId, VariableId};
3use crate::reflect::StructMember;
4use crate::sealed::Sealed;
5use crate::string::CompilerStr;
6use crate::Compiler;
7use crate::{error, ToStatic};
8use spirv::Decoration;
9use spirv_cross_sys as sys;
10use spirv_cross_sys::{SpvDecoration, SpvId};
11
12/// A value accompanying an `OpDecoration`
13#[derive(Debug, Eq, PartialEq)]
14pub enum DecorationValue<'a> {
15    /// Returned by the following decorations.
16    ///
17    /// - [`Location`](Decoration::Location).
18    /// - [`Component`](Decoration::Component).
19    /// - [`Offset`](Decoration::Offset).
20    /// - [`XfbBuffer`](Decoration::XfbBuffer).
21    /// - [`XfbStride`](Decoration::XfbStride).
22    /// - [`Stream`](Decoration::Stream).
23    /// - [`Binding`](Decoration::Binding).
24    /// - [`DescriptorSet`](Decoration::DescriptorSet).
25    /// - [`InputAttachmentIndex`](Decoration::InputAttachmentIndex).
26    /// - [`ArrayStride`](Decoration::ArrayStride).
27    /// - [`MatrixStride`](Decoration::MatrixStride).
28    /// - [`Index`](Decoration::Index).
29    Literal(u32),
30    /// Only for decoration [`BuiltIn`](Decoration::BuiltIn).
31    BuiltIn(spirv::BuiltIn),
32    /// Only for decoration [`FPRoundingMode`](Decoration::FPRoundingMode).
33    RoundingMode(spirv::FPRoundingMode),
34    /// Only for decoration [`SpecId`](Decoration::SpecId).
35    Constant(Handle<ConstantId>),
36    /// Only for decoration [`HlslSemanticGOOGLE`](Decoration::HlslSemanticGOOGLE) and [`UserTypeGOOGLE`](Decoration::HlslSemanticGOOGLE).
37    String(CompilerStr<'a>),
38    /// All other decorations to indicate the presence of a decoration.
39    Present,
40}
41
42impl DecorationValue<'_> {
43    /// Helper function to unset a decoration value, to be passed to
44    /// [`Compiler::set_decoration`].
45    pub const fn unset() -> Option<Self> {
46        None
47    }
48
49    /// Get the value if it is a literal `u32`.
50    pub fn as_literal(&self) -> Option<u32> {
51        match self {
52            Self::Literal(l) => Some(*l),
53            _ => None,
54        }
55    }
56}
57
58impl From<u32> for DecorationValue<'_> {
59    fn from(value: u32) -> Self {
60        DecorationValue::Literal(value)
61    }
62}
63
64impl From<()> for DecorationValue<'_> {
65    fn from(_value: ()) -> Self {
66        DecorationValue::Present
67    }
68}
69
70impl From<Handle<ConstantId>> for DecorationValue<'_> {
71    fn from(value: Handle<ConstantId>) -> Self {
72        DecorationValue::Constant(value)
73    }
74}
75
76impl<'a> From<&'a str> for DecorationValue<'a> {
77    fn from(value: &'a str) -> Self {
78        DecorationValue::String(CompilerStr::from_str(value))
79    }
80}
81
82impl From<String> for DecorationValue<'_> {
83    fn from(value: String) -> Self {
84        DecorationValue::String(CompilerStr::from_string(value))
85    }
86}
87
88impl<'a> From<CompilerStr<'a>> for DecorationValue<'a> {
89    fn from(value: CompilerStr<'a>) -> Self {
90        DecorationValue::String(value)
91    }
92}
93
94impl Sealed for DecorationValue<'_> {}
95impl ToStatic for DecorationValue<'_> {
96    type Static<'a>
97        = DecorationValue<'static>
98    where
99        'a: 'static;
100
101    fn to_static(&self) -> Self::Static<'static> {
102        match self {
103            DecorationValue::Literal(a) => DecorationValue::Literal(*a),
104            DecorationValue::BuiltIn(a) => DecorationValue::BuiltIn(*a),
105            DecorationValue::RoundingMode(a) => DecorationValue::RoundingMode(*a),
106            DecorationValue::Constant(a) => DecorationValue::Constant(*a),
107            DecorationValue::String(c) => {
108                let owned = c.to_string();
109                DecorationValue::String(CompilerStr::from_string(owned))
110            }
111            DecorationValue::Present => DecorationValue::Present,
112        }
113    }
114}
115
116impl<'a> Clone for DecorationValue<'a> {
117    fn clone(&self) -> DecorationValue<'static> {
118        self.to_static()
119    }
120}
121
122impl DecorationValue<'_> {
123    /// Check that the value is valid for the decoration type.
124    pub fn type_is_valid_for_decoration(&self, decoration: spirv::Decoration) -> bool {
125        match self {
126            DecorationValue::Literal(_) => decoration_is_literal(decoration),
127            DecorationValue::BuiltIn(_) => decoration == Decoration::BuiltIn,
128            DecorationValue::RoundingMode(_) => decoration == Decoration::FPRoundingMode,
129            DecorationValue::Constant(_) => decoration == Decoration::SpecId,
130            DecorationValue::String(_) => decoration_is_string(decoration),
131            DecorationValue::Present => {
132                !decoration_is_literal(decoration)
133                    && !decoration_is_string(decoration)
134                    && decoration != Decoration::BuiltIn
135                    && decoration != Decoration::FPRoundingMode
136                    && decoration != Decoration::SpecId
137            }
138        }
139    }
140}
141fn decoration_is_literal(decoration: spirv::Decoration) -> bool {
142    matches!(
143        decoration,
144        Decoration::Location
145            | Decoration::Component
146            | Decoration::Offset
147            | Decoration::XfbBuffer
148            | Decoration::XfbStride
149            | Decoration::Stream
150            | Decoration::Binding
151            | Decoration::DescriptorSet
152            | Decoration::InputAttachmentIndex
153            | Decoration::ArrayStride
154            | Decoration::MatrixStride
155            | Decoration::Index
156    )
157}
158
159fn decoration_is_string(decoration: Decoration) -> bool {
160    matches!(
161        decoration,
162        Decoration::HlslSemanticGOOGLE | Decoration::UserTypeGOOGLE
163    )
164}
165
166impl<T> Compiler<T> {
167    /// Gets the value for decorations which take arguments.
168    pub fn decoration<I: Id>(
169        &self,
170        id: Handle<I>,
171        decoration: Decoration,
172    ) -> error::Result<Option<DecorationValue<'_>>> {
173        // SAFETY: 'ctx is not sound to return here!
174        //  https://github.com/KhronosGroup/SPIRV-Cross/blob/6a1fb66eef1bdca14acf7d0a51a3f883499d79f0/spirv_cross_c.cpp#L2154
175
176        // SAFETY: id is yielded by the instance so it's safe to use.
177        let id = SpvId(self.yield_id(id)?.id());
178        unsafe {
179            let has_decoration = sys::spvc_compiler_has_decoration(
180                self.ptr.as_ptr(),
181                id,
182                SpvDecoration(decoration as u32 as i32),
183            );
184            if !has_decoration {
185                return Ok(None);
186            };
187
188            if decoration_is_string(decoration) {
189                let str = sys::spvc_compiler_get_decoration_string(
190                    self.ptr.as_ptr(),
191                    id,
192                    SpvDecoration(decoration as u32 as i32),
193                );
194                return Ok(Some(DecorationValue::String(CompilerStr::from_ptr(
195                    str,
196                    self.ctx.drop_guard(),
197                ))));
198            }
199
200            let value = sys::spvc_compiler_get_decoration(
201                self.ptr.as_ptr(),
202                id,
203                SpvDecoration(decoration as u32 as i32),
204            );
205            self.parse_decoration_value(decoration, value)
206        }
207    }
208
209    /// Gets the value for member decorations which take arguments.
210    pub fn member_decoration_by_handle(
211        &self,
212        struct_type_id: Handle<TypeId>,
213        index: u32,
214        decoration: Decoration,
215    ) -> error::Result<Option<DecorationValue<'_>>> {
216        // SAFETY: id is yielded by the instance so it's safe to use.
217        let struct_type = self.yield_id(struct_type_id)?;
218
219        unsafe {
220            let has_decoration = sys::spvc_compiler_has_member_decoration(
221                self.ptr.as_ptr(),
222                struct_type,
223                index,
224                SpvDecoration(decoration as u32 as i32),
225            );
226            if !has_decoration {
227                return Ok(None);
228            };
229
230            if decoration_is_string(decoration) {
231                let str = sys::spvc_compiler_get_member_decoration_string(
232                    self.ptr.as_ptr(),
233                    struct_type,
234                    index,
235                    SpvDecoration(decoration as u32 as i32),
236                );
237                return Ok(Some(DecorationValue::String(CompilerStr::from_ptr(
238                    str,
239                    self.ctx.drop_guard(),
240                ))));
241            }
242
243            let value = sys::spvc_compiler_get_member_decoration(
244                self.ptr.as_ptr(),
245                struct_type,
246                index,
247                SpvDecoration(decoration as u32 as i32),
248            );
249            self.parse_decoration_value(decoration, value)
250        }
251    }
252
253    /// Gets the value for member decorations which take arguments.
254    pub fn member_decoration(
255        &self,
256        member: &StructMember,
257        decoration: Decoration,
258    ) -> error::Result<Option<DecorationValue<'_>>> {
259        self.member_decoration_by_handle(member.struct_type, member.index as u32, decoration)
260    }
261
262    /// Set the value of a decoration for an ID.
263    pub fn set_decoration<'value, I: Id>(
264        &mut self,
265        id: Handle<I>,
266        decoration: spirv::Decoration,
267        value: Option<impl Into<DecorationValue<'value>>>,
268    ) -> error::Result<()> {
269        // SAFETY: id is yielded by the instance so it's safe to use.
270        let id = SpvId(self.yield_id(id)?.id());
271        unsafe {
272            let Some(value) = value else {
273                sys::spvc_compiler_unset_decoration(
274                    self.ptr.as_ptr(),
275                    id,
276                    SpvDecoration(decoration as u32 as i32),
277                );
278                return Ok(());
279            };
280
281            let value = value.into();
282
283            if !value.type_is_valid_for_decoration(decoration) {
284                return Err(SpirvCrossError::InvalidDecorationInput(
285                    decoration,
286                    DecorationValue::to_static(&value),
287                ));
288            }
289
290            match value {
291                DecorationValue::Literal(literal) => {
292                    sys::spvc_compiler_set_decoration(
293                        self.ptr.as_ptr(),
294                        id,
295                        SpvDecoration(decoration as u32 as i32),
296                        literal,
297                    );
298                }
299                DecorationValue::BuiltIn(builtin) => {
300                    sys::spvc_compiler_set_decoration(
301                        self.ptr.as_ptr(),
302                        id,
303                        SpvDecoration(decoration as u32 as i32),
304                        builtin as u32,
305                    );
306                }
307                DecorationValue::RoundingMode(rounding_mode) => {
308                    sys::spvc_compiler_set_decoration(
309                        self.ptr.as_ptr(),
310                        id,
311                        SpvDecoration(decoration as u32 as i32),
312                        rounding_mode as u32,
313                    );
314                }
315                DecorationValue::Constant(constant) => {
316                    let constant = self.yield_id(constant)?;
317                    sys::spvc_compiler_set_decoration(
318                        self.ptr.as_ptr(),
319                        id,
320                        SpvDecoration(decoration as u32 as i32),
321                        constant.id(),
322                    );
323                }
324                DecorationValue::Present => {
325                    sys::spvc_compiler_set_decoration(
326                        self.ptr.as_ptr(),
327                        id,
328                        SpvDecoration(decoration as u32 as i32),
329                        1,
330                    );
331                }
332                DecorationValue::String(string) => {
333                    let cstring = string.into_cstring_ptr().map_err(|e| {
334                        let SpirvCrossError::InvalidString(string) = e else {
335                            unreachable!("into_cstring_ptr only errors InvalidString")
336                        };
337                        SpirvCrossError::InvalidDecorationInput(
338                            decoration,
339                            DecorationValue::String(string.into()),
340                        )
341                    })?;
342
343                    sys::spvc_compiler_set_decoration_string(
344                        self.ptr.as_ptr(),
345                        id,
346                        SpvDecoration(decoration as u32 as i32),
347                        cstring.as_ptr(),
348                    );
349
350                    // Sanity drop to show that the lifetime of the cstring is only up until
351                    // we have returned. AFAIK, SPIRV-Cross will do a string copy.
352                    // If it does not, then we'll have to keep this string alive for a while.
353                    drop(cstring);
354                }
355            }
356        }
357        Ok(())
358    }
359
360    /// Set the value of a decoration for a struct member.
361    pub fn set_member_decoration<'value>(
362        &mut self,
363        member: &StructMember,
364        decoration: Decoration,
365        value: Option<impl Into<DecorationValue<'value>>>,
366    ) -> error::Result<()> {
367        self.set_member_decoration_by_handle(
368            member.struct_type,
369            member.index as u32,
370            decoration,
371            value,
372        )
373    }
374
375    /// Set the value of a decoration for a struct member by the handle of its parent struct
376    /// and the index.
377    pub fn set_member_decoration_by_handle<'value>(
378        &mut self,
379        struct_type: Handle<TypeId>,
380        index: u32,
381        decoration: Decoration,
382        value: Option<impl Into<DecorationValue<'value>>>,
383    ) -> error::Result<()> {
384        // SAFETY: id is yielded by the instance so it's safe to use.
385        let struct_type = self.yield_id(struct_type)?;
386
387        unsafe {
388            let Some(value) = value else {
389                sys::spvc_compiler_unset_member_decoration(
390                    self.ptr.as_ptr(),
391                    struct_type,
392                    index,
393                    SpvDecoration(decoration as u32 as i32),
394                );
395                return Ok(());
396            };
397
398            let value = value.into();
399
400            if !value.type_is_valid_for_decoration(decoration) {
401                return Err(SpirvCrossError::InvalidDecorationInput(
402                    decoration,
403                    DecorationValue::to_static(&value),
404                ));
405            }
406
407            match value {
408                DecorationValue::Literal(literal) => {
409                    sys::spvc_compiler_set_member_decoration(
410                        self.ptr.as_ptr(),
411                        struct_type,
412                        index,
413                        SpvDecoration(decoration as u32 as i32),
414                        literal,
415                    );
416                }
417                DecorationValue::BuiltIn(builtin) => {
418                    sys::spvc_compiler_set_member_decoration(
419                        self.ptr.as_ptr(),
420                        struct_type,
421                        index,
422                        SpvDecoration(decoration as u32 as i32),
423                        builtin as u32,
424                    );
425                }
426                DecorationValue::RoundingMode(rounding_mode) => {
427                    sys::spvc_compiler_set_member_decoration(
428                        self.ptr.as_ptr(),
429                        struct_type,
430                        index,
431                        SpvDecoration(decoration as u32 as i32),
432                        rounding_mode as u32,
433                    );
434                }
435                DecorationValue::Constant(constant) => {
436                    let constant = self.yield_id(constant)?;
437                    sys::spvc_compiler_set_member_decoration(
438                        self.ptr.as_ptr(),
439                        struct_type,
440                        index,
441                        SpvDecoration(decoration as u32 as i32),
442                        constant.id(),
443                    );
444                }
445                DecorationValue::Present => {
446                    sys::spvc_compiler_set_member_decoration(
447                        self.ptr.as_ptr(),
448                        struct_type,
449                        index,
450                        SpvDecoration(decoration as u32 as i32),
451                        1,
452                    );
453                }
454                DecorationValue::String(string) => {
455                    let cstring = string.into_cstring_ptr().map_err(|e| {
456                        let SpirvCrossError::InvalidString(string) = e else {
457                            unreachable!("into_cstring_ptr only errors InvalidString")
458                        };
459                        SpirvCrossError::InvalidDecorationInput(
460                            decoration,
461                            DecorationValue::String(string.into()),
462                        )
463                    })?;
464
465                    sys::spvc_compiler_set_member_decoration_string(
466                        self.ptr.as_ptr(),
467                        struct_type,
468                        index,
469                        SpvDecoration(decoration as u32 as i32),
470                        cstring.as_ptr(),
471                    );
472
473                    // Sanity drop to show that the lifetime of the cstring is only up until
474                    // we have returned. AFAIK, SPIRV-Cross will do a string copy.
475                    // If it does not, then we'll have to keep this string alive for a while.
476                    drop(cstring);
477                }
478            }
479        }
480        Ok(())
481    }
482
483    /// Gets the offset in SPIR-V words (uint32_t) for a decoration which was originally declared in the SPIR-V binary.
484    /// The offset will point to one or more uint32_t literals which can be modified in-place before using the SPIR-V binary.
485    ///
486    /// Note that adding or removing decorations using the reflection API will not change the behavior of this function.
487    /// If the decoration was declared, returns an offset into the provided SPIR-V binary buffer,
488    /// otherwise returns None.
489    ///
490    /// If the decoration does not have any value attached to it (e.g. DecorationRelaxedPrecision), this function will also return None.
491    pub fn binary_offset_for_decoration(
492        &self,
493        variable: impl Into<Handle<VariableId>>,
494        decoration: Decoration,
495    ) -> error::Result<Option<u32>> {
496        let id = self.yield_id(variable.into())?;
497
498        unsafe {
499            let mut offset = 0;
500            if !sys::spvc_compiler_get_binary_offset_for_decoration(
501                self.ptr.as_ptr(),
502                id,
503                SpvDecoration(decoration as u32 as i32),
504                &mut offset,
505            ) {
506                Ok(None)
507            } else {
508                Ok(Some(offset))
509            }
510        }
511    }
512
513    fn parse_decoration_value(
514        &self,
515        decoration: Decoration,
516        value: u32,
517    ) -> error::Result<Option<DecorationValue<'_>>> {
518        if decoration_is_literal(decoration) {
519            return Ok(Some(DecorationValue::Literal(value)));
520        }
521
522        // String is handled.
523        match decoration {
524            Decoration::BuiltIn => {
525                let Some(builtin) = spirv::BuiltIn::from_u32(value) else {
526                    return Err(SpirvCrossError::InvalidDecorationOutput(decoration, value));
527                };
528                Ok(Some(DecorationValue::BuiltIn(builtin)))
529            }
530            Decoration::FPRoundingMode => {
531                // https://github.com/KhronosGroup/SPIRV-Cross/blob/6a1fb66eef1bdca14acf7d0a51a3f883499d79f0/spirv_cross_parsed_ir.cpp#L730
532                if value as i32 == i32::MAX {
533                    return Ok(None);
534                }
535
536                let Some(rounding_mode) = spirv::FPRoundingMode::from_u32(value) else {
537                    return Err(SpirvCrossError::InvalidDecorationOutput(decoration, value));
538                };
539                Ok(Some(DecorationValue::RoundingMode(rounding_mode)))
540            }
541            Decoration::SpecId => unsafe {
542                Ok(Some(DecorationValue::Constant(
543                    self.create_handle(ConstantId(SpvId(value))),
544                )))
545            },
546            _ => {
547                if value == 1 {
548                    Ok(Some(DecorationValue::Present))
549                } else {
550                    Ok(None)
551                }
552            }
553        }
554    }
555
556    /// Get the decorations for a buffer block resource.
557    ///
558    /// If the variable handle is not a handle to with struct
559    /// base type, returns [`SpirvCrossError::InvalidArgument`].
560    pub fn buffer_block_decorations(
561        &self,
562        variable: impl Into<Handle<VariableId>>,
563    ) -> error::Result<Option<&[Decoration]>> {
564        let variable = variable.into();
565        let id = self.yield_id(variable)?;
566
567        unsafe {
568            let mut size = 0;
569            let mut buffer = std::ptr::null();
570            sys::spvc_compiler_get_buffer_block_decorations(
571                self.ptr.as_ptr(),
572                id,
573                &mut buffer,
574                &mut size,
575            )
576            .ok(self)?;
577
578            // SAFETY: 'ctx is sound here.
579            // https://github.com/KhronosGroup/SPIRV-Cross/blob/main/spirv_cross_c.cpp#L2790
580            let slice = super::try_valid_slice::<Decoration>(buffer.cast(), size)?;
581            if slice.is_empty() {
582                Ok(None)
583            } else {
584                Ok(Some(slice))
585            }
586        }
587    }
588}
589
590#[cfg(test)]
591mod test {
592    use crate::error::SpirvCrossError;
593    use crate::Compiler;
594
595    use crate::{targets, Module};
596
597    static BASIC_SPV: &[u8] = include_bytes!("../../basic.spv");
598
599    #[test]
600    pub fn set_decoration_test() -> Result<(), SpirvCrossError> {
601        let vec = Vec::from(BASIC_SPV);
602        let words = Module::from_words(bytemuck::cast_slice(&vec));
603
604        let compiler: Compiler<targets::None> = Compiler::new(words)?;
605        let resources = compiler.shader_resources()?.all_resources()?;
606
607        // compiler.set_decoration(Decoration::HlslSemanticGOOGLE, DecorationValue::String(Cow::Borrowed("hello")));
608        Ok(())
609    }
610}