1use thiserror::Error;
2use wgt::{
3 error::{ErrorType, WebGpuError},
4 BufferAddress, DynamicOffset,
5};
6
7use alloc::{borrow::Cow, boxed::Box, sync::Arc, vec::Vec};
8use core::{convert::Infallible, fmt, str};
9
10use crate::{
11 api_log,
12 binding_model::{BindError, ImmediateUploadError, LateMinBufferBindingSizeMismatch},
13 command::{
14 bind::{Binder, BinderError},
15 compute_command::ArcComputeCommand,
16 encoder::EncodingState,
17 memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState},
18 pass::{self, flush_bindings_helper},
19 pass_base, pass_try,
20 query::{end_pipeline_statistics_query, validate_and_begin_pipeline_statistics_query},
21 ArcCommand, ArcPassTimestampWrites, BasePass, BindGroupStateChange, CommandEncoder,
22 CommandEncoderError, DebugGroupError, EncoderStateError, InnerCommandEncoder, MapPassErr,
23 PassErrorScope, PassStateError, PassTimestampWrites, QueryUseError, StateChange,
24 TimestampWritesError,
25 },
26 device::{Device, DeviceError, MissingDownlevelFlags, MissingFeatures},
27 global::Global,
28 hal_label, id,
29 init_tracker::MemoryInitKind,
30 pipeline::ComputePipeline,
31 resource::{
32 self, Buffer, DestroyedResourceError, InvalidResourceError, Labeled,
33 MissingBufferUsageError, ParentDevice, RawResourceAccess, Trackable,
34 },
35 track::{ResourceUsageCompatibilityError, Tracker},
36 Label,
37};
38
39pub type ComputeBasePass = BasePass<ArcComputeCommand, ComputePassError>;
40
41pub struct ComputePass {
49 base: ComputeBasePass,
51
52 parent: Option<Arc<CommandEncoder>>,
58
59 timestamp_writes: Option<ArcPassTimestampWrites>,
60
61 current_bind_groups: BindGroupStateChange,
63 current_pipeline: StateChange<id::ComputePipelineId>,
64}
65
66impl ComputePass {
67 fn new(parent: Arc<CommandEncoder>, desc: ArcComputePassDescriptor) -> Self {
69 let ArcComputePassDescriptor {
70 label,
71 timestamp_writes,
72 } = desc;
73
74 Self {
75 base: BasePass::new(&label),
76 parent: Some(parent),
77 timestamp_writes,
78
79 current_bind_groups: BindGroupStateChange::new(),
80 current_pipeline: StateChange::new(),
81 }
82 }
83
84 fn new_invalid(parent: Arc<CommandEncoder>, label: &Label, err: ComputePassError) -> Self {
85 Self {
86 base: BasePass::new_invalid(label, err),
87 parent: Some(parent),
88 timestamp_writes: None,
89 current_bind_groups: BindGroupStateChange::new(),
90 current_pipeline: StateChange::new(),
91 }
92 }
93
94 #[inline]
95 pub fn label(&self) -> Option<&str> {
96 self.base.label.as_deref()
97 }
98}
99
100impl fmt::Debug for ComputePass {
101 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102 match self.parent {
103 Some(ref cmd_enc) => write!(f, "ComputePass {{ parent: {} }}", cmd_enc.error_ident()),
104 None => write!(f, "ComputePass {{ parent: None }}"),
105 }
106 }
107}
108
109#[derive(Clone, Debug, Default)]
110pub struct ComputePassDescriptor<'a, PTW = PassTimestampWrites> {
111 pub label: Label<'a>,
112 pub timestamp_writes: Option<PTW>,
114}
115
116type ArcComputePassDescriptor<'a> = ComputePassDescriptor<'a, ArcPassTimestampWrites>;
118
119#[derive(Clone, Debug, Error)]
120#[non_exhaustive]
121pub enum DispatchError {
122 #[error("Compute pipeline must be set")]
123 MissingPipeline(pass::MissingPipeline),
124 #[error(transparent)]
125 IncompatibleBindGroup(#[from] Box<BinderError>),
126 #[error(
127 "Each current dispatch group size dimension ({current:?}) must be less or equal to {limit}"
128 )]
129 InvalidGroupSize { current: [u32; 3], limit: u32 },
130 #[error(transparent)]
131 BindingSizeTooSmall(#[from] LateMinBufferBindingSizeMismatch),
132}
133
134impl WebGpuError for DispatchError {
135 fn webgpu_error_type(&self) -> ErrorType {
136 ErrorType::Validation
137 }
138}
139
140#[derive(Clone, Debug, Error)]
142pub enum ComputePassErrorInner {
143 #[error(transparent)]
144 Device(#[from] DeviceError),
145 #[error(transparent)]
146 EncoderState(#[from] EncoderStateError),
147 #[error("Parent encoder is invalid")]
148 InvalidParentEncoder,
149 #[error(transparent)]
150 DebugGroupError(#[from] DebugGroupError),
151 #[error(transparent)]
152 BindGroupIndexOutOfRange(#[from] pass::BindGroupIndexOutOfRange),
153 #[error(transparent)]
154 DestroyedResource(#[from] DestroyedResourceError),
155 #[error("Indirect buffer offset {0:?} is not a multiple of 4")]
156 UnalignedIndirectBufferOffset(BufferAddress),
157 #[error("Indirect buffer uses bytes {offset}..{end_offset} which overruns indirect buffer of size {buffer_size}")]
158 IndirectBufferOverrun {
159 offset: u64,
160 end_offset: u64,
161 buffer_size: u64,
162 },
163 #[error(transparent)]
164 ResourceUsageCompatibility(#[from] ResourceUsageCompatibilityError),
165 #[error(transparent)]
166 MissingBufferUsage(#[from] MissingBufferUsageError),
167 #[error(transparent)]
168 Dispatch(#[from] DispatchError),
169 #[error(transparent)]
170 Bind(#[from] BindError),
171 #[error(transparent)]
172 ImmediateData(#[from] ImmediateUploadError),
173 #[error("Immediate data offset must be aligned to 4 bytes")]
174 ImmediateOffsetAlignment,
175 #[error("Immediate data size must be aligned to 4 bytes")]
176 ImmediateDataizeAlignment,
177 #[error("Ran out of immediate data space. Don't set 4gb of immediates per ComputePass.")]
178 ImmediateOutOfMemory,
179 #[error(transparent)]
180 QueryUse(#[from] QueryUseError),
181 #[error(transparent)]
182 MissingFeatures(#[from] MissingFeatures),
183 #[error(transparent)]
184 MissingDownlevelFlags(#[from] MissingDownlevelFlags),
185 #[error("The compute pass has already been ended and no further commands can be recorded")]
186 PassEnded,
187 #[error(transparent)]
188 InvalidResource(#[from] InvalidResourceError),
189 #[error(transparent)]
190 TimestampWrites(#[from] TimestampWritesError),
191 #[error(transparent)]
193 InvalidValuesOffset(#[from] pass::InvalidValuesOffset),
194}
195
196#[derive(Clone, Debug, Error)]
199#[error("{scope}")]
200pub struct ComputePassError {
201 pub scope: PassErrorScope,
202 #[source]
203 pub(super) inner: ComputePassErrorInner,
204}
205
206impl From<pass::MissingPipeline> for ComputePassErrorInner {
207 fn from(value: pass::MissingPipeline) -> Self {
208 Self::Dispatch(DispatchError::MissingPipeline(value))
209 }
210}
211
212impl<E> MapPassErr<ComputePassError> for E
213where
214 E: Into<ComputePassErrorInner>,
215{
216 fn map_pass_err(self, scope: PassErrorScope) -> ComputePassError {
217 ComputePassError {
218 scope,
219 inner: self.into(),
220 }
221 }
222}
223
224impl WebGpuError for ComputePassError {
225 fn webgpu_error_type(&self) -> ErrorType {
226 let Self { scope: _, inner } = self;
227 match inner {
228 ComputePassErrorInner::Device(e) => e.webgpu_error_type(),
229 ComputePassErrorInner::EncoderState(e) => e.webgpu_error_type(),
230 ComputePassErrorInner::DebugGroupError(e) => e.webgpu_error_type(),
231 ComputePassErrorInner::DestroyedResource(e) => e.webgpu_error_type(),
232 ComputePassErrorInner::ResourceUsageCompatibility(e) => e.webgpu_error_type(),
233 ComputePassErrorInner::MissingBufferUsage(e) => e.webgpu_error_type(),
234 ComputePassErrorInner::Dispatch(e) => e.webgpu_error_type(),
235 ComputePassErrorInner::Bind(e) => e.webgpu_error_type(),
236 ComputePassErrorInner::ImmediateData(e) => e.webgpu_error_type(),
237 ComputePassErrorInner::QueryUse(e) => e.webgpu_error_type(),
238 ComputePassErrorInner::MissingFeatures(e) => e.webgpu_error_type(),
239 ComputePassErrorInner::MissingDownlevelFlags(e) => e.webgpu_error_type(),
240 ComputePassErrorInner::InvalidResource(e) => e.webgpu_error_type(),
241 ComputePassErrorInner::TimestampWrites(e) => e.webgpu_error_type(),
242 ComputePassErrorInner::InvalidValuesOffset(e) => e.webgpu_error_type(),
243
244 ComputePassErrorInner::InvalidParentEncoder
245 | ComputePassErrorInner::BindGroupIndexOutOfRange { .. }
246 | ComputePassErrorInner::UnalignedIndirectBufferOffset(_)
247 | ComputePassErrorInner::IndirectBufferOverrun { .. }
248 | ComputePassErrorInner::ImmediateOffsetAlignment
249 | ComputePassErrorInner::ImmediateDataizeAlignment
250 | ComputePassErrorInner::ImmediateOutOfMemory
251 | ComputePassErrorInner::PassEnded => ErrorType::Validation,
252 }
253 }
254}
255
256struct State<'scope, 'snatch_guard, 'cmd_enc> {
257 pipeline: Option<Arc<ComputePipeline>>,
258
259 pass: pass::PassState<'scope, 'snatch_guard, 'cmd_enc>,
260
261 active_query: Option<(Arc<resource::QuerySet>, u32)>,
262
263 immediates: Vec<u32>,
264
265 intermediate_trackers: Tracker,
266}
267
268impl<'scope, 'snatch_guard, 'cmd_enc> State<'scope, 'snatch_guard, 'cmd_enc> {
269 fn is_ready(&self) -> Result<(), DispatchError> {
270 if let Some(pipeline) = self.pipeline.as_ref() {
271 self.pass.binder.check_compatibility(pipeline.as_ref())?;
272 self.pass.binder.check_late_buffer_bindings()?;
273 Ok(())
274 } else {
275 Err(DispatchError::MissingPipeline(pass::MissingPipeline))
276 }
277 }
278
279 fn flush_bindings(
318 &mut self,
319 indirect_buffer: Option<&Arc<Buffer>>,
320 track_indirect_buffer: bool,
321 ) -> Result<(), ComputePassErrorInner> {
322 for bind_group in self.pass.binder.list_active() {
323 unsafe { self.pass.scope.merge_bind_group(&bind_group.used)? };
324 }
325
326 if let Some(buffer) = indirect_buffer {
330 self.pass
331 .scope
332 .buffers
333 .merge_single(buffer, wgt::BufferUses::INDIRECT)?;
334 }
335
336 for bind_group in self.pass.binder.list_active() {
342 self.intermediate_trackers
343 .set_and_remove_from_usage_scope_sparse(&mut self.pass.scope, &bind_group.used);
344 }
345
346 if track_indirect_buffer {
347 self.intermediate_trackers
348 .buffers
349 .set_and_remove_from_usage_scope_sparse(
350 &mut self.pass.scope.buffers,
351 indirect_buffer.map(|buf| buf.tracker_index()),
352 );
353 } else if let Some(buffer) = indirect_buffer {
354 self.pass
355 .scope
356 .buffers
357 .remove_usage(buffer, wgt::BufferUses::INDIRECT);
358 }
359
360 flush_bindings_helper(&mut self.pass)?;
361
362 CommandEncoder::drain_barriers(
363 self.pass.base.raw_encoder,
364 &mut self.intermediate_trackers,
365 self.pass.base.snatch_guard,
366 );
367 Ok(())
368 }
369}
370
371impl Global {
374 pub fn command_encoder_begin_compute_pass(
385 &self,
386 encoder_id: id::CommandEncoderId,
387 desc: &ComputePassDescriptor<'_>,
388 ) -> (ComputePass, Option<CommandEncoderError>) {
389 use EncoderStateError as SErr;
390
391 let scope = PassErrorScope::Pass;
392 let hub = &self.hub;
393
394 let label = desc.label.as_deref().map(Cow::Borrowed);
395
396 let cmd_enc = hub.command_encoders.get(encoder_id);
397 let mut cmd_buf_data = cmd_enc.data.lock();
398
399 match cmd_buf_data.lock_encoder() {
400 Ok(()) => {
401 drop(cmd_buf_data);
402 if let Err(err) = cmd_enc.device.check_is_valid() {
403 return (
404 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
405 None,
406 );
407 }
408
409 match desc
410 .timestamp_writes
411 .as_ref()
412 .map(|tw| {
413 Self::validate_pass_timestamp_writes::<ComputePassErrorInner>(
414 &cmd_enc.device,
415 &hub.query_sets.read(),
416 tw,
417 )
418 })
419 .transpose()
420 {
421 Ok(timestamp_writes) => {
422 let arc_desc = ArcComputePassDescriptor {
423 label,
424 timestamp_writes,
425 };
426 (ComputePass::new(cmd_enc, arc_desc), None)
427 }
428 Err(err) => (
429 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
430 None,
431 ),
432 }
433 }
434 Err(err @ SErr::Locked) => {
435 cmd_buf_data.invalidate(err.clone());
439 drop(cmd_buf_data);
440 (
441 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
442 None,
443 )
444 }
445 Err(err @ (SErr::Ended | SErr::Submitted)) => {
446 drop(cmd_buf_data);
449 (
450 ComputePass::new_invalid(cmd_enc, &label, err.clone().map_pass_err(scope)),
451 Some(err.into()),
452 )
453 }
454 Err(err @ SErr::Invalid) => {
455 drop(cmd_buf_data);
461 (
462 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
463 None,
464 )
465 }
466 Err(SErr::Unlocked) => {
467 unreachable!("lock_encoder cannot fail due to the encoder being unlocked")
468 }
469 }
470 }
471
472 pub fn compute_pass_end(&self, pass: &mut ComputePass) -> Result<(), EncoderStateError> {
473 profiling::scope!(
474 "CommandEncoder::run_compute_pass {}",
475 pass.base.label.as_deref().unwrap_or("")
476 );
477
478 let cmd_enc = pass.parent.take().ok_or(EncoderStateError::Ended)?;
479 let mut cmd_buf_data = cmd_enc.data.lock();
480
481 cmd_buf_data.unlock_encoder()?;
482
483 let base = pass.base.take();
484
485 if let Err(ComputePassError {
486 inner:
487 ComputePassErrorInner::EncoderState(
488 err @ (EncoderStateError::Locked | EncoderStateError::Ended),
489 ),
490 scope: _,
491 }) = base
492 {
493 return Err(err.clone());
500 }
501
502 cmd_buf_data.push_with(|| -> Result<_, ComputePassError> {
503 Ok(ArcCommand::RunComputePass {
504 pass: base?,
505 timestamp_writes: pass.timestamp_writes.take(),
506 })
507 })
508 }
509}
510
511pub(super) fn encode_compute_pass(
512 parent_state: &mut EncodingState<InnerCommandEncoder>,
513 mut base: BasePass<ArcComputeCommand, Infallible>,
514 mut timestamp_writes: Option<ArcPassTimestampWrites>,
515) -> Result<(), ComputePassError> {
516 let pass_scope = PassErrorScope::Pass;
517
518 let device = parent_state.device;
519
520 parent_state
524 .raw_encoder
525 .close_if_open()
526 .map_pass_err(pass_scope)?;
527 let raw_encoder = parent_state
528 .raw_encoder
529 .open_pass(base.label.as_deref())
530 .map_pass_err(pass_scope)?;
531
532 let mut debug_scope_depth = 0;
533
534 let mut state = State {
535 pipeline: None,
536
537 pass: pass::PassState {
538 base: EncodingState {
539 device,
540 raw_encoder,
541 tracker: parent_state.tracker,
542 buffer_memory_init_actions: parent_state.buffer_memory_init_actions,
543 texture_memory_actions: parent_state.texture_memory_actions,
544 as_actions: parent_state.as_actions,
545 temp_resources: parent_state.temp_resources,
546 indirect_draw_validation_resources: parent_state.indirect_draw_validation_resources,
547 snatch_guard: parent_state.snatch_guard,
548 debug_scope_depth: &mut debug_scope_depth,
549 },
550 binder: Binder::new(),
551 temp_offsets: Vec::new(),
552 dynamic_offset_count: 0,
553 pending_discard_init_fixups: SurfacesInDiscardState::new(),
554 scope: device.new_usage_scope(),
555 string_offset: 0,
556 },
557 active_query: None,
558
559 immediates: Vec::new(),
560
561 intermediate_trackers: Tracker::new(
562 device.ordered_buffer_usages,
563 device.ordered_texture_usages,
564 ),
565 };
566
567 let indices = &device.tracker_indices;
568 state
569 .pass
570 .base
571 .tracker
572 .buffers
573 .set_size(indices.buffers.size());
574 state
575 .pass
576 .base
577 .tracker
578 .textures
579 .set_size(indices.textures.size());
580
581 let timestamp_writes: Option<hal::PassTimestampWrites<'_, dyn hal::DynQuerySet>> =
582 if let Some(tw) = timestamp_writes.take() {
583 tw.query_set.same_device(device).map_pass_err(pass_scope)?;
584
585 let query_set = state
586 .pass
587 .base
588 .tracker
589 .query_sets
590 .insert_single(tw.query_set);
591
592 let range = if let (Some(index_a), Some(index_b)) =
595 (tw.beginning_of_pass_write_index, tw.end_of_pass_write_index)
596 {
597 Some(index_a.min(index_b)..index_a.max(index_b) + 1)
598 } else {
599 tw.beginning_of_pass_write_index
600 .or(tw.end_of_pass_write_index)
601 .map(|i| i..i + 1)
602 };
603 if let Some(range) = range {
606 unsafe {
607 state
608 .pass
609 .base
610 .raw_encoder
611 .reset_queries(query_set.raw(), range);
612 }
613 }
614
615 Some(hal::PassTimestampWrites {
616 query_set: query_set.raw(),
617 beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
618 end_of_pass_write_index: tw.end_of_pass_write_index,
619 })
620 } else {
621 None
622 };
623
624 let hal_desc = hal::ComputePassDescriptor {
625 label: hal_label(base.label.as_deref(), device.instance_flags),
626 timestamp_writes,
627 };
628
629 unsafe {
630 state.pass.base.raw_encoder.begin_compute_pass(&hal_desc);
631 }
632
633 for command in base.commands.drain(..) {
634 match command {
635 ArcComputeCommand::SetBindGroup {
636 index,
637 num_dynamic_offsets,
638 bind_group,
639 } => {
640 let scope = PassErrorScope::SetBindGroup;
641 pass::set_bind_group::<ComputePassErrorInner>(
642 &mut state.pass,
643 device,
644 &base.dynamic_offsets,
645 index,
646 num_dynamic_offsets,
647 bind_group,
648 false,
649 )
650 .map_pass_err(scope)?;
651 }
652 ArcComputeCommand::SetPipeline(pipeline) => {
653 let scope = PassErrorScope::SetPipelineCompute;
654 set_pipeline(&mut state, device, pipeline).map_pass_err(scope)?;
655 }
656 ArcComputeCommand::SetImmediate {
657 offset,
658 size_bytes,
659 values_offset,
660 } => {
661 let scope = PassErrorScope::SetImmediate;
662 pass::set_immediates::<ComputePassErrorInner, _>(
663 &mut state.pass,
664 &base.immediates_data,
665 offset,
666 size_bytes,
667 Some(values_offset),
668 |data_slice| {
669 let offset_in_elements = (offset / wgt::IMMEDIATE_DATA_ALIGNMENT) as usize;
670 let size_in_elements =
671 (size_bytes / wgt::IMMEDIATE_DATA_ALIGNMENT) as usize;
672 state.immediates[offset_in_elements..][..size_in_elements]
673 .copy_from_slice(data_slice);
674 },
675 )
676 .map_pass_err(scope)?;
677 }
678 ArcComputeCommand::Dispatch(groups) => {
679 let scope = PassErrorScope::Dispatch { indirect: false };
680 dispatch(&mut state, groups).map_pass_err(scope)?;
681 }
682 ArcComputeCommand::DispatchIndirect { buffer, offset } => {
683 let scope = PassErrorScope::Dispatch { indirect: true };
684 dispatch_indirect(&mut state, device, buffer, offset).map_pass_err(scope)?;
685 }
686 ArcComputeCommand::PushDebugGroup { color: _, len } => {
687 pass::push_debug_group(&mut state.pass, &base.string_data, len);
688 }
689 ArcComputeCommand::PopDebugGroup => {
690 let scope = PassErrorScope::PopDebugGroup;
691 pass::pop_debug_group::<ComputePassErrorInner>(&mut state.pass)
692 .map_pass_err(scope)?;
693 }
694 ArcComputeCommand::InsertDebugMarker { color: _, len } => {
695 pass::insert_debug_marker(&mut state.pass, &base.string_data, len);
696 }
697 ArcComputeCommand::WriteTimestamp {
698 query_set,
699 query_index,
700 } => {
701 let scope = PassErrorScope::WriteTimestamp;
702 pass::write_timestamp::<ComputePassErrorInner>(
703 &mut state.pass,
704 device,
705 None, query_set,
707 query_index,
708 )
709 .map_pass_err(scope)?;
710 }
711 ArcComputeCommand::BeginPipelineStatisticsQuery {
712 query_set,
713 query_index,
714 } => {
715 let scope = PassErrorScope::BeginPipelineStatisticsQuery;
716 validate_and_begin_pipeline_statistics_query(
717 query_set,
718 state.pass.base.raw_encoder,
719 &mut state.pass.base.tracker.query_sets,
720 device,
721 query_index,
722 None,
723 &mut state.active_query,
724 )
725 .map_pass_err(scope)?;
726 }
727 ArcComputeCommand::EndPipelineStatisticsQuery => {
728 let scope = PassErrorScope::EndPipelineStatisticsQuery;
729 end_pipeline_statistics_query(state.pass.base.raw_encoder, &mut state.active_query)
730 .map_pass_err(scope)?;
731 }
732 }
733 }
734
735 if *state.pass.base.debug_scope_depth > 0 {
736 Err(
737 ComputePassErrorInner::DebugGroupError(DebugGroupError::MissingPop)
738 .map_pass_err(pass_scope),
739 )?;
740 }
741
742 unsafe {
743 state.pass.base.raw_encoder.end_compute_pass();
744 }
745
746 let State {
747 pass: pass::PassState {
748 pending_discard_init_fixups,
749 ..
750 },
751 intermediate_trackers,
752 ..
753 } = state;
754
755 parent_state.raw_encoder.close().map_pass_err(pass_scope)?;
757
758 let transit = parent_state
762 .raw_encoder
763 .open_pass(hal_label(
764 Some("(wgpu internal) Pre Pass"),
765 device.instance_flags,
766 ))
767 .map_pass_err(pass_scope)?;
768 fixup_discarded_surfaces(
769 pending_discard_init_fixups.into_iter(),
770 transit,
771 &mut parent_state.tracker.textures,
772 device,
773 parent_state.snatch_guard,
774 );
775 CommandEncoder::insert_barriers_from_tracker(
776 transit,
777 parent_state.tracker,
778 &intermediate_trackers,
779 parent_state.snatch_guard,
780 );
781 parent_state
783 .raw_encoder
784 .close_and_swap()
785 .map_pass_err(pass_scope)?;
786
787 Ok(())
788}
789
790fn set_pipeline(
791 state: &mut State,
792 device: &Arc<Device>,
793 pipeline: Arc<ComputePipeline>,
794) -> Result<(), ComputePassErrorInner> {
795 pipeline.same_device(device)?;
796
797 state.pipeline = Some(pipeline.clone());
798
799 let pipeline = state
800 .pass
801 .base
802 .tracker
803 .compute_pipelines
804 .insert_single(pipeline)
805 .clone();
806
807 unsafe {
808 state
809 .pass
810 .base
811 .raw_encoder
812 .set_compute_pipeline(pipeline.raw());
813 }
814
815 pass::change_pipeline_layout::<ComputePassErrorInner, _>(
817 &mut state.pass,
818 &pipeline.layout,
819 &pipeline.late_sized_buffer_groups,
820 || {
821 state.immediates.clear();
824 if pipeline.layout.immediate_size != 0 {
826 let len = pipeline.layout.immediate_size as usize
828 / wgt::IMMEDIATE_DATA_ALIGNMENT as usize;
829 state.immediates.extend(core::iter::repeat_n(0, len));
830 }
831 },
832 )
833}
834
835fn dispatch(state: &mut State, groups: [u32; 3]) -> Result<(), ComputePassErrorInner> {
836 api_log!("ComputePass::dispatch {groups:?}");
837
838 state.is_ready()?;
839
840 state.flush_bindings(None, false)?;
841
842 let groups_size_limit = state
843 .pass
844 .base
845 .device
846 .limits
847 .max_compute_workgroups_per_dimension;
848
849 if groups.iter().copied().any(|g| g > groups_size_limit) {
850 return Err(ComputePassErrorInner::Dispatch(
851 DispatchError::InvalidGroupSize {
852 current: groups,
853 limit: groups_size_limit,
854 },
855 ));
856 }
857
858 unsafe {
859 state.pass.base.raw_encoder.dispatch(groups);
860 }
861 Ok(())
862}
863
864fn dispatch_indirect(
865 state: &mut State,
866 device: &Arc<Device>,
867 buffer: Arc<Buffer>,
868 offset: u64,
869) -> Result<(), ComputePassErrorInner> {
870 api_log!("ComputePass::dispatch_indirect");
871
872 buffer.same_device(device)?;
873
874 state.is_ready()?;
875
876 state
877 .pass
878 .base
879 .device
880 .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?;
881
882 buffer.check_usage(wgt::BufferUsages::INDIRECT)?;
883
884 if !offset.is_multiple_of(4) {
885 return Err(ComputePassErrorInner::UnalignedIndirectBufferOffset(offset));
886 }
887
888 let end_offset = offset + size_of::<wgt::DispatchIndirectArgs>() as u64;
889 if end_offset > buffer.size {
890 return Err(ComputePassErrorInner::IndirectBufferOverrun {
891 offset,
892 end_offset,
893 buffer_size: buffer.size,
894 });
895 }
896
897 buffer.check_destroyed(state.pass.base.snatch_guard)?;
898
899 let stride = 3 * 4; state.pass.base.buffer_memory_init_actions.extend(
901 buffer.initialization_status.read().create_action(
902 &buffer,
903 offset..(offset + stride),
904 MemoryInitKind::NeedsInitializedMemory,
905 ),
906 );
907
908 if let Some(ref indirect_validation) = state.pass.base.device.indirect_validation {
909 let params = indirect_validation.dispatch.params(
910 &state.pass.base.device.limits,
911 offset,
912 buffer.size,
913 );
914
915 unsafe {
916 state
917 .pass
918 .base
919 .raw_encoder
920 .set_compute_pipeline(params.pipeline);
921 }
922
923 unsafe {
924 state.pass.base.raw_encoder.set_immediates(
925 params.pipeline_layout,
926 0,
927 &[params.offset_remainder as u32 / 4],
928 );
929 }
930
931 unsafe {
932 state.pass.base.raw_encoder.set_bind_group(
933 params.pipeline_layout,
934 0,
935 params.dst_bind_group,
936 &[],
937 );
938 }
939 unsafe {
940 state.pass.base.raw_encoder.set_bind_group(
941 params.pipeline_layout,
942 1,
943 buffer
944 .indirect_validation_bind_groups
945 .get(state.pass.base.snatch_guard)
946 .unwrap()
947 .dispatch
948 .as_ref(),
949 &[params.aligned_offset as u32],
950 );
951 }
952
953 let src_transition = state
954 .intermediate_trackers
955 .buffers
956 .set_single(&buffer, wgt::BufferUses::STORAGE_READ_ONLY);
957 let src_barrier = src_transition
958 .map(|transition| transition.into_hal(&buffer, state.pass.base.snatch_guard));
959 unsafe {
960 state
961 .pass
962 .base
963 .raw_encoder
964 .transition_buffers(src_barrier.as_slice());
965 }
966
967 unsafe {
968 state
969 .pass
970 .base
971 .raw_encoder
972 .transition_buffers(&[hal::BufferBarrier {
973 buffer: params.dst_buffer,
974 usage: hal::StateTransition {
975 from: wgt::BufferUses::INDIRECT,
976 to: wgt::BufferUses::STORAGE_READ_WRITE,
977 },
978 }]);
979 }
980
981 unsafe {
982 state.pass.base.raw_encoder.dispatch([1, 1, 1]);
983 }
984
985 {
987 let pipeline = state.pipeline.as_ref().unwrap();
988
989 unsafe {
990 state
991 .pass
992 .base
993 .raw_encoder
994 .set_compute_pipeline(pipeline.raw());
995 }
996
997 if !state.immediates.is_empty() {
998 unsafe {
999 state.pass.base.raw_encoder.set_immediates(
1000 pipeline.layout.raw(),
1001 0,
1002 &state.immediates,
1003 );
1004 }
1005 }
1006
1007 for (i, group, dynamic_offsets) in state.pass.binder.list_valid() {
1008 let raw_bg = group.try_raw(state.pass.base.snatch_guard)?;
1009 unsafe {
1010 state.pass.base.raw_encoder.set_bind_group(
1011 pipeline.layout.raw(),
1012 i as u32,
1013 raw_bg,
1014 dynamic_offsets,
1015 );
1016 }
1017 }
1018 }
1019
1020 unsafe {
1021 state
1022 .pass
1023 .base
1024 .raw_encoder
1025 .transition_buffers(&[hal::BufferBarrier {
1026 buffer: params.dst_buffer,
1027 usage: hal::StateTransition {
1028 from: wgt::BufferUses::STORAGE_READ_WRITE,
1029 to: wgt::BufferUses::INDIRECT,
1030 },
1031 }]);
1032 }
1033
1034 state.flush_bindings(Some(&buffer), false)?;
1035 unsafe {
1036 state
1037 .pass
1038 .base
1039 .raw_encoder
1040 .dispatch_indirect(params.dst_buffer, 0);
1041 }
1042 } else {
1043 state.flush_bindings(Some(&buffer), true)?;
1044
1045 let buf_raw = buffer.try_raw(state.pass.base.snatch_guard)?;
1046 unsafe {
1047 state
1048 .pass
1049 .base
1050 .raw_encoder
1051 .dispatch_indirect(buf_raw, offset);
1052 }
1053 }
1054
1055 Ok(())
1056}
1057
1058impl Global {
1071 pub fn compute_pass_set_bind_group(
1072 &self,
1073 pass: &mut ComputePass,
1074 index: u32,
1075 bind_group_id: Option<id::BindGroupId>,
1076 offsets: &[DynamicOffset],
1077 ) -> Result<(), PassStateError> {
1078 let scope = PassErrorScope::SetBindGroup;
1079
1080 let base = pass_base!(pass, scope);
1084
1085 if pass.current_bind_groups.set_and_check_redundant(
1086 bind_group_id,
1087 index,
1088 &mut base.dynamic_offsets,
1089 offsets,
1090 ) {
1091 return Ok(());
1092 }
1093
1094 let mut bind_group = None;
1095 if let Some(bind_group_id) = bind_group_id {
1096 let hub = &self.hub;
1097 bind_group = Some(pass_try!(
1098 base,
1099 scope,
1100 hub.bind_groups.get(bind_group_id).get(),
1101 ));
1102 }
1103
1104 base.commands.push(ArcComputeCommand::SetBindGroup {
1105 index,
1106 num_dynamic_offsets: offsets.len(),
1107 bind_group,
1108 });
1109
1110 Ok(())
1111 }
1112
1113 pub fn compute_pass_set_pipeline(
1114 &self,
1115 pass: &mut ComputePass,
1116 pipeline_id: id::ComputePipelineId,
1117 ) -> Result<(), PassStateError> {
1118 let redundant = pass.current_pipeline.set_and_check_redundant(pipeline_id);
1119
1120 let scope = PassErrorScope::SetPipelineCompute;
1121
1122 let base = pass_base!(pass, scope);
1125
1126 if redundant {
1127 return Ok(());
1128 }
1129
1130 let hub = &self.hub;
1131 let pipeline = pass_try!(base, scope, hub.compute_pipelines.get(pipeline_id).get());
1132
1133 base.commands.push(ArcComputeCommand::SetPipeline(pipeline));
1134
1135 Ok(())
1136 }
1137
1138 pub fn compute_pass_set_immediates(
1139 &self,
1140 pass: &mut ComputePass,
1141 offset: u32,
1142 data: &[u8],
1143 ) -> Result<(), PassStateError> {
1144 let scope = PassErrorScope::SetImmediate;
1145 let base = pass_base!(pass, scope);
1146
1147 if offset & (wgt::IMMEDIATE_DATA_ALIGNMENT - 1) != 0 {
1148 pass_try!(
1149 base,
1150 scope,
1151 Err(ComputePassErrorInner::ImmediateOffsetAlignment),
1152 );
1153 }
1154
1155 if data.len() as u32 & (wgt::IMMEDIATE_DATA_ALIGNMENT - 1) != 0 {
1156 pass_try!(
1157 base,
1158 scope,
1159 Err(ComputePassErrorInner::ImmediateDataizeAlignment),
1160 )
1161 }
1162 let value_offset = pass_try!(
1163 base,
1164 scope,
1165 base.immediates_data
1166 .len()
1167 .try_into()
1168 .map_err(|_| ComputePassErrorInner::ImmediateOutOfMemory)
1169 );
1170
1171 base.immediates_data.extend(
1172 data.chunks_exact(wgt::IMMEDIATE_DATA_ALIGNMENT as usize)
1173 .map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])),
1174 );
1175
1176 base.commands.push(ArcComputeCommand::SetImmediate {
1177 offset,
1178 size_bytes: data.len() as u32,
1179 values_offset: value_offset,
1180 });
1181
1182 Ok(())
1183 }
1184
1185 pub fn compute_pass_dispatch_workgroups(
1186 &self,
1187 pass: &mut ComputePass,
1188 groups_x: u32,
1189 groups_y: u32,
1190 groups_z: u32,
1191 ) -> Result<(), PassStateError> {
1192 let scope = PassErrorScope::Dispatch { indirect: false };
1193
1194 pass_base!(pass, scope)
1195 .commands
1196 .push(ArcComputeCommand::Dispatch([groups_x, groups_y, groups_z]));
1197
1198 Ok(())
1199 }
1200
1201 pub fn compute_pass_dispatch_workgroups_indirect(
1202 &self,
1203 pass: &mut ComputePass,
1204 buffer_id: id::BufferId,
1205 offset: BufferAddress,
1206 ) -> Result<(), PassStateError> {
1207 let hub = &self.hub;
1208 let scope = PassErrorScope::Dispatch { indirect: true };
1209 let base = pass_base!(pass, scope);
1210
1211 let buffer = pass_try!(base, scope, hub.buffers.get(buffer_id).get());
1212
1213 base.commands
1214 .push(ArcComputeCommand::DispatchIndirect { buffer, offset });
1215
1216 Ok(())
1217 }
1218
1219 pub fn compute_pass_push_debug_group(
1220 &self,
1221 pass: &mut ComputePass,
1222 label: &str,
1223 color: u32,
1224 ) -> Result<(), PassStateError> {
1225 let base = pass_base!(pass, PassErrorScope::PushDebugGroup);
1226
1227 let bytes = label.as_bytes();
1228 base.string_data.extend_from_slice(bytes);
1229
1230 base.commands.push(ArcComputeCommand::PushDebugGroup {
1231 color,
1232 len: bytes.len(),
1233 });
1234
1235 Ok(())
1236 }
1237
1238 pub fn compute_pass_pop_debug_group(
1239 &self,
1240 pass: &mut ComputePass,
1241 ) -> Result<(), PassStateError> {
1242 let base = pass_base!(pass, PassErrorScope::PopDebugGroup);
1243
1244 base.commands.push(ArcComputeCommand::PopDebugGroup);
1245
1246 Ok(())
1247 }
1248
1249 pub fn compute_pass_insert_debug_marker(
1250 &self,
1251 pass: &mut ComputePass,
1252 label: &str,
1253 color: u32,
1254 ) -> Result<(), PassStateError> {
1255 let base = pass_base!(pass, PassErrorScope::InsertDebugMarker);
1256
1257 let bytes = label.as_bytes();
1258 base.string_data.extend_from_slice(bytes);
1259
1260 base.commands.push(ArcComputeCommand::InsertDebugMarker {
1261 color,
1262 len: bytes.len(),
1263 });
1264
1265 Ok(())
1266 }
1267
1268 pub fn compute_pass_write_timestamp(
1269 &self,
1270 pass: &mut ComputePass,
1271 query_set_id: id::QuerySetId,
1272 query_index: u32,
1273 ) -> Result<(), PassStateError> {
1274 let scope = PassErrorScope::WriteTimestamp;
1275 let base = pass_base!(pass, scope);
1276
1277 let hub = &self.hub;
1278 let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1279
1280 base.commands.push(ArcComputeCommand::WriteTimestamp {
1281 query_set,
1282 query_index,
1283 });
1284
1285 Ok(())
1286 }
1287
1288 pub fn compute_pass_begin_pipeline_statistics_query(
1289 &self,
1290 pass: &mut ComputePass,
1291 query_set_id: id::QuerySetId,
1292 query_index: u32,
1293 ) -> Result<(), PassStateError> {
1294 let scope = PassErrorScope::BeginPipelineStatisticsQuery;
1295 let base = pass_base!(pass, scope);
1296
1297 let hub = &self.hub;
1298 let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1299
1300 base.commands
1301 .push(ArcComputeCommand::BeginPipelineStatisticsQuery {
1302 query_set,
1303 query_index,
1304 });
1305
1306 Ok(())
1307 }
1308
1309 pub fn compute_pass_end_pipeline_statistics_query(
1310 &self,
1311 pass: &mut ComputePass,
1312 ) -> Result<(), PassStateError> {
1313 pass_base!(pass, PassErrorScope::EndPipelineStatisticsQuery)
1314 .commands
1315 .push(ArcComputeCommand::EndPipelineStatisticsQuery);
1316
1317 Ok(())
1318 }
1319}