1use alloc::sync::Arc;
2
3use crate::{
4 binding_model::BindGroup,
5 id,
6 pipeline::ComputePipeline,
7 resource::{Buffer, QuerySet},
8};
9
10#[derive(Clone, Copy, Debug)]
11#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
12pub enum ComputeCommand {
13 SetBindGroup {
14 index: u32,
15 num_dynamic_offsets: usize,
16 bind_group_id: Option<id::BindGroupId>,
17 },
18
19 SetPipeline(id::ComputePipelineId),
20
21 SetPushConstant {
23 offset: u32,
26
27 size_bytes: u32,
29
30 values_offset: u32,
36 },
37
38 Dispatch([u32; 3]),
39
40 DispatchIndirect {
41 buffer_id: id::BufferId,
42 offset: wgt::BufferAddress,
43 },
44
45 PushDebugGroup {
46 color: u32,
47 len: usize,
48 },
49
50 PopDebugGroup,
51
52 InsertDebugMarker {
53 color: u32,
54 len: usize,
55 },
56
57 WriteTimestamp {
58 query_set_id: id::QuerySetId,
59 query_index: u32,
60 },
61
62 BeginPipelineStatisticsQuery {
63 query_set_id: id::QuerySetId,
64 query_index: u32,
65 },
66
67 EndPipelineStatisticsQuery,
68}
69
70impl ComputeCommand {
71 #[cfg(any(feature = "serde", feature = "replay"))]
73 pub fn resolve_compute_command_ids(
74 hub: &crate::hub::Hub,
75 commands: &[ComputeCommand],
76 ) -> Result<alloc::vec::Vec<ArcComputeCommand>, super::ComputePassError> {
77 use super::{ComputePassError, PassErrorScope};
78 use alloc::vec::Vec;
79
80 let buffers_guard = hub.buffers.read();
81 let bind_group_guard = hub.bind_groups.read();
82 let query_set_guard = hub.query_sets.read();
83 let pipelines_guard = hub.compute_pipelines.read();
84
85 let resolved_commands: Vec<ArcComputeCommand> = commands
86 .iter()
87 .map(|c| -> Result<ArcComputeCommand, ComputePassError> {
88 Ok(match *c {
89 ComputeCommand::SetBindGroup {
90 index,
91 num_dynamic_offsets,
92 bind_group_id,
93 } => {
94 if bind_group_id.is_none() {
95 return Ok(ArcComputeCommand::SetBindGroup {
96 index,
97 num_dynamic_offsets,
98 bind_group: None,
99 });
100 }
101
102 let bind_group_id = bind_group_id.unwrap();
103 let bg = bind_group_guard.get(bind_group_id).get().map_err(|e| {
104 ComputePassError {
105 scope: PassErrorScope::SetBindGroup,
106 inner: e.into(),
107 }
108 })?;
109
110 ArcComputeCommand::SetBindGroup {
111 index,
112 num_dynamic_offsets,
113 bind_group: Some(bg),
114 }
115 }
116 ComputeCommand::SetPipeline(pipeline_id) => ArcComputeCommand::SetPipeline(
117 pipelines_guard
118 .get(pipeline_id)
119 .get()
120 .map_err(|e| ComputePassError {
121 scope: PassErrorScope::SetPipelineCompute,
122 inner: e.into(),
123 })?,
124 ),
125
126 ComputeCommand::SetPushConstant {
127 offset,
128 size_bytes,
129 values_offset,
130 } => ArcComputeCommand::SetPushConstant {
131 offset,
132 size_bytes,
133 values_offset,
134 },
135
136 ComputeCommand::Dispatch(dim) => ArcComputeCommand::Dispatch(dim),
137
138 ComputeCommand::DispatchIndirect { buffer_id, offset } => {
139 ArcComputeCommand::DispatchIndirect {
140 buffer: buffers_guard.get(buffer_id).get().map_err(|e| {
141 ComputePassError {
142 scope: PassErrorScope::Dispatch { indirect: true },
143 inner: e.into(),
144 }
145 })?,
146 offset,
147 }
148 }
149
150 ComputeCommand::PushDebugGroup { color, len } => {
151 ArcComputeCommand::PushDebugGroup { color, len }
152 }
153
154 ComputeCommand::PopDebugGroup => ArcComputeCommand::PopDebugGroup,
155
156 ComputeCommand::InsertDebugMarker { color, len } => {
157 ArcComputeCommand::InsertDebugMarker { color, len }
158 }
159
160 ComputeCommand::WriteTimestamp {
161 query_set_id,
162 query_index,
163 } => ArcComputeCommand::WriteTimestamp {
164 query_set: query_set_guard.get(query_set_id).get().map_err(|e| {
165 ComputePassError {
166 scope: PassErrorScope::WriteTimestamp,
167 inner: e.into(),
168 }
169 })?,
170 query_index,
171 },
172
173 ComputeCommand::BeginPipelineStatisticsQuery {
174 query_set_id,
175 query_index,
176 } => ArcComputeCommand::BeginPipelineStatisticsQuery {
177 query_set: query_set_guard.get(query_set_id).get().map_err(|e| {
178 ComputePassError {
179 scope: PassErrorScope::BeginPipelineStatisticsQuery,
180 inner: e.into(),
181 }
182 })?,
183 query_index,
184 },
185
186 ComputeCommand::EndPipelineStatisticsQuery => {
187 ArcComputeCommand::EndPipelineStatisticsQuery
188 }
189 })
190 })
191 .collect::<Result<Vec<_>, ComputePassError>>()?;
192 Ok(resolved_commands)
193 }
194}
195
196#[derive(Clone, Debug)]
198pub enum ArcComputeCommand {
199 SetBindGroup {
200 index: u32,
201 num_dynamic_offsets: usize,
202 bind_group: Option<Arc<BindGroup>>,
203 },
204
205 SetPipeline(Arc<ComputePipeline>),
206
207 SetPushConstant {
209 offset: u32,
212
213 size_bytes: u32,
215
216 values_offset: u32,
222 },
223
224 Dispatch([u32; 3]),
225
226 DispatchIndirect {
227 buffer: Arc<Buffer>,
228 offset: wgt::BufferAddress,
229 },
230
231 PushDebugGroup {
232 #[cfg_attr(not(any(feature = "serde", feature = "replay")), allow(dead_code))]
233 color: u32,
234 len: usize,
235 },
236
237 PopDebugGroup,
238
239 InsertDebugMarker {
240 #[cfg_attr(not(any(feature = "serde", feature = "replay")), allow(dead_code))]
241 color: u32,
242 len: usize,
243 },
244
245 WriteTimestamp {
246 query_set: Arc<QuerySet>,
247 query_index: u32,
248 },
249
250 BeginPipelineStatisticsQuery {
251 query_set: Arc<QuerySet>,
252 query_index: u32,
253 },
254
255 EndPipelineStatisticsQuery,
256}