spirv_cross2/reflect/
execution_modes.rs

1use crate::error;
2use crate::error::ToContextError;
3use crate::handle::Handle;
4use crate::reflect::try_valid_slice;
5use crate::Compiler;
6use spirv_cross_sys as sys;
7use spirv_cross_sys::{ConstantId, SpvExecutionMode};
8
9/// Arguments to an `OpExecutionMode`.
10#[derive(Debug)]
11pub enum ExecutionModeArguments {
12    /// No arguments.
13    ///
14    /// This is also used to set execution modes for modes that don't have arguments.
15    None,
16    /// A single literal argument.
17    Literal(u32),
18    /// Arguments to `LocalSize` execution mode.
19    LocalSize {
20        /// Workgroup size x.
21        x: u32,
22        /// Workgroup size y.
23        y: u32,
24        /// Workgroup size z.
25        z: u32,
26    },
27    /// Arguments to `LocalSizeId` execution mode.
28    LocalSizeId {
29        /// Workgroup size x ID.
30        x: Handle<ConstantId>,
31        /// Workgroup size y ID.
32        y: Handle<ConstantId>,
33        /// Workgroup size z ID.
34        z: Handle<ConstantId>,
35    },
36}
37
38impl ExecutionModeArguments {
39    fn expand(self) -> [u32; 3] {
40        match self {
41            ExecutionModeArguments::None => [0, 0, 0],
42            ExecutionModeArguments::Literal(a) => [a, 0, 0],
43            ExecutionModeArguments::LocalSize { x, y, z } => [x, y, z],
44            ExecutionModeArguments::LocalSizeId { x, y, z } => [x.id(), y.id(), z.id()],
45        }
46    }
47}
48
49impl<T> Compiler<T> {
50    /// Set or unset execution modes and arguments.
51    ///
52    /// If arguments is `None`, unsets the execution mode. To set an execution mode that does not
53    /// take arguments, pass `Some(ExecutionModeArguments::None)`.
54    pub fn set_execution_mode(
55        &mut self,
56        mode: spirv::ExecutionMode,
57        arguments: Option<ExecutionModeArguments>,
58    ) {
59        unsafe {
60            let Some(arguments) = arguments else {
61                return sys::spvc_compiler_unset_execution_mode(
62                    self.ptr.as_ptr(),
63                    SpvExecutionMode(mode as u32 as i32),
64                );
65            };
66
67            let [x, y, z] = arguments.expand();
68
69            sys::spvc_compiler_set_execution_mode_with_arguments(
70                self.ptr.as_ptr(),
71                SpvExecutionMode(mode as u32 as i32),
72                x,
73                y,
74                z,
75            );
76        }
77    }
78
79    /// Query `OpExecutionMode`.
80    pub fn execution_modes(&self) -> error::Result<&[spirv::ExecutionMode]> {
81        unsafe {
82            let mut size = 0;
83            let mut modes = std::ptr::null();
84
85            sys::spvc_compiler_get_execution_modes(self.ptr.as_ptr(), &mut modes, &mut size)
86                .ok(self)?;
87
88            // SAFETY: 'ctx is sound here.
89            // https://github.com/KhronosGroup/SPIRV-Cross/blob/main/spirv_cross_c.cpp#L2250
90
91            const _: () =
92                assert!(std::mem::size_of::<spirv::ExecutionMode>() == std::mem::size_of::<u32>());
93            try_valid_slice(modes.cast(), size)
94        }
95    }
96
97    /// Get arguments used by the execution mode.
98    ///
99    /// If the execution mode is unused, returns `None`.
100    ///
101    /// LocalSizeId query returns an ID. If LocalSizeId execution mode is not used, it returns None.
102    /// LocalSize always returns a literal. If execution mode is LocalSizeId, the literal (spec constant or not) is still returned.
103    pub fn execution_mode_arguments(
104        &self,
105        mode: spirv::ExecutionMode,
106    ) -> error::Result<Option<ExecutionModeArguments>> {
107        Ok(match mode {
108            spirv::ExecutionMode::LocalSize => unsafe {
109                let x = sys::spvc_compiler_get_execution_mode_argument_by_index(
110                    self.ptr.as_ptr(),
111                    SpvExecutionMode(mode as u32 as i32),
112                    0,
113                );
114                let y = sys::spvc_compiler_get_execution_mode_argument_by_index(
115                    self.ptr.as_ptr(),
116                    SpvExecutionMode(mode as u32 as i32),
117                    1,
118                );
119                let z = sys::spvc_compiler_get_execution_mode_argument_by_index(
120                    self.ptr.as_ptr(),
121                    SpvExecutionMode(mode as u32 as i32),
122                    2,
123                );
124
125                if x * y * z == 0 {
126                    None
127                } else {
128                    Some(ExecutionModeArguments::LocalSize { x, y, z })
129                }
130            },
131            spirv::ExecutionMode::LocalSizeId => unsafe {
132                let x = sys::spvc_compiler_get_execution_mode_argument_by_index(
133                    self.ptr.as_ptr(),
134                    SpvExecutionMode(mode as u32 as i32),
135                    0,
136                );
137                let y = sys::spvc_compiler_get_execution_mode_argument_by_index(
138                    self.ptr.as_ptr(),
139                    SpvExecutionMode(mode as u32 as i32),
140                    1,
141                );
142                let z = sys::spvc_compiler_get_execution_mode_argument_by_index(
143                    self.ptr.as_ptr(),
144                    SpvExecutionMode(mode as u32 as i32),
145                    2,
146                );
147
148                if x * y * z == 0 {
149                    // If one is zero, then all are zero.
150                    None
151                } else {
152                    Some(ExecutionModeArguments::LocalSizeId {
153                        x: self.create_handle(ConstantId::from(x)),
154                        y: self.create_handle(ConstantId::from(y)),
155                        z: self.create_handle(ConstantId::from(z)),
156                    })
157                }
158            },
159            spirv::ExecutionMode::Invocations
160            | spirv::ExecutionMode::OutputVertices
161            | spirv::ExecutionMode::OutputPrimitivesEXT => unsafe {
162                if !self.execution_modes()?.contains(&mode) {
163                    return Ok(None);
164                };
165
166                let x = sys::spvc_compiler_get_execution_mode_argument_by_index(
167                    self.ptr.as_ptr(),
168                    SpvExecutionMode(mode as u32 as i32),
169                    0,
170                );
171                Some(ExecutionModeArguments::Literal(x))
172            },
173            _ => {
174                if !self.execution_modes()?.contains(&mode) {
175                    return Ok(None);
176                };
177
178                Some(ExecutionModeArguments::None)
179            }
180        })
181    }
182}
183
184#[cfg(test)]
185mod test {
186    use crate::error::SpirvCrossError;
187    use crate::Compiler;
188    use crate::{targets, Module};
189
190    static BASIC_SPV: &[u8] = include_bytes!("../../basic.spv");
191
192    #[test]
193    pub fn execution_modes() -> Result<(), SpirvCrossError> {
194        let vec = Vec::from(BASIC_SPV);
195        let words = Module::from_words(bytemuck::cast_slice(&vec));
196
197        let compiler: Compiler<targets::None> = Compiler::new(words)?;
198        let resources = compiler.shader_resources()?.all_resources()?;
199
200        let ty = compiler.execution_modes()?;
201        assert_eq!([spirv::ExecutionMode::OriginUpperLeft], ty);
202
203        Ok(())
204    }
205}