1use thiserror::Error;
2use wgt::{
3 error::{ErrorType, WebGpuError},
4 BufferAddress, DynamicOffset,
5};
6
7use alloc::{borrow::Cow, boxed::Box, sync::Arc, vec::Vec};
8use core::{fmt, str};
9
10use crate::command::{pass, EncoderStateError, PassStateError, TimestampWritesError};
11use crate::resource::DestroyedResourceError;
12use crate::{binding_model::BindError, resource::RawResourceAccess};
13use crate::{
14 binding_model::{LateMinBufferBindingSizeMismatch, PushConstantUploadError},
15 command::{
16 bind::{Binder, BinderError},
17 compute_command::ArcComputeCommand,
18 end_pipeline_statistics_query,
19 memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState},
20 pass_base, pass_try, validate_and_begin_pipeline_statistics_query, ArcPassTimestampWrites,
21 BasePass, BindGroupStateChange, CommandBuffer, CommandEncoderError, MapPassErr,
22 PassErrorScope, PassTimestampWrites, QueryUseError, StateChange,
23 },
24 device::{DeviceError, MissingDownlevelFlags, MissingFeatures},
25 global::Global,
26 hal_label, id,
27 init_tracker::MemoryInitKind,
28 pipeline::ComputePipeline,
29 resource::{
30 self, Buffer, InvalidResourceError, Labeled, MissingBufferUsageError, ParentDevice,
31 },
32 track::{ResourceUsageCompatibilityError, Tracker, TrackerIndex},
33 Label,
34};
35
36pub type ComputeBasePass = BasePass<ArcComputeCommand, ComputePassError>;
37
38pub struct ComputePass {
46 base: ComputeBasePass,
48
49 parent: Option<Arc<CommandBuffer>>,
55
56 timestamp_writes: Option<ArcPassTimestampWrites>,
57
58 current_bind_groups: BindGroupStateChange,
60 current_pipeline: StateChange<id::ComputePipelineId>,
61}
62
63impl ComputePass {
64 fn new(parent: Arc<CommandBuffer>, desc: ArcComputePassDescriptor) -> Self {
66 let ArcComputePassDescriptor {
67 label,
68 timestamp_writes,
69 } = desc;
70
71 Self {
72 base: BasePass::new(&label),
73 parent: Some(parent),
74 timestamp_writes,
75
76 current_bind_groups: BindGroupStateChange::new(),
77 current_pipeline: StateChange::new(),
78 }
79 }
80
81 fn new_invalid(parent: Arc<CommandBuffer>, label: &Label, err: ComputePassError) -> Self {
82 Self {
83 base: BasePass::new_invalid(label, err),
84 parent: Some(parent),
85 timestamp_writes: None,
86 current_bind_groups: BindGroupStateChange::new(),
87 current_pipeline: StateChange::new(),
88 }
89 }
90
91 #[inline]
92 pub fn label(&self) -> Option<&str> {
93 self.base.label.as_deref()
94 }
95}
96
97impl fmt::Debug for ComputePass {
98 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99 match self.parent {
100 Some(ref cmd_buf) => write!(f, "ComputePass {{ parent: {} }}", cmd_buf.error_ident()),
101 None => write!(f, "ComputePass {{ parent: None }}"),
102 }
103 }
104}
105
106#[derive(Clone, Debug, Default)]
107pub struct ComputePassDescriptor<'a, PTW = PassTimestampWrites> {
108 pub label: Label<'a>,
109 pub timestamp_writes: Option<PTW>,
111}
112
113type ArcComputePassDescriptor<'a> = ComputePassDescriptor<'a, ArcPassTimestampWrites>;
115
116#[derive(Clone, Debug, Error)]
117#[non_exhaustive]
118pub enum DispatchError {
119 #[error("Compute pipeline must be set")]
120 MissingPipeline(pass::MissingPipeline),
121 #[error(transparent)]
122 IncompatibleBindGroup(#[from] Box<BinderError>),
123 #[error(
124 "Each current dispatch group size dimension ({current:?}) must be less or equal to {limit}"
125 )]
126 InvalidGroupSize { current: [u32; 3], limit: u32 },
127 #[error(transparent)]
128 BindingSizeTooSmall(#[from] LateMinBufferBindingSizeMismatch),
129}
130
131impl WebGpuError for DispatchError {
132 fn webgpu_error_type(&self) -> ErrorType {
133 ErrorType::Validation
134 }
135}
136
137#[derive(Clone, Debug, Error)]
139pub enum ComputePassErrorInner {
140 #[error(transparent)]
141 Device(#[from] DeviceError),
142 #[error(transparent)]
143 EncoderState(#[from] EncoderStateError),
144 #[error("Parent encoder is invalid")]
145 InvalidParentEncoder,
146 #[error(transparent)]
147 BindGroupIndexOutOfRange(#[from] pass::BindGroupIndexOutOfRange),
148 #[error(transparent)]
149 DestroyedResource(#[from] DestroyedResourceError),
150 #[error("Indirect buffer offset {0:?} is not a multiple of 4")]
151 UnalignedIndirectBufferOffset(BufferAddress),
152 #[error("Indirect buffer uses bytes {offset}..{end_offset} which overruns indirect buffer of size {buffer_size}")]
153 IndirectBufferOverrun {
154 offset: u64,
155 end_offset: u64,
156 buffer_size: u64,
157 },
158 #[error(transparent)]
159 ResourceUsageCompatibility(#[from] ResourceUsageCompatibilityError),
160 #[error(transparent)]
161 MissingBufferUsage(#[from] MissingBufferUsageError),
162 #[error(transparent)]
163 InvalidPopDebugGroup(#[from] pass::InvalidPopDebugGroup),
164 #[error(transparent)]
165 Dispatch(#[from] DispatchError),
166 #[error(transparent)]
167 Bind(#[from] BindError),
168 #[error(transparent)]
169 PushConstants(#[from] PushConstantUploadError),
170 #[error("Push constant offset must be aligned to 4 bytes")]
171 PushConstantOffsetAlignment,
172 #[error("Push constant size must be aligned to 4 bytes")]
173 PushConstantSizeAlignment,
174 #[error("Ran out of push constant space. Don't set 4gb of push constants per ComputePass.")]
175 PushConstantOutOfMemory,
176 #[error(transparent)]
177 QueryUse(#[from] QueryUseError),
178 #[error(transparent)]
179 MissingFeatures(#[from] MissingFeatures),
180 #[error(transparent)]
181 MissingDownlevelFlags(#[from] MissingDownlevelFlags),
182 #[error("The compute pass has already been ended and no further commands can be recorded")]
183 PassEnded,
184 #[error(transparent)]
185 InvalidResource(#[from] InvalidResourceError),
186 #[error(transparent)]
187 TimestampWrites(#[from] TimestampWritesError),
188 #[error(transparent)]
190 InvalidValuesOffset(#[from] pass::InvalidValuesOffset),
191}
192
193#[derive(Clone, Debug, Error)]
196#[error("{scope}")]
197pub struct ComputePassError {
198 pub scope: PassErrorScope,
199 #[source]
200 pub(super) inner: ComputePassErrorInner,
201}
202
203impl From<pass::MissingPipeline> for ComputePassErrorInner {
204 fn from(value: pass::MissingPipeline) -> Self {
205 Self::Dispatch(DispatchError::MissingPipeline(value))
206 }
207}
208
209impl<E> MapPassErr<ComputePassError> for E
210where
211 E: Into<ComputePassErrorInner>,
212{
213 fn map_pass_err(self, scope: PassErrorScope) -> ComputePassError {
214 ComputePassError {
215 scope,
216 inner: self.into(),
217 }
218 }
219}
220
221impl WebGpuError for ComputePassError {
222 fn webgpu_error_type(&self) -> ErrorType {
223 let Self { scope: _, inner } = self;
224 let e: &dyn WebGpuError = match inner {
225 ComputePassErrorInner::Device(e) => e,
226 ComputePassErrorInner::EncoderState(e) => e,
227 ComputePassErrorInner::DestroyedResource(e) => e,
228 ComputePassErrorInner::ResourceUsageCompatibility(e) => e,
229 ComputePassErrorInner::MissingBufferUsage(e) => e,
230 ComputePassErrorInner::Dispatch(e) => e,
231 ComputePassErrorInner::Bind(e) => e,
232 ComputePassErrorInner::PushConstants(e) => e,
233 ComputePassErrorInner::QueryUse(e) => e,
234 ComputePassErrorInner::MissingFeatures(e) => e,
235 ComputePassErrorInner::MissingDownlevelFlags(e) => e,
236 ComputePassErrorInner::InvalidResource(e) => e,
237 ComputePassErrorInner::TimestampWrites(e) => e,
238 ComputePassErrorInner::InvalidValuesOffset(e) => e,
239 ComputePassErrorInner::InvalidPopDebugGroup(e) => e,
240
241 ComputePassErrorInner::InvalidParentEncoder
242 | ComputePassErrorInner::BindGroupIndexOutOfRange { .. }
243 | ComputePassErrorInner::UnalignedIndirectBufferOffset(_)
244 | ComputePassErrorInner::IndirectBufferOverrun { .. }
245 | ComputePassErrorInner::PushConstantOffsetAlignment
246 | ComputePassErrorInner::PushConstantSizeAlignment
247 | ComputePassErrorInner::PushConstantOutOfMemory
248 | ComputePassErrorInner::PassEnded => return ErrorType::Validation,
249 };
250 e.webgpu_error_type()
251 }
252}
253
254struct State<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder> {
255 pipeline: Option<Arc<ComputePipeline>>,
256
257 general: pass::BaseState<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder>,
258
259 active_query: Option<(Arc<resource::QuerySet>, u32)>,
260
261 push_constants: Vec<u32>,
262
263 intermediate_trackers: Tracker,
264}
265
266impl<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder>
267 State<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder>
268{
269 fn is_ready(&self) -> Result<(), DispatchError> {
270 if let Some(pipeline) = self.pipeline.as_ref() {
271 self.general.binder.check_compatibility(pipeline.as_ref())?;
272 self.general.binder.check_late_buffer_bindings()?;
273 Ok(())
274 } else {
275 Err(DispatchError::MissingPipeline(pass::MissingPipeline))
276 }
277 }
278
279 fn flush_states(
282 &mut self,
283 indirect_buffer: Option<TrackerIndex>,
284 ) -> Result<(), ResourceUsageCompatibilityError> {
285 for bind_group in self.general.binder.list_active() {
286 unsafe { self.general.scope.merge_bind_group(&bind_group.used)? };
287 }
290
291 for bind_group in self.general.binder.list_active() {
292 unsafe {
293 self.intermediate_trackers
294 .set_and_remove_from_usage_scope_sparse(
295 &mut self.general.scope,
296 &bind_group.used,
297 )
298 }
299 }
300
301 unsafe {
303 self.intermediate_trackers
304 .buffers
305 .set_and_remove_from_usage_scope_sparse(
306 &mut self.general.scope.buffers,
307 indirect_buffer,
308 );
309 }
310
311 CommandBuffer::drain_barriers(
312 self.general.raw_encoder,
313 &mut self.intermediate_trackers,
314 self.general.snatch_guard,
315 );
316 Ok(())
317 }
318}
319
320impl Global {
323 pub fn command_encoder_begin_compute_pass(
334 &self,
335 encoder_id: id::CommandEncoderId,
336 desc: &ComputePassDescriptor<'_>,
337 ) -> (ComputePass, Option<CommandEncoderError>) {
338 use EncoderStateError as SErr;
339
340 let scope = PassErrorScope::Pass;
341 let hub = &self.hub;
342
343 let label = desc.label.as_deref().map(Cow::Borrowed);
344
345 let cmd_buf = hub.command_buffers.get(encoder_id.into_command_buffer_id());
346 let mut cmd_buf_data = cmd_buf.data.lock();
347
348 match cmd_buf_data.lock_encoder() {
349 Ok(()) => {
350 drop(cmd_buf_data);
351 if let Err(err) = cmd_buf.device.check_is_valid() {
352 return (
353 ComputePass::new_invalid(cmd_buf, &label, err.map_pass_err(scope)),
354 None,
355 );
356 }
357
358 match desc
359 .timestamp_writes
360 .as_ref()
361 .map(|tw| {
362 Self::validate_pass_timestamp_writes::<ComputePassErrorInner>(
363 &cmd_buf.device,
364 &hub.query_sets.read(),
365 tw,
366 )
367 })
368 .transpose()
369 {
370 Ok(timestamp_writes) => {
371 let arc_desc = ArcComputePassDescriptor {
372 label,
373 timestamp_writes,
374 };
375 (ComputePass::new(cmd_buf, arc_desc), None)
376 }
377 Err(err) => (
378 ComputePass::new_invalid(cmd_buf, &label, err.map_pass_err(scope)),
379 None,
380 ),
381 }
382 }
383 Err(err @ SErr::Locked) => {
384 cmd_buf_data.invalidate(err.clone());
388 drop(cmd_buf_data);
389 (
390 ComputePass::new_invalid(cmd_buf, &label, err.map_pass_err(scope)),
391 None,
392 )
393 }
394 Err(err @ (SErr::Ended | SErr::Submitted)) => {
395 drop(cmd_buf_data);
398 (
399 ComputePass::new_invalid(cmd_buf, &label, err.clone().map_pass_err(scope)),
400 Some(err.into()),
401 )
402 }
403 Err(err @ SErr::Invalid) => {
404 drop(cmd_buf_data);
410 (
411 ComputePass::new_invalid(cmd_buf, &label, err.map_pass_err(scope)),
412 None,
413 )
414 }
415 Err(SErr::Unlocked) => {
416 unreachable!("lock_encoder cannot fail due to the encoder being unlocked")
417 }
418 }
419 }
420
421 #[doc(hidden)]
427 #[cfg(any(feature = "serde", feature = "replay"))]
428 pub fn compute_pass_end_with_unresolved_commands(
429 &self,
430 encoder_id: id::CommandEncoderId,
431 base: BasePass<super::ComputeCommand, core::convert::Infallible>,
432 timestamp_writes: Option<&PassTimestampWrites>,
433 ) {
434 #[cfg(feature = "trace")]
435 {
436 let cmd_buf = self
437 .hub
438 .command_buffers
439 .get(encoder_id.into_command_buffer_id());
440 let mut cmd_buf_data = cmd_buf.data.lock();
441 let cmd_buf_data = cmd_buf_data.get_inner();
442
443 if let Some(ref mut list) = cmd_buf_data.commands {
444 list.push(crate::device::trace::Command::RunComputePass {
445 base: BasePass {
446 label: base.label.clone(),
447 error: None,
448 commands: base.commands.clone(),
449 dynamic_offsets: base.dynamic_offsets.clone(),
450 string_data: base.string_data.clone(),
451 push_constant_data: base.push_constant_data.clone(),
452 },
453 timestamp_writes: timestamp_writes.cloned(),
454 });
455 }
456 }
457
458 let BasePass {
459 label,
460 error: _,
461 commands,
462 dynamic_offsets,
463 string_data,
464 push_constant_data,
465 } = base;
466
467 let (mut compute_pass, encoder_error) = self.command_encoder_begin_compute_pass(
468 encoder_id,
469 &ComputePassDescriptor {
470 label: label.as_deref().map(Cow::Borrowed),
471 timestamp_writes: timestamp_writes.cloned(),
472 },
473 );
474 if let Some(err) = encoder_error {
475 panic!("{:?}", err);
476 };
477
478 compute_pass.base = BasePass {
479 label,
480 error: None,
481 commands: super::ComputeCommand::resolve_compute_command_ids(&self.hub, &commands)
482 .unwrap(),
483 dynamic_offsets,
484 string_data,
485 push_constant_data,
486 };
487
488 self.compute_pass_end(&mut compute_pass).unwrap();
489 }
490
491 pub fn compute_pass_end(&self, pass: &mut ComputePass) -> Result<(), EncoderStateError> {
492 let pass_scope = PassErrorScope::Pass;
493 profiling::scope!(
494 "CommandEncoder::run_compute_pass {}",
495 pass.base.label.as_deref().unwrap_or("")
496 );
497
498 let cmd_buf = pass.parent.take().ok_or(EncoderStateError::Ended)?;
499 let mut cmd_buf_data = cmd_buf.data.lock();
500
501 if let Some(err) = pass.base.error.take() {
502 if matches!(
503 err,
504 ComputePassError {
505 inner: ComputePassErrorInner::EncoderState(EncoderStateError::Ended),
506 scope: _,
507 }
508 ) {
509 return Err(EncoderStateError::Ended);
514 } else {
515 cmd_buf_data.invalidate(err);
519 return Ok(());
520 }
521 }
522
523 cmd_buf_data.unlock_and_record(|cmd_buf_data| -> Result<(), ComputePassError> {
524 let device = &cmd_buf.device;
525 device.check_is_valid().map_pass_err(pass_scope)?;
526
527 let base = &mut pass.base;
528
529 let encoder = &mut cmd_buf_data.encoder;
530
531 encoder.close_if_open().map_pass_err(pass_scope)?;
535 let raw_encoder = encoder
536 .open_pass(base.label.as_deref())
537 .map_pass_err(pass_scope)?;
538
539 let snatch_guard = device.snatchable_lock.read();
540
541 let mut state = State {
542 pipeline: None,
543
544 general: pass::BaseState {
545 device,
546 raw_encoder,
547 tracker: &mut cmd_buf_data.trackers,
548 buffer_memory_init_actions: &mut cmd_buf_data.buffer_memory_init_actions,
549 texture_memory_actions: &mut cmd_buf_data.texture_memory_actions,
550 as_actions: &mut cmd_buf_data.as_actions,
551 binder: Binder::new(),
552 temp_offsets: Vec::new(),
553 dynamic_offset_count: 0,
554
555 pending_discard_init_fixups: SurfacesInDiscardState::new(),
556
557 snatch_guard: &snatch_guard,
558 scope: device.new_usage_scope(),
559
560 debug_scope_depth: 0,
561 string_offset: 0,
562 },
563 active_query: None,
564
565 push_constants: Vec::new(),
566
567 intermediate_trackers: Tracker::new(),
568 };
569
570 let indices = &state.general.device.tracker_indices;
571 state
572 .general
573 .tracker
574 .buffers
575 .set_size(indices.buffers.size());
576 state
577 .general
578 .tracker
579 .textures
580 .set_size(indices.textures.size());
581
582 let timestamp_writes: Option<hal::PassTimestampWrites<'_, dyn hal::DynQuerySet>> =
583 if let Some(tw) = pass.timestamp_writes.take() {
584 tw.query_set
585 .same_device_as(cmd_buf.as_ref())
586 .map_pass_err(pass_scope)?;
587
588 let query_set = state.general.tracker.query_sets.insert_single(tw.query_set);
589
590 let range = if let (Some(index_a), Some(index_b)) =
593 (tw.beginning_of_pass_write_index, tw.end_of_pass_write_index)
594 {
595 Some(index_a.min(index_b)..index_a.max(index_b) + 1)
596 } else {
597 tw.beginning_of_pass_write_index
598 .or(tw.end_of_pass_write_index)
599 .map(|i| i..i + 1)
600 };
601 if let Some(range) = range {
604 unsafe {
605 state
606 .general
607 .raw_encoder
608 .reset_queries(query_set.raw(), range);
609 }
610 }
611
612 Some(hal::PassTimestampWrites {
613 query_set: query_set.raw(),
614 beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
615 end_of_pass_write_index: tw.end_of_pass_write_index,
616 })
617 } else {
618 None
619 };
620
621 let hal_desc = hal::ComputePassDescriptor {
622 label: hal_label(base.label.as_deref(), device.instance_flags),
623 timestamp_writes,
624 };
625
626 unsafe {
627 state.general.raw_encoder.begin_compute_pass(&hal_desc);
628 }
629
630 for command in base.commands.drain(..) {
631 match command {
632 ArcComputeCommand::SetBindGroup {
633 index,
634 num_dynamic_offsets,
635 bind_group,
636 } => {
637 let scope = PassErrorScope::SetBindGroup;
638 pass::set_bind_group::<ComputePassErrorInner>(
639 &mut state.general,
640 cmd_buf.as_ref(),
641 &base.dynamic_offsets,
642 index,
643 num_dynamic_offsets,
644 bind_group,
645 false,
646 )
647 .map_pass_err(scope)?;
648 }
649 ArcComputeCommand::SetPipeline(pipeline) => {
650 let scope = PassErrorScope::SetPipelineCompute;
651 set_pipeline(&mut state, cmd_buf.as_ref(), pipeline).map_pass_err(scope)?;
652 }
653 ArcComputeCommand::SetPushConstant {
654 offset,
655 size_bytes,
656 values_offset,
657 } => {
658 let scope = PassErrorScope::SetPushConstant;
659 pass::set_push_constant::<ComputePassErrorInner, _>(
660 &mut state.general,
661 &base.push_constant_data,
662 wgt::ShaderStages::COMPUTE,
663 offset,
664 size_bytes,
665 Some(values_offset),
666 |data_slice| {
667 let offset_in_elements =
668 (offset / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
669 let size_in_elements =
670 (size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
671 state.push_constants[offset_in_elements..][..size_in_elements]
672 .copy_from_slice(data_slice);
673 },
674 )
675 .map_pass_err(scope)?;
676 }
677 ArcComputeCommand::Dispatch(groups) => {
678 let scope = PassErrorScope::Dispatch { indirect: false };
679 dispatch(&mut state, groups).map_pass_err(scope)?;
680 }
681 ArcComputeCommand::DispatchIndirect { buffer, offset } => {
682 let scope = PassErrorScope::Dispatch { indirect: true };
683 dispatch_indirect(&mut state, cmd_buf.as_ref(), buffer, offset)
684 .map_pass_err(scope)?;
685 }
686 ArcComputeCommand::PushDebugGroup { color: _, len } => {
687 pass::push_debug_group(&mut state.general, &base.string_data, len);
688 }
689 ArcComputeCommand::PopDebugGroup => {
690 let scope = PassErrorScope::PopDebugGroup;
691 pass::pop_debug_group::<ComputePassErrorInner>(&mut state.general)
692 .map_pass_err(scope)?;
693 }
694 ArcComputeCommand::InsertDebugMarker { color: _, len } => {
695 pass::insert_debug_marker(&mut state.general, &base.string_data, len);
696 }
697 ArcComputeCommand::WriteTimestamp {
698 query_set,
699 query_index,
700 } => {
701 let scope = PassErrorScope::WriteTimestamp;
702 pass::write_timestamp::<ComputePassErrorInner>(
703 &mut state.general,
704 cmd_buf.as_ref(),
705 None,
706 query_set,
707 query_index,
708 )
709 .map_pass_err(scope)?;
710 }
711 ArcComputeCommand::BeginPipelineStatisticsQuery {
712 query_set,
713 query_index,
714 } => {
715 let scope = PassErrorScope::BeginPipelineStatisticsQuery;
716 validate_and_begin_pipeline_statistics_query(
717 query_set,
718 state.general.raw_encoder,
719 &mut state.general.tracker.query_sets,
720 cmd_buf.as_ref(),
721 query_index,
722 None,
723 &mut state.active_query,
724 )
725 .map_pass_err(scope)?;
726 }
727 ArcComputeCommand::EndPipelineStatisticsQuery => {
728 let scope = PassErrorScope::EndPipelineStatisticsQuery;
729 end_pipeline_statistics_query(
730 state.general.raw_encoder,
731 &mut state.active_query,
732 )
733 .map_pass_err(scope)?;
734 }
735 }
736 }
737
738 unsafe {
739 state.general.raw_encoder.end_compute_pass();
740 }
741
742 let State {
743 general:
744 pass::BaseState {
745 tracker,
746 pending_discard_init_fixups,
747 ..
748 },
749 intermediate_trackers,
750 ..
751 } = state;
752
753 encoder.close().map_pass_err(pass_scope)?;
755
756 let transit = encoder
760 .open_pass(Some("(wgpu internal) Pre Pass"))
761 .map_pass_err(pass_scope)?;
762 fixup_discarded_surfaces(
763 pending_discard_init_fixups.into_iter(),
764 transit,
765 &mut tracker.textures,
766 device,
767 &snatch_guard,
768 );
769 CommandBuffer::insert_barriers_from_tracker(
770 transit,
771 tracker,
772 &intermediate_trackers,
773 &snatch_guard,
774 );
775 encoder.close_and_swap().map_pass_err(pass_scope)?;
777
778 Ok(())
779 })
780 }
781}
782
783fn set_pipeline(
784 state: &mut State,
785 cmd_buf: &CommandBuffer,
786 pipeline: Arc<ComputePipeline>,
787) -> Result<(), ComputePassErrorInner> {
788 pipeline.same_device_as(cmd_buf)?;
789
790 state.pipeline = Some(pipeline.clone());
791
792 let pipeline = state
793 .general
794 .tracker
795 .compute_pipelines
796 .insert_single(pipeline)
797 .clone();
798
799 unsafe {
800 state
801 .general
802 .raw_encoder
803 .set_compute_pipeline(pipeline.raw());
804 }
805
806 pass::rebind_resources::<ComputePassErrorInner, _>(
808 &mut state.general,
809 &pipeline.layout,
810 &pipeline.late_sized_buffer_groups,
811 || {
812 state.push_constants.clear();
815 if let Some(push_constant_range) =
817 pipeline.layout.push_constant_ranges.iter().find_map(|pcr| {
818 pcr.stages
819 .contains(wgt::ShaderStages::COMPUTE)
820 .then_some(pcr.range.clone())
821 })
822 {
823 let len = push_constant_range.len() / wgt::PUSH_CONSTANT_ALIGNMENT as usize;
825 state.push_constants.extend(core::iter::repeat_n(0, len));
826 }
827 },
828 )
829}
830
831fn dispatch(state: &mut State, groups: [u32; 3]) -> Result<(), ComputePassErrorInner> {
832 state.is_ready()?;
833
834 state.flush_states(None)?;
835
836 let groups_size_limit = state
837 .general
838 .device
839 .limits
840 .max_compute_workgroups_per_dimension;
841
842 if groups[0] > groups_size_limit
843 || groups[1] > groups_size_limit
844 || groups[2] > groups_size_limit
845 {
846 return Err(ComputePassErrorInner::Dispatch(
847 DispatchError::InvalidGroupSize {
848 current: groups,
849 limit: groups_size_limit,
850 },
851 ));
852 }
853
854 unsafe {
855 state.general.raw_encoder.dispatch(groups);
856 }
857 Ok(())
858}
859
860fn dispatch_indirect(
861 state: &mut State,
862 cmd_buf: &CommandBuffer,
863 buffer: Arc<Buffer>,
864 offset: u64,
865) -> Result<(), ComputePassErrorInner> {
866 buffer.same_device_as(cmd_buf)?;
867
868 state.is_ready()?;
869
870 state
871 .general
872 .device
873 .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?;
874
875 buffer.check_usage(wgt::BufferUsages::INDIRECT)?;
876 buffer.check_destroyed(state.general.snatch_guard)?;
877
878 if offset % 4 != 0 {
879 return Err(ComputePassErrorInner::UnalignedIndirectBufferOffset(offset));
880 }
881
882 let end_offset = offset + size_of::<wgt::DispatchIndirectArgs>() as u64;
883 if end_offset > buffer.size {
884 return Err(ComputePassErrorInner::IndirectBufferOverrun {
885 offset,
886 end_offset,
887 buffer_size: buffer.size,
888 });
889 }
890
891 let stride = 3 * 4; state.general.buffer_memory_init_actions.extend(
893 buffer.initialization_status.read().create_action(
894 &buffer,
895 offset..(offset + stride),
896 MemoryInitKind::NeedsInitializedMemory,
897 ),
898 );
899
900 if let Some(ref indirect_validation) = state.general.device.indirect_validation {
901 let params =
902 indirect_validation
903 .dispatch
904 .params(&state.general.device.limits, offset, buffer.size);
905
906 unsafe {
907 state
908 .general
909 .raw_encoder
910 .set_compute_pipeline(params.pipeline);
911 }
912
913 unsafe {
914 state.general.raw_encoder.set_push_constants(
915 params.pipeline_layout,
916 wgt::ShaderStages::COMPUTE,
917 0,
918 &[params.offset_remainder as u32 / 4],
919 );
920 }
921
922 unsafe {
923 state.general.raw_encoder.set_bind_group(
924 params.pipeline_layout,
925 0,
926 Some(params.dst_bind_group),
927 &[],
928 );
929 }
930 unsafe {
931 state.general.raw_encoder.set_bind_group(
932 params.pipeline_layout,
933 1,
934 Some(
935 buffer
936 .indirect_validation_bind_groups
937 .get(state.general.snatch_guard)
938 .unwrap()
939 .dispatch
940 .as_ref(),
941 ),
942 &[params.aligned_offset as u32],
943 );
944 }
945
946 let src_transition = state
947 .intermediate_trackers
948 .buffers
949 .set_single(&buffer, wgt::BufferUses::STORAGE_READ_ONLY);
950 let src_barrier = src_transition
951 .map(|transition| transition.into_hal(&buffer, state.general.snatch_guard));
952 unsafe {
953 state
954 .general
955 .raw_encoder
956 .transition_buffers(src_barrier.as_slice());
957 }
958
959 unsafe {
960 state
961 .general
962 .raw_encoder
963 .transition_buffers(&[hal::BufferBarrier {
964 buffer: params.dst_buffer,
965 usage: hal::StateTransition {
966 from: wgt::BufferUses::INDIRECT,
967 to: wgt::BufferUses::STORAGE_READ_WRITE,
968 },
969 }]);
970 }
971
972 unsafe {
973 state.general.raw_encoder.dispatch([1, 1, 1]);
974 }
975
976 {
978 let pipeline = state.pipeline.as_ref().unwrap();
979
980 unsafe {
981 state
982 .general
983 .raw_encoder
984 .set_compute_pipeline(pipeline.raw());
985 }
986
987 if !state.push_constants.is_empty() {
988 unsafe {
989 state.general.raw_encoder.set_push_constants(
990 pipeline.layout.raw(),
991 wgt::ShaderStages::COMPUTE,
992 0,
993 &state.push_constants,
994 );
995 }
996 }
997
998 for (i, e) in state.general.binder.list_valid() {
999 let group = e.group.as_ref().unwrap();
1000 let raw_bg = group.try_raw(state.general.snatch_guard)?;
1001 unsafe {
1002 state.general.raw_encoder.set_bind_group(
1003 pipeline.layout.raw(),
1004 i as u32,
1005 Some(raw_bg),
1006 &e.dynamic_offsets,
1007 );
1008 }
1009 }
1010 }
1011
1012 unsafe {
1013 state
1014 .general
1015 .raw_encoder
1016 .transition_buffers(&[hal::BufferBarrier {
1017 buffer: params.dst_buffer,
1018 usage: hal::StateTransition {
1019 from: wgt::BufferUses::STORAGE_READ_WRITE,
1020 to: wgt::BufferUses::INDIRECT,
1021 },
1022 }]);
1023 }
1024
1025 state.flush_states(None)?;
1026 unsafe {
1027 state
1028 .general
1029 .raw_encoder
1030 .dispatch_indirect(params.dst_buffer, 0);
1031 }
1032 } else {
1033 state
1034 .general
1035 .scope
1036 .buffers
1037 .merge_single(&buffer, wgt::BufferUses::INDIRECT)?;
1038
1039 use crate::resource::Trackable;
1040 state.flush_states(Some(buffer.tracker_index()))?;
1041
1042 let buf_raw = buffer.try_raw(state.general.snatch_guard)?;
1043 unsafe {
1044 state.general.raw_encoder.dispatch_indirect(buf_raw, offset);
1045 }
1046 }
1047
1048 Ok(())
1049}
1050
1051impl Global {
1064 pub fn compute_pass_set_bind_group(
1065 &self,
1066 pass: &mut ComputePass,
1067 index: u32,
1068 bind_group_id: Option<id::BindGroupId>,
1069 offsets: &[DynamicOffset],
1070 ) -> Result<(), PassStateError> {
1071 let scope = PassErrorScope::SetBindGroup;
1072
1073 let base = pass_base!(pass, scope);
1077
1078 if pass.current_bind_groups.set_and_check_redundant(
1079 bind_group_id,
1080 index,
1081 &mut base.dynamic_offsets,
1082 offsets,
1083 ) {
1084 return Ok(());
1085 }
1086
1087 let mut bind_group = None;
1088 if bind_group_id.is_some() {
1089 let bind_group_id = bind_group_id.unwrap();
1090
1091 let hub = &self.hub;
1092 bind_group = Some(pass_try!(
1093 base,
1094 scope,
1095 hub.bind_groups.get(bind_group_id).get(),
1096 ));
1097 }
1098
1099 base.commands.push(ArcComputeCommand::SetBindGroup {
1100 index,
1101 num_dynamic_offsets: offsets.len(),
1102 bind_group,
1103 });
1104
1105 Ok(())
1106 }
1107
1108 pub fn compute_pass_set_pipeline(
1109 &self,
1110 pass: &mut ComputePass,
1111 pipeline_id: id::ComputePipelineId,
1112 ) -> Result<(), PassStateError> {
1113 let redundant = pass.current_pipeline.set_and_check_redundant(pipeline_id);
1114
1115 let scope = PassErrorScope::SetPipelineCompute;
1116
1117 let base = pass_base!(pass, scope);
1120
1121 if redundant {
1122 return Ok(());
1123 }
1124
1125 let hub = &self.hub;
1126 let pipeline = pass_try!(base, scope, hub.compute_pipelines.get(pipeline_id).get());
1127
1128 base.commands.push(ArcComputeCommand::SetPipeline(pipeline));
1129
1130 Ok(())
1131 }
1132
1133 pub fn compute_pass_set_push_constants(
1134 &self,
1135 pass: &mut ComputePass,
1136 offset: u32,
1137 data: &[u8],
1138 ) -> Result<(), PassStateError> {
1139 let scope = PassErrorScope::SetPushConstant;
1140 let base = pass_base!(pass, scope);
1141
1142 if offset & (wgt::PUSH_CONSTANT_ALIGNMENT - 1) != 0 {
1143 pass_try!(
1144 base,
1145 scope,
1146 Err(ComputePassErrorInner::PushConstantOffsetAlignment),
1147 );
1148 }
1149
1150 if data.len() as u32 & (wgt::PUSH_CONSTANT_ALIGNMENT - 1) != 0 {
1151 pass_try!(
1152 base,
1153 scope,
1154 Err(ComputePassErrorInner::PushConstantSizeAlignment),
1155 )
1156 }
1157 let value_offset = pass_try!(
1158 base,
1159 scope,
1160 base.push_constant_data
1161 .len()
1162 .try_into()
1163 .map_err(|_| ComputePassErrorInner::PushConstantOutOfMemory)
1164 );
1165
1166 base.push_constant_data.extend(
1167 data.chunks_exact(wgt::PUSH_CONSTANT_ALIGNMENT as usize)
1168 .map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])),
1169 );
1170
1171 base.commands.push(ArcComputeCommand::SetPushConstant {
1172 offset,
1173 size_bytes: data.len() as u32,
1174 values_offset: value_offset,
1175 });
1176
1177 Ok(())
1178 }
1179
1180 pub fn compute_pass_dispatch_workgroups(
1181 &self,
1182 pass: &mut ComputePass,
1183 groups_x: u32,
1184 groups_y: u32,
1185 groups_z: u32,
1186 ) -> Result<(), PassStateError> {
1187 let scope = PassErrorScope::Dispatch { indirect: false };
1188
1189 pass_base!(pass, scope)
1190 .commands
1191 .push(ArcComputeCommand::Dispatch([groups_x, groups_y, groups_z]));
1192
1193 Ok(())
1194 }
1195
1196 pub fn compute_pass_dispatch_workgroups_indirect(
1197 &self,
1198 pass: &mut ComputePass,
1199 buffer_id: id::BufferId,
1200 offset: BufferAddress,
1201 ) -> Result<(), PassStateError> {
1202 let hub = &self.hub;
1203 let scope = PassErrorScope::Dispatch { indirect: true };
1204 let base = pass_base!(pass, scope);
1205
1206 let buffer = pass_try!(base, scope, hub.buffers.get(buffer_id).get());
1207
1208 base.commands
1209 .push(ArcComputeCommand::DispatchIndirect { buffer, offset });
1210
1211 Ok(())
1212 }
1213
1214 pub fn compute_pass_push_debug_group(
1215 &self,
1216 pass: &mut ComputePass,
1217 label: &str,
1218 color: u32,
1219 ) -> Result<(), PassStateError> {
1220 let base = pass_base!(pass, PassErrorScope::PushDebugGroup);
1221
1222 let bytes = label.as_bytes();
1223 base.string_data.extend_from_slice(bytes);
1224
1225 base.commands.push(ArcComputeCommand::PushDebugGroup {
1226 color,
1227 len: bytes.len(),
1228 });
1229
1230 Ok(())
1231 }
1232
1233 pub fn compute_pass_pop_debug_group(
1234 &self,
1235 pass: &mut ComputePass,
1236 ) -> Result<(), PassStateError> {
1237 let base = pass_base!(pass, PassErrorScope::PopDebugGroup);
1238
1239 base.commands.push(ArcComputeCommand::PopDebugGroup);
1240
1241 Ok(())
1242 }
1243
1244 pub fn compute_pass_insert_debug_marker(
1245 &self,
1246 pass: &mut ComputePass,
1247 label: &str,
1248 color: u32,
1249 ) -> Result<(), PassStateError> {
1250 let base = pass_base!(pass, PassErrorScope::InsertDebugMarker);
1251
1252 let bytes = label.as_bytes();
1253 base.string_data.extend_from_slice(bytes);
1254
1255 base.commands.push(ArcComputeCommand::InsertDebugMarker {
1256 color,
1257 len: bytes.len(),
1258 });
1259
1260 Ok(())
1261 }
1262
1263 pub fn compute_pass_write_timestamp(
1264 &self,
1265 pass: &mut ComputePass,
1266 query_set_id: id::QuerySetId,
1267 query_index: u32,
1268 ) -> Result<(), PassStateError> {
1269 let scope = PassErrorScope::WriteTimestamp;
1270 let base = pass_base!(pass, scope);
1271
1272 let hub = &self.hub;
1273 let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1274
1275 base.commands.push(ArcComputeCommand::WriteTimestamp {
1276 query_set,
1277 query_index,
1278 });
1279
1280 Ok(())
1281 }
1282
1283 pub fn compute_pass_begin_pipeline_statistics_query(
1284 &self,
1285 pass: &mut ComputePass,
1286 query_set_id: id::QuerySetId,
1287 query_index: u32,
1288 ) -> Result<(), PassStateError> {
1289 let scope = PassErrorScope::BeginPipelineStatisticsQuery;
1290 let base = pass_base!(pass, scope);
1291
1292 let hub = &self.hub;
1293 let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1294
1295 base.commands
1296 .push(ArcComputeCommand::BeginPipelineStatisticsQuery {
1297 query_set,
1298 query_index,
1299 });
1300
1301 Ok(())
1302 }
1303
1304 pub fn compute_pass_end_pipeline_statistics_query(
1305 &self,
1306 pass: &mut ComputePass,
1307 ) -> Result<(), PassStateError> {
1308 pass_base!(pass, PassErrorScope::EndPipelineStatisticsQuery)
1309 .commands
1310 .push(ArcComputeCommand::EndPipelineStatisticsQuery);
1311
1312 Ok(())
1313 }
1314}