1use thiserror::Error;
2use wgt::{BufferAddress, DynamicOffset};
3
4use alloc::{borrow::Cow, boxed::Box, sync::Arc, vec::Vec};
5use core::{fmt, str};
6
7use crate::ray_tracing::AsAction;
8use crate::{
9 binding_model::{
10 BindError, BindGroup, LateMinBufferBindingSizeMismatch, PushConstantUploadError,
11 },
12 command::{
13 bind::{Binder, BinderError},
14 compute_command::ArcComputeCommand,
15 end_pipeline_statistics_query,
16 memory_init::{
17 fixup_discarded_surfaces, CommandBufferTextureMemoryActions, SurfacesInDiscardState,
18 },
19 validate_and_begin_pipeline_statistics_query, ArcPassTimestampWrites, BasePass,
20 BindGroupStateChange, CommandBuffer, CommandEncoderError, MapPassErr, PassErrorScope,
21 PassTimestampWrites, QueryUseError, StateChange,
22 },
23 device::{Device, DeviceError, MissingDownlevelFlags, MissingFeatures},
24 global::Global,
25 hal_label, id,
26 init_tracker::{BufferInitTrackerAction, MemoryInitKind},
27 pipeline::ComputePipeline,
28 resource::{
29 self, Buffer, DestroyedResourceError, InvalidResourceError, Labeled,
30 MissingBufferUsageError, ParentDevice,
31 },
32 snatch::SnatchGuard,
33 track::{ResourceUsageCompatibilityError, Tracker, TrackerIndex, UsageScope},
34 Label,
35};
36
37pub struct ComputePass {
38 base: Option<BasePass<ArcComputeCommand>>,
43
44 parent: Option<Arc<CommandBuffer>>,
48
49 timestamp_writes: Option<ArcPassTimestampWrites>,
50
51 current_bind_groups: BindGroupStateChange,
53 current_pipeline: StateChange<id::ComputePipelineId>,
54}
55
56impl ComputePass {
57 fn new(parent: Option<Arc<CommandBuffer>>, desc: ArcComputePassDescriptor) -> Self {
59 let ArcComputePassDescriptor {
60 label,
61 timestamp_writes,
62 } = desc;
63
64 Self {
65 base: Some(BasePass::new(&label)),
66 parent,
67 timestamp_writes,
68
69 current_bind_groups: BindGroupStateChange::new(),
70 current_pipeline: StateChange::new(),
71 }
72 }
73
74 #[inline]
75 pub fn label(&self) -> Option<&str> {
76 self.base.as_ref().and_then(|base| base.label.as_deref())
77 }
78
79 fn base_mut<'a>(
80 &'a mut self,
81 scope: PassErrorScope,
82 ) -> Result<&'a mut BasePass<ArcComputeCommand>, ComputePassError> {
83 self.base
84 .as_mut()
85 .ok_or(ComputePassErrorInner::PassEnded)
86 .map_pass_err(scope)
87 }
88}
89
90impl fmt::Debug for ComputePass {
91 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
92 match self.parent {
93 Some(ref cmd_buf) => write!(f, "ComputePass {{ parent: {} }}", cmd_buf.error_ident()),
94 None => write!(f, "ComputePass {{ parent: None }}"),
95 }
96 }
97}
98
99#[derive(Clone, Debug, Default)]
100pub struct ComputePassDescriptor<'a, PTW = PassTimestampWrites> {
101 pub label: Label<'a>,
102 pub timestamp_writes: Option<PTW>,
104}
105
106type ArcComputePassDescriptor<'a> = ComputePassDescriptor<'a, ArcPassTimestampWrites>;
108
109#[derive(Clone, Debug, Error)]
110#[non_exhaustive]
111pub enum DispatchError {
112 #[error("Compute pipeline must be set")]
113 MissingPipeline,
114 #[error(transparent)]
115 IncompatibleBindGroup(#[from] Box<BinderError>),
116 #[error(
117 "Each current dispatch group size dimension ({current:?}) must be less or equal to {limit}"
118 )]
119 InvalidGroupSize { current: [u32; 3], limit: u32 },
120 #[error(transparent)]
121 BindingSizeTooSmall(#[from] LateMinBufferBindingSizeMismatch),
122}
123
124#[derive(Clone, Debug, Error)]
126pub enum ComputePassErrorInner {
127 #[error(transparent)]
128 Device(#[from] DeviceError),
129 #[error(transparent)]
130 Encoder(#[from] CommandEncoderError),
131 #[error("Parent encoder is invalid")]
132 InvalidParentEncoder,
133 #[error("Bind group index {index} is greater than the device's requested `max_bind_group` limit {max}")]
134 BindGroupIndexOutOfRange { index: u32, max: u32 },
135 #[error(transparent)]
136 DestroyedResource(#[from] DestroyedResourceError),
137 #[error("Indirect buffer offset {0:?} is not a multiple of 4")]
138 UnalignedIndirectBufferOffset(BufferAddress),
139 #[error("Indirect buffer uses bytes {offset}..{end_offset} which overruns indirect buffer of size {buffer_size}")]
140 IndirectBufferOverrun {
141 offset: u64,
142 end_offset: u64,
143 buffer_size: u64,
144 },
145 #[error(transparent)]
146 ResourceUsageCompatibility(#[from] ResourceUsageCompatibilityError),
147 #[error(transparent)]
148 MissingBufferUsage(#[from] MissingBufferUsageError),
149 #[error("Cannot pop debug group, because number of pushed debug groups is zero")]
150 InvalidPopDebugGroup,
151 #[error(transparent)]
152 Dispatch(#[from] DispatchError),
153 #[error(transparent)]
154 Bind(#[from] BindError),
155 #[error(transparent)]
156 PushConstants(#[from] PushConstantUploadError),
157 #[error("Push constant offset must be aligned to 4 bytes")]
158 PushConstantOffsetAlignment,
159 #[error("Push constant size must be aligned to 4 bytes")]
160 PushConstantSizeAlignment,
161 #[error("Ran out of push constant space. Don't set 4gb of push constants per ComputePass.")]
162 PushConstantOutOfMemory,
163 #[error(transparent)]
164 QueryUse(#[from] QueryUseError),
165 #[error(transparent)]
166 MissingFeatures(#[from] MissingFeatures),
167 #[error(transparent)]
168 MissingDownlevelFlags(#[from] MissingDownlevelFlags),
169 #[error("The compute pass has already been ended and no further commands can be recorded")]
170 PassEnded,
171 #[error(transparent)]
172 InvalidResource(#[from] InvalidResourceError),
173}
174
175#[derive(Clone, Debug, Error)]
177#[error("{scope}")]
178pub struct ComputePassError {
179 pub scope: PassErrorScope,
180 #[source]
181 pub(super) inner: ComputePassErrorInner,
182}
183
184impl<T, E> MapPassErr<T, ComputePassError> for Result<T, E>
185where
186 E: Into<ComputePassErrorInner>,
187{
188 fn map_pass_err(self, scope: PassErrorScope) -> Result<T, ComputePassError> {
189 self.map_err(|inner| ComputePassError {
190 scope,
191 inner: inner.into(),
192 })
193 }
194}
195
196struct State<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder> {
197 binder: Binder,
198 pipeline: Option<Arc<ComputePipeline>>,
199 scope: UsageScope<'scope>,
200 debug_scope_depth: u32,
201
202 snatch_guard: SnatchGuard<'snatch_guard>,
203
204 device: &'cmd_buf Arc<Device>,
205
206 raw_encoder: &'raw_encoder mut dyn hal::DynCommandEncoder,
207
208 tracker: &'cmd_buf mut Tracker,
209 buffer_memory_init_actions: &'cmd_buf mut Vec<BufferInitTrackerAction>,
210 texture_memory_actions: &'cmd_buf mut CommandBufferTextureMemoryActions,
211 as_actions: &'cmd_buf mut Vec<AsAction>,
212
213 temp_offsets: Vec<u32>,
214 dynamic_offset_count: usize,
215 string_offset: usize,
216 active_query: Option<(Arc<resource::QuerySet>, u32)>,
217
218 push_constants: Vec<u32>,
219
220 intermediate_trackers: Tracker,
221
222 pending_discard_init_fixups: SurfacesInDiscardState,
225}
226
227impl<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder>
228 State<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder>
229{
230 fn is_ready(&self) -> Result<(), DispatchError> {
231 if let Some(pipeline) = self.pipeline.as_ref() {
232 self.binder.check_compatibility(pipeline.as_ref())?;
233 self.binder.check_late_buffer_bindings()?;
234 Ok(())
235 } else {
236 Err(DispatchError::MissingPipeline)
237 }
238 }
239
240 fn flush_states(
243 &mut self,
244 indirect_buffer: Option<TrackerIndex>,
245 ) -> Result<(), ResourceUsageCompatibilityError> {
246 for bind_group in self.binder.list_active() {
247 unsafe { self.scope.merge_bind_group(&bind_group.used)? };
248 }
251
252 for bind_group in self.binder.list_active() {
253 unsafe {
254 self.intermediate_trackers
255 .set_and_remove_from_usage_scope_sparse(&mut self.scope, &bind_group.used)
256 }
257 }
258
259 unsafe {
261 self.intermediate_trackers
262 .buffers
263 .set_and_remove_from_usage_scope_sparse(&mut self.scope.buffers, indirect_buffer);
264 }
265
266 CommandBuffer::drain_barriers(
267 self.raw_encoder,
268 &mut self.intermediate_trackers,
269 &self.snatch_guard,
270 );
271 Ok(())
272 }
273}
274
275impl Global {
278 pub fn command_encoder_begin_compute_pass(
287 &self,
288 encoder_id: id::CommandEncoderId,
289 desc: &ComputePassDescriptor<'_>,
290 ) -> (ComputePass, Option<CommandEncoderError>) {
291 let hub = &self.hub;
292
293 let mut arc_desc = ArcComputePassDescriptor {
294 label: desc.label.as_deref().map(Cow::Borrowed),
295 timestamp_writes: None, };
297
298 let make_err = |e, arc_desc| (ComputePass::new(None, arc_desc), Some(e));
299
300 let cmd_buf = hub.command_buffers.get(encoder_id.into_command_buffer_id());
301
302 match cmd_buf.data.lock().lock_encoder() {
303 Ok(_) => {}
304 Err(e) => return make_err(e, arc_desc),
305 };
306
307 arc_desc.timestamp_writes = match desc
308 .timestamp_writes
309 .as_ref()
310 .map(|tw| {
311 Self::validate_pass_timestamp_writes(&cmd_buf.device, &hub.query_sets.read(), tw)
312 })
313 .transpose()
314 {
315 Ok(ok) => ok,
316 Err(e) => return make_err(e, arc_desc),
317 };
318
319 (ComputePass::new(Some(cmd_buf), arc_desc), None)
320 }
321
322 #[doc(hidden)]
325 #[cfg(any(feature = "serde", feature = "replay"))]
326 pub fn compute_pass_end_with_unresolved_commands(
327 &self,
328 encoder_id: id::CommandEncoderId,
329 base: BasePass<super::ComputeCommand>,
330 timestamp_writes: Option<&PassTimestampWrites>,
331 ) -> Result<(), ComputePassError> {
332 let pass_scope = PassErrorScope::Pass;
333
334 #[cfg(feature = "trace")]
335 {
336 let cmd_buf = self
337 .hub
338 .command_buffers
339 .get(encoder_id.into_command_buffer_id());
340 let mut cmd_buf_data = cmd_buf.data.lock();
341 let cmd_buf_data = cmd_buf_data.get_inner().map_pass_err(pass_scope)?;
342
343 if let Some(ref mut list) = cmd_buf_data.commands {
344 list.push(crate::device::trace::Command::RunComputePass {
345 base: BasePass {
346 label: base.label.clone(),
347 commands: base.commands.clone(),
348 dynamic_offsets: base.dynamic_offsets.clone(),
349 string_data: base.string_data.clone(),
350 push_constant_data: base.push_constant_data.clone(),
351 },
352 timestamp_writes: timestamp_writes.cloned(),
353 });
354 }
355 }
356
357 let BasePass {
358 label,
359 commands,
360 dynamic_offsets,
361 string_data,
362 push_constant_data,
363 } = base;
364
365 let (mut compute_pass, encoder_error) = self.command_encoder_begin_compute_pass(
366 encoder_id,
367 &ComputePassDescriptor {
368 label: label.as_deref().map(Cow::Borrowed),
369 timestamp_writes: timestamp_writes.cloned(),
370 },
371 );
372 if let Some(err) = encoder_error {
373 return Err(ComputePassError {
374 scope: pass_scope,
375 inner: err.into(),
376 });
377 };
378
379 compute_pass.base = Some(BasePass {
380 label,
381 commands: super::ComputeCommand::resolve_compute_command_ids(&self.hub, &commands)?,
382 dynamic_offsets,
383 string_data,
384 push_constant_data,
385 });
386
387 self.compute_pass_end(&mut compute_pass)
388 }
389
390 pub fn compute_pass_end(&self, pass: &mut ComputePass) -> Result<(), ComputePassError> {
391 profiling::scope!("CommandEncoder::run_compute_pass");
392 let pass_scope = PassErrorScope::Pass;
393
394 let cmd_buf = pass
395 .parent
396 .as_ref()
397 .ok_or(ComputePassErrorInner::InvalidParentEncoder)
398 .map_pass_err(pass_scope)?;
399
400 let base = pass
401 .base
402 .take()
403 .ok_or(ComputePassErrorInner::PassEnded)
404 .map_pass_err(pass_scope)?;
405
406 let device = &cmd_buf.device;
407 device.check_is_valid().map_pass_err(pass_scope)?;
408
409 let mut cmd_buf_data = cmd_buf.data.lock();
410 let mut cmd_buf_data_guard = cmd_buf_data.unlock_encoder().map_pass_err(pass_scope)?;
411 let cmd_buf_data = &mut *cmd_buf_data_guard;
412
413 let encoder = &mut cmd_buf_data.encoder;
414
415 encoder.close_if_open().map_pass_err(pass_scope)?;
419 let raw_encoder = encoder
420 .open_pass(base.label.as_deref())
421 .map_pass_err(pass_scope)?;
422
423 let mut state = State {
424 binder: Binder::new(),
425 pipeline: None,
426 scope: device.new_usage_scope(),
427 debug_scope_depth: 0,
428
429 snatch_guard: device.snatchable_lock.read(),
430
431 device,
432 raw_encoder,
433 tracker: &mut cmd_buf_data.trackers,
434 buffer_memory_init_actions: &mut cmd_buf_data.buffer_memory_init_actions,
435 texture_memory_actions: &mut cmd_buf_data.texture_memory_actions,
436 as_actions: &mut cmd_buf_data.as_actions,
437
438 temp_offsets: Vec::new(),
439 dynamic_offset_count: 0,
440 string_offset: 0,
441 active_query: None,
442
443 push_constants: Vec::new(),
444
445 intermediate_trackers: Tracker::new(),
446
447 pending_discard_init_fixups: SurfacesInDiscardState::new(),
448 };
449
450 let indices = &state.device.tracker_indices;
451 state.tracker.buffers.set_size(indices.buffers.size());
452 state.tracker.textures.set_size(indices.textures.size());
453
454 let timestamp_writes: Option<hal::PassTimestampWrites<'_, dyn hal::DynQuerySet>> =
455 if let Some(tw) = pass.timestamp_writes.take() {
456 tw.query_set
457 .same_device_as(cmd_buf.as_ref())
458 .map_pass_err(pass_scope)?;
459
460 let query_set = state.tracker.query_sets.insert_single(tw.query_set);
461
462 let range = if let (Some(index_a), Some(index_b)) =
465 (tw.beginning_of_pass_write_index, tw.end_of_pass_write_index)
466 {
467 Some(index_a.min(index_b)..index_a.max(index_b) + 1)
468 } else {
469 tw.beginning_of_pass_write_index
470 .or(tw.end_of_pass_write_index)
471 .map(|i| i..i + 1)
472 };
473 if let Some(range) = range {
476 unsafe {
477 state.raw_encoder.reset_queries(query_set.raw(), range);
478 }
479 }
480
481 Some(hal::PassTimestampWrites {
482 query_set: query_set.raw(),
483 beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
484 end_of_pass_write_index: tw.end_of_pass_write_index,
485 })
486 } else {
487 None
488 };
489
490 let hal_desc = hal::ComputePassDescriptor {
491 label: hal_label(base.label.as_deref(), device.instance_flags),
492 timestamp_writes,
493 };
494
495 unsafe {
496 state.raw_encoder.begin_compute_pass(&hal_desc);
497 }
498
499 for command in base.commands {
500 match command {
501 ArcComputeCommand::SetBindGroup {
502 index,
503 num_dynamic_offsets,
504 bind_group,
505 } => {
506 let scope = PassErrorScope::SetBindGroup;
507 set_bind_group(
508 &mut state,
509 cmd_buf,
510 &base.dynamic_offsets,
511 index,
512 num_dynamic_offsets,
513 bind_group,
514 )
515 .map_pass_err(scope)?;
516 }
517 ArcComputeCommand::SetPipeline(pipeline) => {
518 let scope = PassErrorScope::SetPipelineCompute;
519 set_pipeline(&mut state, cmd_buf, pipeline).map_pass_err(scope)?;
520 }
521 ArcComputeCommand::SetPushConstant {
522 offset,
523 size_bytes,
524 values_offset,
525 } => {
526 let scope = PassErrorScope::SetPushConstant;
527 set_push_constant(
528 &mut state,
529 &base.push_constant_data,
530 offset,
531 size_bytes,
532 values_offset,
533 )
534 .map_pass_err(scope)?;
535 }
536 ArcComputeCommand::Dispatch(groups) => {
537 let scope = PassErrorScope::Dispatch { indirect: false };
538 dispatch(&mut state, groups).map_pass_err(scope)?;
539 }
540 ArcComputeCommand::DispatchIndirect { buffer, offset } => {
541 let scope = PassErrorScope::Dispatch { indirect: true };
542 dispatch_indirect(&mut state, cmd_buf, buffer, offset).map_pass_err(scope)?;
543 }
544 ArcComputeCommand::PushDebugGroup { color: _, len } => {
545 push_debug_group(&mut state, &base.string_data, len);
546 }
547 ArcComputeCommand::PopDebugGroup => {
548 let scope = PassErrorScope::PopDebugGroup;
549 pop_debug_group(&mut state).map_pass_err(scope)?;
550 }
551 ArcComputeCommand::InsertDebugMarker { color: _, len } => {
552 insert_debug_marker(&mut state, &base.string_data, len);
553 }
554 ArcComputeCommand::WriteTimestamp {
555 query_set,
556 query_index,
557 } => {
558 let scope = PassErrorScope::WriteTimestamp;
559 write_timestamp(&mut state, cmd_buf, query_set, query_index)
560 .map_pass_err(scope)?;
561 }
562 ArcComputeCommand::BeginPipelineStatisticsQuery {
563 query_set,
564 query_index,
565 } => {
566 let scope = PassErrorScope::BeginPipelineStatisticsQuery;
567 validate_and_begin_pipeline_statistics_query(
568 query_set,
569 state.raw_encoder,
570 &mut state.tracker.query_sets,
571 cmd_buf,
572 query_index,
573 None,
574 &mut state.active_query,
575 )
576 .map_pass_err(scope)?;
577 }
578 ArcComputeCommand::EndPipelineStatisticsQuery => {
579 let scope = PassErrorScope::EndPipelineStatisticsQuery;
580 end_pipeline_statistics_query(state.raw_encoder, &mut state.active_query)
581 .map_pass_err(scope)?;
582 }
583 }
584 }
585
586 unsafe {
587 state.raw_encoder.end_compute_pass();
588 }
589
590 let State {
591 snatch_guard,
592 tracker,
593 intermediate_trackers,
594 pending_discard_init_fixups,
595 ..
596 } = state;
597
598 encoder.close().map_pass_err(pass_scope)?;
600
601 let transit = encoder
605 .open_pass(Some("(wgpu internal) Pre Pass"))
606 .map_pass_err(pass_scope)?;
607 fixup_discarded_surfaces(
608 pending_discard_init_fixups.into_iter(),
609 transit,
610 &mut tracker.textures,
611 device,
612 &snatch_guard,
613 );
614 CommandBuffer::insert_barriers_from_tracker(
615 transit,
616 tracker,
617 &intermediate_trackers,
618 &snatch_guard,
619 );
620 encoder.close_and_swap().map_pass_err(pass_scope)?;
622 cmd_buf_data_guard.mark_successful();
623
624 Ok(())
625 }
626}
627
628fn set_bind_group(
629 state: &mut State,
630 cmd_buf: &CommandBuffer,
631 dynamic_offsets: &[DynamicOffset],
632 index: u32,
633 num_dynamic_offsets: usize,
634 bind_group: Option<Arc<BindGroup>>,
635) -> Result<(), ComputePassErrorInner> {
636 let max_bind_groups = state.device.limits.max_bind_groups;
637 if index >= max_bind_groups {
638 return Err(ComputePassErrorInner::BindGroupIndexOutOfRange {
639 index,
640 max: max_bind_groups,
641 });
642 }
643
644 state.temp_offsets.clear();
645 state.temp_offsets.extend_from_slice(
646 &dynamic_offsets
647 [state.dynamic_offset_count..state.dynamic_offset_count + num_dynamic_offsets],
648 );
649 state.dynamic_offset_count += num_dynamic_offsets;
650
651 if bind_group.is_none() {
652 return Ok(());
654 }
655
656 let bind_group = bind_group.unwrap();
657 let bind_group = state.tracker.bind_groups.insert_single(bind_group);
658
659 bind_group.same_device_as(cmd_buf)?;
660
661 bind_group.validate_dynamic_bindings(index, &state.temp_offsets)?;
662
663 state
664 .buffer_memory_init_actions
665 .extend(bind_group.used_buffer_ranges.iter().filter_map(|action| {
666 action
667 .buffer
668 .initialization_status
669 .read()
670 .check_action(action)
671 }));
672
673 for action in bind_group.used_texture_ranges.iter() {
674 state
675 .pending_discard_init_fixups
676 .extend(state.texture_memory_actions.register_init_action(action));
677 }
678
679 let used_resource = bind_group
680 .used
681 .acceleration_structures
682 .into_iter()
683 .map(|tlas| AsAction::UseTlas(tlas.clone()));
684
685 state.as_actions.extend(used_resource);
686
687 let pipeline_layout = state.binder.pipeline_layout.clone();
688 let entries = state
689 .binder
690 .assign_group(index as usize, bind_group, &state.temp_offsets);
691 if !entries.is_empty() && pipeline_layout.is_some() {
692 let pipeline_layout = pipeline_layout.as_ref().unwrap().raw();
693 for (i, e) in entries.iter().enumerate() {
694 if let Some(group) = e.group.as_ref() {
695 let raw_bg = group.try_raw(&state.snatch_guard)?;
696 unsafe {
697 state.raw_encoder.set_bind_group(
698 pipeline_layout,
699 index + i as u32,
700 Some(raw_bg),
701 &e.dynamic_offsets,
702 );
703 }
704 }
705 }
706 }
707 Ok(())
708}
709
710fn set_pipeline(
711 state: &mut State,
712 cmd_buf: &CommandBuffer,
713 pipeline: Arc<ComputePipeline>,
714) -> Result<(), ComputePassErrorInner> {
715 pipeline.same_device_as(cmd_buf)?;
716
717 state.pipeline = Some(pipeline.clone());
718
719 let pipeline = state.tracker.compute_pipelines.insert_single(pipeline);
720
721 unsafe {
722 state.raw_encoder.set_compute_pipeline(pipeline.raw());
723 }
724
725 if state.binder.pipeline_layout.is_none()
727 || !state
728 .binder
729 .pipeline_layout
730 .as_ref()
731 .unwrap()
732 .is_equal(&pipeline.layout)
733 {
734 let (start_index, entries) = state
735 .binder
736 .change_pipeline_layout(&pipeline.layout, &pipeline.late_sized_buffer_groups);
737 if !entries.is_empty() {
738 for (i, e) in entries.iter().enumerate() {
739 if let Some(group) = e.group.as_ref() {
740 let raw_bg = group.try_raw(&state.snatch_guard)?;
741 unsafe {
742 state.raw_encoder.set_bind_group(
743 pipeline.layout.raw(),
744 start_index as u32 + i as u32,
745 Some(raw_bg),
746 &e.dynamic_offsets,
747 );
748 }
749 }
750 }
751 }
752
753 state.push_constants.clear();
755 if let Some(push_constant_range) =
757 pipeline.layout.push_constant_ranges.iter().find_map(|pcr| {
758 pcr.stages
759 .contains(wgt::ShaderStages::COMPUTE)
760 .then_some(pcr.range.clone())
761 })
762 {
763 let len = push_constant_range.len() / wgt::PUSH_CONSTANT_ALIGNMENT as usize;
765 state.push_constants.extend(core::iter::repeat_n(0, len));
766 }
767
768 let non_overlapping =
770 super::bind::compute_nonoverlapping_ranges(&pipeline.layout.push_constant_ranges);
771 for range in non_overlapping {
772 let offset = range.range.start;
773 let size_bytes = range.range.end - offset;
774 super::push_constant_clear(offset, size_bytes, |clear_offset, clear_data| unsafe {
775 state.raw_encoder.set_push_constants(
776 pipeline.layout.raw(),
777 wgt::ShaderStages::COMPUTE,
778 clear_offset,
779 clear_data,
780 );
781 });
782 }
783 }
784 Ok(())
785}
786
787fn set_push_constant(
788 state: &mut State,
789 push_constant_data: &[u32],
790 offset: u32,
791 size_bytes: u32,
792 values_offset: u32,
793) -> Result<(), ComputePassErrorInner> {
794 let end_offset_bytes = offset + size_bytes;
795 let values_end_offset = (values_offset + size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
796 let data_slice = &push_constant_data[(values_offset as usize)..values_end_offset];
797
798 let pipeline_layout = state
799 .binder
800 .pipeline_layout
801 .as_ref()
802 .ok_or(ComputePassErrorInner::Dispatch(
804 DispatchError::MissingPipeline,
805 ))?;
806
807 pipeline_layout.validate_push_constant_ranges(
808 wgt::ShaderStages::COMPUTE,
809 offset,
810 end_offset_bytes,
811 )?;
812
813 let offset_in_elements = (offset / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
814 let size_in_elements = (size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
815 state.push_constants[offset_in_elements..][..size_in_elements].copy_from_slice(data_slice);
816
817 unsafe {
818 state.raw_encoder.set_push_constants(
819 pipeline_layout.raw(),
820 wgt::ShaderStages::COMPUTE,
821 offset,
822 data_slice,
823 );
824 }
825 Ok(())
826}
827
828fn dispatch(state: &mut State, groups: [u32; 3]) -> Result<(), ComputePassErrorInner> {
829 state.is_ready()?;
830
831 state.flush_states(None)?;
832
833 let groups_size_limit = state.device.limits.max_compute_workgroups_per_dimension;
834
835 if groups[0] > groups_size_limit
836 || groups[1] > groups_size_limit
837 || groups[2] > groups_size_limit
838 {
839 return Err(ComputePassErrorInner::Dispatch(
840 DispatchError::InvalidGroupSize {
841 current: groups,
842 limit: groups_size_limit,
843 },
844 ));
845 }
846
847 unsafe {
848 state.raw_encoder.dispatch(groups);
849 }
850 Ok(())
851}
852
853fn dispatch_indirect(
854 state: &mut State,
855 cmd_buf: &CommandBuffer,
856 buffer: Arc<Buffer>,
857 offset: u64,
858) -> Result<(), ComputePassErrorInner> {
859 buffer.same_device_as(cmd_buf)?;
860
861 state.is_ready()?;
862
863 state
864 .device
865 .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?;
866
867 buffer.check_usage(wgt::BufferUsages::INDIRECT)?;
868
869 if offset % 4 != 0 {
870 return Err(ComputePassErrorInner::UnalignedIndirectBufferOffset(offset));
871 }
872
873 let end_offset = offset + size_of::<wgt::DispatchIndirectArgs>() as u64;
874 if end_offset > buffer.size {
875 return Err(ComputePassErrorInner::IndirectBufferOverrun {
876 offset,
877 end_offset,
878 buffer_size: buffer.size,
879 });
880 }
881
882 let stride = 3 * 4; state
884 .buffer_memory_init_actions
885 .extend(buffer.initialization_status.read().create_action(
886 &buffer,
887 offset..(offset + stride),
888 MemoryInitKind::NeedsInitializedMemory,
889 ));
890
891 if let Some(ref indirect_validation) = state.device.indirect_validation {
892 let params = indirect_validation
893 .dispatch
894 .params(&state.device.limits, offset, buffer.size);
895
896 unsafe {
897 state.raw_encoder.set_compute_pipeline(params.pipeline);
898 }
899
900 unsafe {
901 state.raw_encoder.set_push_constants(
902 params.pipeline_layout,
903 wgt::ShaderStages::COMPUTE,
904 0,
905 &[params.offset_remainder as u32 / 4],
906 );
907 }
908
909 unsafe {
910 state.raw_encoder.set_bind_group(
911 params.pipeline_layout,
912 0,
913 Some(params.dst_bind_group),
914 &[],
915 );
916 }
917 unsafe {
918 state.raw_encoder.set_bind_group(
919 params.pipeline_layout,
920 1,
921 Some(
922 buffer
923 .indirect_validation_bind_groups
924 .get(&state.snatch_guard)
925 .unwrap()
926 .dispatch
927 .as_ref(),
928 ),
929 &[params.aligned_offset as u32],
930 );
931 }
932
933 let src_transition = state
934 .intermediate_trackers
935 .buffers
936 .set_single(&buffer, wgt::BufferUses::STORAGE_READ_ONLY);
937 let src_barrier =
938 src_transition.map(|transition| transition.into_hal(&buffer, &state.snatch_guard));
939 unsafe {
940 state.raw_encoder.transition_buffers(src_barrier.as_slice());
941 }
942
943 unsafe {
944 state.raw_encoder.transition_buffers(&[hal::BufferBarrier {
945 buffer: params.dst_buffer,
946 usage: hal::StateTransition {
947 from: wgt::BufferUses::INDIRECT,
948 to: wgt::BufferUses::STORAGE_READ_WRITE,
949 },
950 }]);
951 }
952
953 unsafe {
954 state.raw_encoder.dispatch([1, 1, 1]);
955 }
956
957 {
959 let pipeline = state.pipeline.as_ref().unwrap();
960
961 unsafe {
962 state.raw_encoder.set_compute_pipeline(pipeline.raw());
963 }
964
965 if !state.push_constants.is_empty() {
966 unsafe {
967 state.raw_encoder.set_push_constants(
968 pipeline.layout.raw(),
969 wgt::ShaderStages::COMPUTE,
970 0,
971 &state.push_constants,
972 );
973 }
974 }
975
976 for (i, e) in state.binder.list_valid() {
977 let group = e.group.as_ref().unwrap();
978 let raw_bg = group.try_raw(&state.snatch_guard)?;
979 unsafe {
980 state.raw_encoder.set_bind_group(
981 pipeline.layout.raw(),
982 i as u32,
983 Some(raw_bg),
984 &e.dynamic_offsets,
985 );
986 }
987 }
988 }
989
990 unsafe {
991 state.raw_encoder.transition_buffers(&[hal::BufferBarrier {
992 buffer: params.dst_buffer,
993 usage: hal::StateTransition {
994 from: wgt::BufferUses::STORAGE_READ_WRITE,
995 to: wgt::BufferUses::INDIRECT,
996 },
997 }]);
998 }
999
1000 state.flush_states(None)?;
1001 unsafe {
1002 state.raw_encoder.dispatch_indirect(params.dst_buffer, 0);
1003 }
1004 } else {
1005 state
1006 .scope
1007 .buffers
1008 .merge_single(&buffer, wgt::BufferUses::INDIRECT)?;
1009
1010 use crate::resource::Trackable;
1011 state.flush_states(Some(buffer.tracker_index()))?;
1012
1013 let buf_raw = buffer.try_raw(&state.snatch_guard)?;
1014 unsafe {
1015 state.raw_encoder.dispatch_indirect(buf_raw, offset);
1016 }
1017 }
1018
1019 Ok(())
1020}
1021
1022fn push_debug_group(state: &mut State, string_data: &[u8], len: usize) {
1023 state.debug_scope_depth += 1;
1024 if !state
1025 .device
1026 .instance_flags
1027 .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
1028 {
1029 let label =
1030 str::from_utf8(&string_data[state.string_offset..state.string_offset + len]).unwrap();
1031 unsafe {
1032 state.raw_encoder.begin_debug_marker(label);
1033 }
1034 }
1035 state.string_offset += len;
1036}
1037
1038fn pop_debug_group(state: &mut State) -> Result<(), ComputePassErrorInner> {
1039 if state.debug_scope_depth == 0 {
1040 return Err(ComputePassErrorInner::InvalidPopDebugGroup);
1041 }
1042 state.debug_scope_depth -= 1;
1043 if !state
1044 .device
1045 .instance_flags
1046 .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
1047 {
1048 unsafe {
1049 state.raw_encoder.end_debug_marker();
1050 }
1051 }
1052 Ok(())
1053}
1054
1055fn insert_debug_marker(state: &mut State, string_data: &[u8], len: usize) {
1056 if !state
1057 .device
1058 .instance_flags
1059 .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
1060 {
1061 let label =
1062 str::from_utf8(&string_data[state.string_offset..state.string_offset + len]).unwrap();
1063 unsafe { state.raw_encoder.insert_debug_marker(label) }
1064 }
1065 state.string_offset += len;
1066}
1067
1068fn write_timestamp(
1069 state: &mut State,
1070 cmd_buf: &CommandBuffer,
1071 query_set: Arc<resource::QuerySet>,
1072 query_index: u32,
1073) -> Result<(), ComputePassErrorInner> {
1074 query_set.same_device_as(cmd_buf)?;
1075
1076 state
1077 .device
1078 .require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_PASSES)?;
1079
1080 let query_set = state.tracker.query_sets.insert_single(query_set);
1081
1082 query_set.validate_and_write_timestamp(state.raw_encoder, query_index, None)?;
1083 Ok(())
1084}
1085
1086impl Global {
1088 pub fn compute_pass_set_bind_group(
1089 &self,
1090 pass: &mut ComputePass,
1091 index: u32,
1092 bind_group_id: Option<id::BindGroupId>,
1093 offsets: &[DynamicOffset],
1094 ) -> Result<(), ComputePassError> {
1095 let scope = PassErrorScope::SetBindGroup;
1096 let base = pass
1097 .base
1098 .as_mut()
1099 .ok_or(ComputePassErrorInner::PassEnded)
1100 .map_pass_err(scope)?; let redundant = pass.current_bind_groups.set_and_check_redundant(
1103 bind_group_id,
1104 index,
1105 &mut base.dynamic_offsets,
1106 offsets,
1107 );
1108
1109 if redundant {
1110 return Ok(());
1111 }
1112
1113 let mut bind_group = None;
1114 if bind_group_id.is_some() {
1115 let bind_group_id = bind_group_id.unwrap();
1116
1117 let hub = &self.hub;
1118 let bg = hub
1119 .bind_groups
1120 .get(bind_group_id)
1121 .get()
1122 .map_pass_err(scope)?;
1123 bind_group = Some(bg);
1124 }
1125
1126 base.commands.push(ArcComputeCommand::SetBindGroup {
1127 index,
1128 num_dynamic_offsets: offsets.len(),
1129 bind_group,
1130 });
1131
1132 Ok(())
1133 }
1134
1135 pub fn compute_pass_set_pipeline(
1136 &self,
1137 pass: &mut ComputePass,
1138 pipeline_id: id::ComputePipelineId,
1139 ) -> Result<(), ComputePassError> {
1140 let redundant = pass.current_pipeline.set_and_check_redundant(pipeline_id);
1141
1142 let scope = PassErrorScope::SetPipelineCompute;
1143
1144 let base = pass.base_mut(scope)?;
1145 if redundant {
1146 return Ok(());
1148 }
1149
1150 let hub = &self.hub;
1151 let pipeline = hub
1152 .compute_pipelines
1153 .get(pipeline_id)
1154 .get()
1155 .map_pass_err(scope)?;
1156
1157 base.commands.push(ArcComputeCommand::SetPipeline(pipeline));
1158
1159 Ok(())
1160 }
1161
1162 pub fn compute_pass_set_push_constants(
1163 &self,
1164 pass: &mut ComputePass,
1165 offset: u32,
1166 data: &[u8],
1167 ) -> Result<(), ComputePassError> {
1168 let scope = PassErrorScope::SetPushConstant;
1169 let base = pass.base_mut(scope)?;
1170
1171 if offset & (wgt::PUSH_CONSTANT_ALIGNMENT - 1) != 0 {
1172 return Err(ComputePassErrorInner::PushConstantOffsetAlignment).map_pass_err(scope);
1173 }
1174
1175 if data.len() as u32 & (wgt::PUSH_CONSTANT_ALIGNMENT - 1) != 0 {
1176 return Err(ComputePassErrorInner::PushConstantSizeAlignment).map_pass_err(scope);
1177 }
1178 let value_offset = base
1179 .push_constant_data
1180 .len()
1181 .try_into()
1182 .map_err(|_| ComputePassErrorInner::PushConstantOutOfMemory)
1183 .map_pass_err(scope)?;
1184
1185 base.push_constant_data.extend(
1186 data.chunks_exact(wgt::PUSH_CONSTANT_ALIGNMENT as usize)
1187 .map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])),
1188 );
1189
1190 base.commands.push(ArcComputeCommand::SetPushConstant {
1191 offset,
1192 size_bytes: data.len() as u32,
1193 values_offset: value_offset,
1194 });
1195
1196 Ok(())
1197 }
1198
1199 pub fn compute_pass_dispatch_workgroups(
1200 &self,
1201 pass: &mut ComputePass,
1202 groups_x: u32,
1203 groups_y: u32,
1204 groups_z: u32,
1205 ) -> Result<(), ComputePassError> {
1206 let scope = PassErrorScope::Dispatch { indirect: false };
1207
1208 let base = pass.base_mut(scope)?;
1209 base.commands
1210 .push(ArcComputeCommand::Dispatch([groups_x, groups_y, groups_z]));
1211
1212 Ok(())
1213 }
1214
1215 pub fn compute_pass_dispatch_workgroups_indirect(
1216 &self,
1217 pass: &mut ComputePass,
1218 buffer_id: id::BufferId,
1219 offset: BufferAddress,
1220 ) -> Result<(), ComputePassError> {
1221 let hub = &self.hub;
1222 let scope = PassErrorScope::Dispatch { indirect: true };
1223 let base = pass.base_mut(scope)?;
1224
1225 let buffer = hub.buffers.get(buffer_id).get().map_pass_err(scope)?;
1226
1227 base.commands
1228 .push(ArcComputeCommand::DispatchIndirect { buffer, offset });
1229
1230 Ok(())
1231 }
1232
1233 pub fn compute_pass_push_debug_group(
1234 &self,
1235 pass: &mut ComputePass,
1236 label: &str,
1237 color: u32,
1238 ) -> Result<(), ComputePassError> {
1239 let base = pass.base_mut(PassErrorScope::PushDebugGroup)?;
1240
1241 let bytes = label.as_bytes();
1242 base.string_data.extend_from_slice(bytes);
1243
1244 base.commands.push(ArcComputeCommand::PushDebugGroup {
1245 color,
1246 len: bytes.len(),
1247 });
1248
1249 Ok(())
1250 }
1251
1252 pub fn compute_pass_pop_debug_group(
1253 &self,
1254 pass: &mut ComputePass,
1255 ) -> Result<(), ComputePassError> {
1256 let base = pass.base_mut(PassErrorScope::PopDebugGroup)?;
1257
1258 base.commands.push(ArcComputeCommand::PopDebugGroup);
1259
1260 Ok(())
1261 }
1262
1263 pub fn compute_pass_insert_debug_marker(
1264 &self,
1265 pass: &mut ComputePass,
1266 label: &str,
1267 color: u32,
1268 ) -> Result<(), ComputePassError> {
1269 let base = pass.base_mut(PassErrorScope::InsertDebugMarker)?;
1270
1271 let bytes = label.as_bytes();
1272 base.string_data.extend_from_slice(bytes);
1273
1274 base.commands.push(ArcComputeCommand::InsertDebugMarker {
1275 color,
1276 len: bytes.len(),
1277 });
1278
1279 Ok(())
1280 }
1281
1282 pub fn compute_pass_write_timestamp(
1283 &self,
1284 pass: &mut ComputePass,
1285 query_set_id: id::QuerySetId,
1286 query_index: u32,
1287 ) -> Result<(), ComputePassError> {
1288 let scope = PassErrorScope::WriteTimestamp;
1289 let base = pass.base_mut(scope)?;
1290
1291 let hub = &self.hub;
1292 let query_set = hub.query_sets.get(query_set_id).get().map_pass_err(scope)?;
1293
1294 base.commands.push(ArcComputeCommand::WriteTimestamp {
1295 query_set,
1296 query_index,
1297 });
1298
1299 Ok(())
1300 }
1301
1302 pub fn compute_pass_begin_pipeline_statistics_query(
1303 &self,
1304 pass: &mut ComputePass,
1305 query_set_id: id::QuerySetId,
1306 query_index: u32,
1307 ) -> Result<(), ComputePassError> {
1308 let scope = PassErrorScope::BeginPipelineStatisticsQuery;
1309 let base = pass.base_mut(scope)?;
1310
1311 let hub = &self.hub;
1312 let query_set = hub.query_sets.get(query_set_id).get().map_pass_err(scope)?;
1313
1314 base.commands
1315 .push(ArcComputeCommand::BeginPipelineStatisticsQuery {
1316 query_set,
1317 query_index,
1318 });
1319
1320 Ok(())
1321 }
1322
1323 pub fn compute_pass_end_pipeline_statistics_query(
1324 &self,
1325 pass: &mut ComputePass,
1326 ) -> Result<(), ComputePassError> {
1327 let scope = PassErrorScope::EndPipelineStatisticsQuery;
1328 let base = pass.base_mut(scope)?;
1329 base.commands
1330 .push(ArcComputeCommand::EndPipelineStatisticsQuery);
1331
1332 Ok(())
1333 }
1334}