1use thiserror::Error;
2use wgt::{
3 error::{ErrorType, WebGpuError},
4 BufferAddress, DynamicOffset,
5};
6
7use alloc::{borrow::Cow, boxed::Box, sync::Arc, vec::Vec};
8use core::{convert::Infallible, fmt, str};
9
10use crate::{binding_model::BindError, resource::RawResourceAccess};
11use crate::{
12 binding_model::{LateMinBufferBindingSizeMismatch, PushConstantUploadError},
13 command::{
14 bind::{Binder, BinderError},
15 compute_command::ArcComputeCommand,
16 end_pipeline_statistics_query,
17 memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState},
18 pass_base, pass_try, validate_and_begin_pipeline_statistics_query, ArcPassTimestampWrites,
19 BasePass, BindGroupStateChange, CommandEncoderError, MapPassErr, PassErrorScope,
20 PassTimestampWrites, QueryUseError, StateChange,
21 },
22 device::{DeviceError, MissingDownlevelFlags, MissingFeatures},
23 global::Global,
24 hal_label, id,
25 init_tracker::MemoryInitKind,
26 pipeline::ComputePipeline,
27 resource::{
28 self, Buffer, InvalidResourceError, Labeled, MissingBufferUsageError, ParentDevice,
29 },
30 track::{ResourceUsageCompatibilityError, Tracker, TrackerIndex},
31 Label,
32};
33use crate::{command::InnerCommandEncoder, resource::DestroyedResourceError};
34use crate::{
35 command::{
36 encoder::EncodingState, pass, ArcCommand, CommandEncoder, DebugGroupError,
37 EncoderStateError, PassStateError, TimestampWritesError,
38 },
39 device::Device,
40};
41
42pub type ComputeBasePass = BasePass<ArcComputeCommand, ComputePassError>;
43
44pub struct ComputePass {
52 base: ComputeBasePass,
54
55 parent: Option<Arc<CommandEncoder>>,
61
62 timestamp_writes: Option<ArcPassTimestampWrites>,
63
64 current_bind_groups: BindGroupStateChange,
66 current_pipeline: StateChange<id::ComputePipelineId>,
67}
68
69impl ComputePass {
70 fn new(parent: Arc<CommandEncoder>, desc: ArcComputePassDescriptor) -> Self {
72 let ArcComputePassDescriptor {
73 label,
74 timestamp_writes,
75 } = desc;
76
77 Self {
78 base: BasePass::new(&label),
79 parent: Some(parent),
80 timestamp_writes,
81
82 current_bind_groups: BindGroupStateChange::new(),
83 current_pipeline: StateChange::new(),
84 }
85 }
86
87 fn new_invalid(parent: Arc<CommandEncoder>, label: &Label, err: ComputePassError) -> Self {
88 Self {
89 base: BasePass::new_invalid(label, err),
90 parent: Some(parent),
91 timestamp_writes: None,
92 current_bind_groups: BindGroupStateChange::new(),
93 current_pipeline: StateChange::new(),
94 }
95 }
96
97 #[inline]
98 pub fn label(&self) -> Option<&str> {
99 self.base.label.as_deref()
100 }
101}
102
103impl fmt::Debug for ComputePass {
104 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
105 match self.parent {
106 Some(ref cmd_enc) => write!(f, "ComputePass {{ parent: {} }}", cmd_enc.error_ident()),
107 None => write!(f, "ComputePass {{ parent: None }}"),
108 }
109 }
110}
111
112#[derive(Clone, Debug, Default)]
113pub struct ComputePassDescriptor<'a, PTW = PassTimestampWrites> {
114 pub label: Label<'a>,
115 pub timestamp_writes: Option<PTW>,
117}
118
119type ArcComputePassDescriptor<'a> = ComputePassDescriptor<'a, ArcPassTimestampWrites>;
121
122#[derive(Clone, Debug, Error)]
123#[non_exhaustive]
124pub enum DispatchError {
125 #[error("Compute pipeline must be set")]
126 MissingPipeline(pass::MissingPipeline),
127 #[error(transparent)]
128 IncompatibleBindGroup(#[from] Box<BinderError>),
129 #[error(
130 "Each current dispatch group size dimension ({current:?}) must be less or equal to {limit}"
131 )]
132 InvalidGroupSize { current: [u32; 3], limit: u32 },
133 #[error(transparent)]
134 BindingSizeTooSmall(#[from] LateMinBufferBindingSizeMismatch),
135}
136
137impl WebGpuError for DispatchError {
138 fn webgpu_error_type(&self) -> ErrorType {
139 ErrorType::Validation
140 }
141}
142
143#[derive(Clone, Debug, Error)]
145pub enum ComputePassErrorInner {
146 #[error(transparent)]
147 Device(#[from] DeviceError),
148 #[error(transparent)]
149 EncoderState(#[from] EncoderStateError),
150 #[error("Parent encoder is invalid")]
151 InvalidParentEncoder,
152 #[error(transparent)]
153 DebugGroupError(#[from] DebugGroupError),
154 #[error(transparent)]
155 BindGroupIndexOutOfRange(#[from] pass::BindGroupIndexOutOfRange),
156 #[error(transparent)]
157 DestroyedResource(#[from] DestroyedResourceError),
158 #[error("Indirect buffer offset {0:?} is not a multiple of 4")]
159 UnalignedIndirectBufferOffset(BufferAddress),
160 #[error("Indirect buffer uses bytes {offset}..{end_offset} which overruns indirect buffer of size {buffer_size}")]
161 IndirectBufferOverrun {
162 offset: u64,
163 end_offset: u64,
164 buffer_size: u64,
165 },
166 #[error(transparent)]
167 ResourceUsageCompatibility(#[from] ResourceUsageCompatibilityError),
168 #[error(transparent)]
169 MissingBufferUsage(#[from] MissingBufferUsageError),
170 #[error(transparent)]
171 Dispatch(#[from] DispatchError),
172 #[error(transparent)]
173 Bind(#[from] BindError),
174 #[error(transparent)]
175 PushConstants(#[from] PushConstantUploadError),
176 #[error("Push constant offset must be aligned to 4 bytes")]
177 PushConstantOffsetAlignment,
178 #[error("Push constant size must be aligned to 4 bytes")]
179 PushConstantSizeAlignment,
180 #[error("Ran out of push constant space. Don't set 4gb of push constants per ComputePass.")]
181 PushConstantOutOfMemory,
182 #[error(transparent)]
183 QueryUse(#[from] QueryUseError),
184 #[error(transparent)]
185 MissingFeatures(#[from] MissingFeatures),
186 #[error(transparent)]
187 MissingDownlevelFlags(#[from] MissingDownlevelFlags),
188 #[error("The compute pass has already been ended and no further commands can be recorded")]
189 PassEnded,
190 #[error(transparent)]
191 InvalidResource(#[from] InvalidResourceError),
192 #[error(transparent)]
193 TimestampWrites(#[from] TimestampWritesError),
194 #[error(transparent)]
196 InvalidValuesOffset(#[from] pass::InvalidValuesOffset),
197}
198
199#[derive(Clone, Debug, Error)]
202#[error("{scope}")]
203pub struct ComputePassError {
204 pub scope: PassErrorScope,
205 #[source]
206 pub(super) inner: ComputePassErrorInner,
207}
208
209impl From<pass::MissingPipeline> for ComputePassErrorInner {
210 fn from(value: pass::MissingPipeline) -> Self {
211 Self::Dispatch(DispatchError::MissingPipeline(value))
212 }
213}
214
215impl<E> MapPassErr<ComputePassError> for E
216where
217 E: Into<ComputePassErrorInner>,
218{
219 fn map_pass_err(self, scope: PassErrorScope) -> ComputePassError {
220 ComputePassError {
221 scope,
222 inner: self.into(),
223 }
224 }
225}
226
227impl WebGpuError for ComputePassError {
228 fn webgpu_error_type(&self) -> ErrorType {
229 let Self { scope: _, inner } = self;
230 let e: &dyn WebGpuError = match inner {
231 ComputePassErrorInner::Device(e) => e,
232 ComputePassErrorInner::EncoderState(e) => e,
233 ComputePassErrorInner::DebugGroupError(e) => e,
234 ComputePassErrorInner::DestroyedResource(e) => e,
235 ComputePassErrorInner::ResourceUsageCompatibility(e) => e,
236 ComputePassErrorInner::MissingBufferUsage(e) => e,
237 ComputePassErrorInner::Dispatch(e) => e,
238 ComputePassErrorInner::Bind(e) => e,
239 ComputePassErrorInner::PushConstants(e) => e,
240 ComputePassErrorInner::QueryUse(e) => e,
241 ComputePassErrorInner::MissingFeatures(e) => e,
242 ComputePassErrorInner::MissingDownlevelFlags(e) => e,
243 ComputePassErrorInner::InvalidResource(e) => e,
244 ComputePassErrorInner::TimestampWrites(e) => e,
245 ComputePassErrorInner::InvalidValuesOffset(e) => e,
246
247 ComputePassErrorInner::InvalidParentEncoder
248 | ComputePassErrorInner::BindGroupIndexOutOfRange { .. }
249 | ComputePassErrorInner::UnalignedIndirectBufferOffset(_)
250 | ComputePassErrorInner::IndirectBufferOverrun { .. }
251 | ComputePassErrorInner::PushConstantOffsetAlignment
252 | ComputePassErrorInner::PushConstantSizeAlignment
253 | ComputePassErrorInner::PushConstantOutOfMemory
254 | ComputePassErrorInner::PassEnded => return ErrorType::Validation,
255 };
256 e.webgpu_error_type()
257 }
258}
259
260struct State<'scope, 'snatch_guard, 'cmd_enc> {
261 pipeline: Option<Arc<ComputePipeline>>,
262
263 pass: pass::PassState<'scope, 'snatch_guard, 'cmd_enc>,
264
265 active_query: Option<(Arc<resource::QuerySet>, u32)>,
266
267 push_constants: Vec<u32>,
268
269 intermediate_trackers: Tracker,
270}
271
272impl<'scope, 'snatch_guard, 'cmd_enc> State<'scope, 'snatch_guard, 'cmd_enc> {
273 fn is_ready(&self) -> Result<(), DispatchError> {
274 if let Some(pipeline) = self.pipeline.as_ref() {
275 self.pass.binder.check_compatibility(pipeline.as_ref())?;
276 self.pass.binder.check_late_buffer_bindings()?;
277 Ok(())
278 } else {
279 Err(DispatchError::MissingPipeline(pass::MissingPipeline))
280 }
281 }
282
283 fn flush_states(
286 &mut self,
287 indirect_buffer: Option<TrackerIndex>,
288 ) -> Result<(), ResourceUsageCompatibilityError> {
289 for bind_group in self.pass.binder.list_active() {
290 unsafe { self.pass.scope.merge_bind_group(&bind_group.used)? };
291 }
294
295 for bind_group in self.pass.binder.list_active() {
296 unsafe {
297 self.intermediate_trackers
298 .set_and_remove_from_usage_scope_sparse(&mut self.pass.scope, &bind_group.used)
299 }
300 }
301
302 unsafe {
304 self.intermediate_trackers
305 .buffers
306 .set_and_remove_from_usage_scope_sparse(
307 &mut self.pass.scope.buffers,
308 indirect_buffer,
309 );
310 }
311
312 CommandEncoder::drain_barriers(
313 self.pass.base.raw_encoder,
314 &mut self.intermediate_trackers,
315 self.pass.base.snatch_guard,
316 );
317 Ok(())
318 }
319}
320
321impl Global {
324 pub fn command_encoder_begin_compute_pass(
335 &self,
336 encoder_id: id::CommandEncoderId,
337 desc: &ComputePassDescriptor<'_>,
338 ) -> (ComputePass, Option<CommandEncoderError>) {
339 use EncoderStateError as SErr;
340
341 let scope = PassErrorScope::Pass;
342 let hub = &self.hub;
343
344 let label = desc.label.as_deref().map(Cow::Borrowed);
345
346 let cmd_enc = hub.command_encoders.get(encoder_id);
347 let mut cmd_buf_data = cmd_enc.data.lock();
348
349 match cmd_buf_data.lock_encoder() {
350 Ok(()) => {
351 drop(cmd_buf_data);
352 if let Err(err) = cmd_enc.device.check_is_valid() {
353 return (
354 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
355 None,
356 );
357 }
358
359 match desc
360 .timestamp_writes
361 .as_ref()
362 .map(|tw| {
363 Self::validate_pass_timestamp_writes::<ComputePassErrorInner>(
364 &cmd_enc.device,
365 &hub.query_sets.read(),
366 tw,
367 )
368 })
369 .transpose()
370 {
371 Ok(timestamp_writes) => {
372 let arc_desc = ArcComputePassDescriptor {
373 label,
374 timestamp_writes,
375 };
376 (ComputePass::new(cmd_enc, arc_desc), None)
377 }
378 Err(err) => (
379 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
380 None,
381 ),
382 }
383 }
384 Err(err @ SErr::Locked) => {
385 cmd_buf_data.invalidate(err.clone());
389 drop(cmd_buf_data);
390 (
391 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
392 None,
393 )
394 }
395 Err(err @ (SErr::Ended | SErr::Submitted)) => {
396 drop(cmd_buf_data);
399 (
400 ComputePass::new_invalid(cmd_enc, &label, err.clone().map_pass_err(scope)),
401 Some(err.into()),
402 )
403 }
404 Err(err @ SErr::Invalid) => {
405 drop(cmd_buf_data);
411 (
412 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
413 None,
414 )
415 }
416 Err(SErr::Unlocked) => {
417 unreachable!("lock_encoder cannot fail due to the encoder being unlocked")
418 }
419 }
420 }
421
422 #[doc(hidden)]
428 #[cfg(any(feature = "serde", feature = "replay"))]
429 pub fn compute_pass_end_with_unresolved_commands(
430 &self,
431 encoder_id: id::CommandEncoderId,
432 base: BasePass<super::ComputeCommand, Infallible>,
433 timestamp_writes: Option<&PassTimestampWrites>,
434 ) {
435 #[cfg(feature = "trace")]
436 {
437 let cmd_enc = self.hub.command_encoders.get(encoder_id);
438 let mut cmd_buf_data = cmd_enc.data.lock();
439 let cmd_buf_data = cmd_buf_data.get_inner();
440
441 if let Some(ref mut list) = cmd_buf_data.trace_commands {
442 list.push(crate::command::Command::RunComputePass {
443 base: BasePass {
444 label: base.label.clone(),
445 error: None,
446 commands: base.commands.clone(),
447 dynamic_offsets: base.dynamic_offsets.clone(),
448 string_data: base.string_data.clone(),
449 push_constant_data: base.push_constant_data.clone(),
450 },
451 timestamp_writes: timestamp_writes.cloned(),
452 });
453 }
454 }
455
456 let BasePass {
457 label,
458 error: _,
459 commands,
460 dynamic_offsets,
461 string_data,
462 push_constant_data,
463 } = base;
464
465 let (mut compute_pass, encoder_error) = self.command_encoder_begin_compute_pass(
466 encoder_id,
467 &ComputePassDescriptor {
468 label: label.as_deref().map(Cow::Borrowed),
469 timestamp_writes: timestamp_writes.cloned(),
470 },
471 );
472 if let Some(err) = encoder_error {
473 panic!("{:?}", err);
474 };
475
476 compute_pass.base = BasePass {
477 label,
478 error: None,
479 commands: super::ComputeCommand::resolve_compute_command_ids(&self.hub, &commands)
480 .unwrap(),
481 dynamic_offsets,
482 string_data,
483 push_constant_data,
484 };
485
486 self.compute_pass_end(&mut compute_pass).unwrap();
487 }
488
489 pub fn compute_pass_end(&self, pass: &mut ComputePass) -> Result<(), EncoderStateError> {
490 profiling::scope!(
491 "CommandEncoder::run_compute_pass {}",
492 pass.base.label.as_deref().unwrap_or("")
493 );
494
495 let cmd_enc = pass.parent.take().ok_or(EncoderStateError::Ended)?;
496 let mut cmd_buf_data = cmd_enc.data.lock();
497
498 cmd_buf_data.unlock_encoder()?;
499
500 let base = pass.base.take();
501
502 if matches!(
503 base,
504 Err(ComputePassError {
505 inner: ComputePassErrorInner::EncoderState(EncoderStateError::Ended),
506 scope: _,
507 })
508 ) {
509 return Err(EncoderStateError::Ended);
518 }
519
520 cmd_buf_data.push_with(|| -> Result<_, ComputePassError> {
521 Ok(ArcCommand::RunComputePass {
522 pass: base?,
523 timestamp_writes: pass.timestamp_writes.take(),
524 })
525 })
526 }
527}
528
529pub(super) fn encode_compute_pass(
530 parent_state: &mut EncodingState<InnerCommandEncoder>,
531 mut base: BasePass<ArcComputeCommand, Infallible>,
532 mut timestamp_writes: Option<ArcPassTimestampWrites>,
533) -> Result<(), ComputePassError> {
534 let pass_scope = PassErrorScope::Pass;
535
536 let device = parent_state.device;
537
538 parent_state
542 .raw_encoder
543 .close_if_open()
544 .map_pass_err(pass_scope)?;
545 let raw_encoder = parent_state
546 .raw_encoder
547 .open_pass(base.label.as_deref())
548 .map_pass_err(pass_scope)?;
549
550 let mut debug_scope_depth = 0;
551
552 let mut state = State {
553 pipeline: None,
554
555 pass: pass::PassState {
556 base: EncodingState {
557 device,
558 raw_encoder,
559 tracker: parent_state.tracker,
560 buffer_memory_init_actions: parent_state.buffer_memory_init_actions,
561 texture_memory_actions: parent_state.texture_memory_actions,
562 as_actions: parent_state.as_actions,
563 temp_resources: parent_state.temp_resources,
564 indirect_draw_validation_resources: parent_state.indirect_draw_validation_resources,
565 snatch_guard: parent_state.snatch_guard,
566 debug_scope_depth: &mut debug_scope_depth,
567 },
568 binder: Binder::new(),
569 temp_offsets: Vec::new(),
570 dynamic_offset_count: 0,
571 pending_discard_init_fixups: SurfacesInDiscardState::new(),
572 scope: device.new_usage_scope(),
573 string_offset: 0,
574 },
575 active_query: None,
576
577 push_constants: Vec::new(),
578
579 intermediate_trackers: Tracker::new(),
580 };
581
582 let indices = &device.tracker_indices;
583 state
584 .pass
585 .base
586 .tracker
587 .buffers
588 .set_size(indices.buffers.size());
589 state
590 .pass
591 .base
592 .tracker
593 .textures
594 .set_size(indices.textures.size());
595
596 let timestamp_writes: Option<hal::PassTimestampWrites<'_, dyn hal::DynQuerySet>> =
597 if let Some(tw) = timestamp_writes.take() {
598 tw.query_set.same_device(device).map_pass_err(pass_scope)?;
599
600 let query_set = state
601 .pass
602 .base
603 .tracker
604 .query_sets
605 .insert_single(tw.query_set);
606
607 let range = if let (Some(index_a), Some(index_b)) =
610 (tw.beginning_of_pass_write_index, tw.end_of_pass_write_index)
611 {
612 Some(index_a.min(index_b)..index_a.max(index_b) + 1)
613 } else {
614 tw.beginning_of_pass_write_index
615 .or(tw.end_of_pass_write_index)
616 .map(|i| i..i + 1)
617 };
618 if let Some(range) = range {
621 unsafe {
622 state
623 .pass
624 .base
625 .raw_encoder
626 .reset_queries(query_set.raw(), range);
627 }
628 }
629
630 Some(hal::PassTimestampWrites {
631 query_set: query_set.raw(),
632 beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
633 end_of_pass_write_index: tw.end_of_pass_write_index,
634 })
635 } else {
636 None
637 };
638
639 let hal_desc = hal::ComputePassDescriptor {
640 label: hal_label(base.label.as_deref(), device.instance_flags),
641 timestamp_writes,
642 };
643
644 unsafe {
645 state.pass.base.raw_encoder.begin_compute_pass(&hal_desc);
646 }
647
648 for command in base.commands.drain(..) {
649 match command {
650 ArcComputeCommand::SetBindGroup {
651 index,
652 num_dynamic_offsets,
653 bind_group,
654 } => {
655 let scope = PassErrorScope::SetBindGroup;
656 pass::set_bind_group::<ComputePassErrorInner>(
657 &mut state.pass,
658 device,
659 &base.dynamic_offsets,
660 index,
661 num_dynamic_offsets,
662 bind_group,
663 false,
664 )
665 .map_pass_err(scope)?;
666 }
667 ArcComputeCommand::SetPipeline(pipeline) => {
668 let scope = PassErrorScope::SetPipelineCompute;
669 set_pipeline(&mut state, device, pipeline).map_pass_err(scope)?;
670 }
671 ArcComputeCommand::SetPushConstant {
672 offset,
673 size_bytes,
674 values_offset,
675 } => {
676 let scope = PassErrorScope::SetPushConstant;
677 pass::set_push_constant::<ComputePassErrorInner, _>(
678 &mut state.pass,
679 &base.push_constant_data,
680 wgt::ShaderStages::COMPUTE,
681 offset,
682 size_bytes,
683 Some(values_offset),
684 |data_slice| {
685 let offset_in_elements = (offset / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
686 let size_in_elements = (size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
687 state.push_constants[offset_in_elements..][..size_in_elements]
688 .copy_from_slice(data_slice);
689 },
690 )
691 .map_pass_err(scope)?;
692 }
693 ArcComputeCommand::Dispatch(groups) => {
694 let scope = PassErrorScope::Dispatch { indirect: false };
695 dispatch(&mut state, groups).map_pass_err(scope)?;
696 }
697 ArcComputeCommand::DispatchIndirect { buffer, offset } => {
698 let scope = PassErrorScope::Dispatch { indirect: true };
699 dispatch_indirect(&mut state, device, buffer, offset).map_pass_err(scope)?;
700 }
701 ArcComputeCommand::PushDebugGroup { color: _, len } => {
702 pass::push_debug_group(&mut state.pass, &base.string_data, len);
703 }
704 ArcComputeCommand::PopDebugGroup => {
705 let scope = PassErrorScope::PopDebugGroup;
706 pass::pop_debug_group::<ComputePassErrorInner>(&mut state.pass)
707 .map_pass_err(scope)?;
708 }
709 ArcComputeCommand::InsertDebugMarker { color: _, len } => {
710 pass::insert_debug_marker(&mut state.pass, &base.string_data, len);
711 }
712 ArcComputeCommand::WriteTimestamp {
713 query_set,
714 query_index,
715 } => {
716 let scope = PassErrorScope::WriteTimestamp;
717 pass::write_timestamp::<ComputePassErrorInner>(
718 &mut state.pass,
719 device,
720 None, query_set,
722 query_index,
723 )
724 .map_pass_err(scope)?;
725 }
726 ArcComputeCommand::BeginPipelineStatisticsQuery {
727 query_set,
728 query_index,
729 } => {
730 let scope = PassErrorScope::BeginPipelineStatisticsQuery;
731 validate_and_begin_pipeline_statistics_query(
732 query_set,
733 state.pass.base.raw_encoder,
734 &mut state.pass.base.tracker.query_sets,
735 device,
736 query_index,
737 None,
738 &mut state.active_query,
739 )
740 .map_pass_err(scope)?;
741 }
742 ArcComputeCommand::EndPipelineStatisticsQuery => {
743 let scope = PassErrorScope::EndPipelineStatisticsQuery;
744 end_pipeline_statistics_query(state.pass.base.raw_encoder, &mut state.active_query)
745 .map_pass_err(scope)?;
746 }
747 }
748 }
749
750 if *state.pass.base.debug_scope_depth > 0 {
751 Err(
752 ComputePassErrorInner::DebugGroupError(DebugGroupError::MissingPop)
753 .map_pass_err(pass_scope),
754 )?;
755 }
756
757 unsafe {
758 state.pass.base.raw_encoder.end_compute_pass();
759 }
760
761 let State {
762 pass: pass::PassState {
763 pending_discard_init_fixups,
764 ..
765 },
766 intermediate_trackers,
767 ..
768 } = state;
769
770 parent_state.raw_encoder.close().map_pass_err(pass_scope)?;
772
773 let transit = parent_state
777 .raw_encoder
778 .open_pass(hal_label(
779 Some("(wgpu internal) Pre Pass"),
780 device.instance_flags,
781 ))
782 .map_pass_err(pass_scope)?;
783 fixup_discarded_surfaces(
784 pending_discard_init_fixups.into_iter(),
785 transit,
786 &mut parent_state.tracker.textures,
787 device,
788 parent_state.snatch_guard,
789 );
790 CommandEncoder::insert_barriers_from_tracker(
791 transit,
792 parent_state.tracker,
793 &intermediate_trackers,
794 parent_state.snatch_guard,
795 );
796 parent_state
798 .raw_encoder
799 .close_and_swap()
800 .map_pass_err(pass_scope)?;
801
802 Ok(())
803}
804
805fn set_pipeline(
806 state: &mut State,
807 device: &Arc<Device>,
808 pipeline: Arc<ComputePipeline>,
809) -> Result<(), ComputePassErrorInner> {
810 pipeline.same_device(device)?;
811
812 state.pipeline = Some(pipeline.clone());
813
814 let pipeline = state
815 .pass
816 .base
817 .tracker
818 .compute_pipelines
819 .insert_single(pipeline)
820 .clone();
821
822 unsafe {
823 state
824 .pass
825 .base
826 .raw_encoder
827 .set_compute_pipeline(pipeline.raw());
828 }
829
830 pass::rebind_resources::<ComputePassErrorInner, _>(
832 &mut state.pass,
833 &pipeline.layout,
834 &pipeline.late_sized_buffer_groups,
835 || {
836 state.push_constants.clear();
839 if let Some(push_constant_range) =
841 pipeline.layout.push_constant_ranges.iter().find_map(|pcr| {
842 pcr.stages
843 .contains(wgt::ShaderStages::COMPUTE)
844 .then_some(pcr.range.clone())
845 })
846 {
847 let len = push_constant_range.len() / wgt::PUSH_CONSTANT_ALIGNMENT as usize;
849 state.push_constants.extend(core::iter::repeat_n(0, len));
850 }
851 },
852 )
853}
854
855fn dispatch(state: &mut State, groups: [u32; 3]) -> Result<(), ComputePassErrorInner> {
856 state.is_ready()?;
857
858 state.flush_states(None)?;
859
860 let groups_size_limit = state
861 .pass
862 .base
863 .device
864 .limits
865 .max_compute_workgroups_per_dimension;
866
867 if groups[0] > groups_size_limit
868 || groups[1] > groups_size_limit
869 || groups[2] > groups_size_limit
870 {
871 return Err(ComputePassErrorInner::Dispatch(
872 DispatchError::InvalidGroupSize {
873 current: groups,
874 limit: groups_size_limit,
875 },
876 ));
877 }
878
879 unsafe {
880 state.pass.base.raw_encoder.dispatch(groups);
881 }
882 Ok(())
883}
884
885fn dispatch_indirect(
886 state: &mut State,
887 device: &Arc<Device>,
888 buffer: Arc<Buffer>,
889 offset: u64,
890) -> Result<(), ComputePassErrorInner> {
891 buffer.same_device(device)?;
892
893 state.is_ready()?;
894
895 state
896 .pass
897 .base
898 .device
899 .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?;
900
901 buffer.check_usage(wgt::BufferUsages::INDIRECT)?;
902 buffer.check_destroyed(state.pass.base.snatch_guard)?;
903
904 if offset % 4 != 0 {
905 return Err(ComputePassErrorInner::UnalignedIndirectBufferOffset(offset));
906 }
907
908 let end_offset = offset + size_of::<wgt::DispatchIndirectArgs>() as u64;
909 if end_offset > buffer.size {
910 return Err(ComputePassErrorInner::IndirectBufferOverrun {
911 offset,
912 end_offset,
913 buffer_size: buffer.size,
914 });
915 }
916
917 let stride = 3 * 4; state.pass.base.buffer_memory_init_actions.extend(
919 buffer.initialization_status.read().create_action(
920 &buffer,
921 offset..(offset + stride),
922 MemoryInitKind::NeedsInitializedMemory,
923 ),
924 );
925
926 if let Some(ref indirect_validation) = state.pass.base.device.indirect_validation {
927 let params = indirect_validation.dispatch.params(
928 &state.pass.base.device.limits,
929 offset,
930 buffer.size,
931 );
932
933 unsafe {
934 state
935 .pass
936 .base
937 .raw_encoder
938 .set_compute_pipeline(params.pipeline);
939 }
940
941 unsafe {
942 state.pass.base.raw_encoder.set_push_constants(
943 params.pipeline_layout,
944 wgt::ShaderStages::COMPUTE,
945 0,
946 &[params.offset_remainder as u32 / 4],
947 );
948 }
949
950 unsafe {
951 state.pass.base.raw_encoder.set_bind_group(
952 params.pipeline_layout,
953 0,
954 Some(params.dst_bind_group),
955 &[],
956 );
957 }
958 unsafe {
959 state.pass.base.raw_encoder.set_bind_group(
960 params.pipeline_layout,
961 1,
962 Some(
963 buffer
964 .indirect_validation_bind_groups
965 .get(state.pass.base.snatch_guard)
966 .unwrap()
967 .dispatch
968 .as_ref(),
969 ),
970 &[params.aligned_offset as u32],
971 );
972 }
973
974 let src_transition = state
975 .intermediate_trackers
976 .buffers
977 .set_single(&buffer, wgt::BufferUses::STORAGE_READ_ONLY);
978 let src_barrier = src_transition
979 .map(|transition| transition.into_hal(&buffer, state.pass.base.snatch_guard));
980 unsafe {
981 state
982 .pass
983 .base
984 .raw_encoder
985 .transition_buffers(src_barrier.as_slice());
986 }
987
988 unsafe {
989 state
990 .pass
991 .base
992 .raw_encoder
993 .transition_buffers(&[hal::BufferBarrier {
994 buffer: params.dst_buffer,
995 usage: hal::StateTransition {
996 from: wgt::BufferUses::INDIRECT,
997 to: wgt::BufferUses::STORAGE_READ_WRITE,
998 },
999 }]);
1000 }
1001
1002 unsafe {
1003 state.pass.base.raw_encoder.dispatch([1, 1, 1]);
1004 }
1005
1006 {
1008 let pipeline = state.pipeline.as_ref().unwrap();
1009
1010 unsafe {
1011 state
1012 .pass
1013 .base
1014 .raw_encoder
1015 .set_compute_pipeline(pipeline.raw());
1016 }
1017
1018 if !state.push_constants.is_empty() {
1019 unsafe {
1020 state.pass.base.raw_encoder.set_push_constants(
1021 pipeline.layout.raw(),
1022 wgt::ShaderStages::COMPUTE,
1023 0,
1024 &state.push_constants,
1025 );
1026 }
1027 }
1028
1029 for (i, e) in state.pass.binder.list_valid() {
1030 let group = e.group.as_ref().unwrap();
1031 let raw_bg = group.try_raw(state.pass.base.snatch_guard)?;
1032 unsafe {
1033 state.pass.base.raw_encoder.set_bind_group(
1034 pipeline.layout.raw(),
1035 i as u32,
1036 Some(raw_bg),
1037 &e.dynamic_offsets,
1038 );
1039 }
1040 }
1041 }
1042
1043 unsafe {
1044 state
1045 .pass
1046 .base
1047 .raw_encoder
1048 .transition_buffers(&[hal::BufferBarrier {
1049 buffer: params.dst_buffer,
1050 usage: hal::StateTransition {
1051 from: wgt::BufferUses::STORAGE_READ_WRITE,
1052 to: wgt::BufferUses::INDIRECT,
1053 },
1054 }]);
1055 }
1056
1057 state.flush_states(None)?;
1058 unsafe {
1059 state
1060 .pass
1061 .base
1062 .raw_encoder
1063 .dispatch_indirect(params.dst_buffer, 0);
1064 }
1065 } else {
1066 state
1067 .pass
1068 .scope
1069 .buffers
1070 .merge_single(&buffer, wgt::BufferUses::INDIRECT)?;
1071
1072 use crate::resource::Trackable;
1073 state.flush_states(Some(buffer.tracker_index()))?;
1074
1075 let buf_raw = buffer.try_raw(state.pass.base.snatch_guard)?;
1076 unsafe {
1077 state
1078 .pass
1079 .base
1080 .raw_encoder
1081 .dispatch_indirect(buf_raw, offset);
1082 }
1083 }
1084
1085 Ok(())
1086}
1087
1088impl Global {
1101 pub fn compute_pass_set_bind_group(
1102 &self,
1103 pass: &mut ComputePass,
1104 index: u32,
1105 bind_group_id: Option<id::BindGroupId>,
1106 offsets: &[DynamicOffset],
1107 ) -> Result<(), PassStateError> {
1108 let scope = PassErrorScope::SetBindGroup;
1109
1110 let base = pass_base!(pass, scope);
1114
1115 if pass.current_bind_groups.set_and_check_redundant(
1116 bind_group_id,
1117 index,
1118 &mut base.dynamic_offsets,
1119 offsets,
1120 ) {
1121 return Ok(());
1122 }
1123
1124 let mut bind_group = None;
1125 if bind_group_id.is_some() {
1126 let bind_group_id = bind_group_id.unwrap();
1127
1128 let hub = &self.hub;
1129 bind_group = Some(pass_try!(
1130 base,
1131 scope,
1132 hub.bind_groups.get(bind_group_id).get(),
1133 ));
1134 }
1135
1136 base.commands.push(ArcComputeCommand::SetBindGroup {
1137 index,
1138 num_dynamic_offsets: offsets.len(),
1139 bind_group,
1140 });
1141
1142 Ok(())
1143 }
1144
1145 pub fn compute_pass_set_pipeline(
1146 &self,
1147 pass: &mut ComputePass,
1148 pipeline_id: id::ComputePipelineId,
1149 ) -> Result<(), PassStateError> {
1150 let redundant = pass.current_pipeline.set_and_check_redundant(pipeline_id);
1151
1152 let scope = PassErrorScope::SetPipelineCompute;
1153
1154 let base = pass_base!(pass, scope);
1157
1158 if redundant {
1159 return Ok(());
1160 }
1161
1162 let hub = &self.hub;
1163 let pipeline = pass_try!(base, scope, hub.compute_pipelines.get(pipeline_id).get());
1164
1165 base.commands.push(ArcComputeCommand::SetPipeline(pipeline));
1166
1167 Ok(())
1168 }
1169
1170 pub fn compute_pass_set_push_constants(
1171 &self,
1172 pass: &mut ComputePass,
1173 offset: u32,
1174 data: &[u8],
1175 ) -> Result<(), PassStateError> {
1176 let scope = PassErrorScope::SetPushConstant;
1177 let base = pass_base!(pass, scope);
1178
1179 if offset & (wgt::PUSH_CONSTANT_ALIGNMENT - 1) != 0 {
1180 pass_try!(
1181 base,
1182 scope,
1183 Err(ComputePassErrorInner::PushConstantOffsetAlignment),
1184 );
1185 }
1186
1187 if data.len() as u32 & (wgt::PUSH_CONSTANT_ALIGNMENT - 1) != 0 {
1188 pass_try!(
1189 base,
1190 scope,
1191 Err(ComputePassErrorInner::PushConstantSizeAlignment),
1192 )
1193 }
1194 let value_offset = pass_try!(
1195 base,
1196 scope,
1197 base.push_constant_data
1198 .len()
1199 .try_into()
1200 .map_err(|_| ComputePassErrorInner::PushConstantOutOfMemory)
1201 );
1202
1203 base.push_constant_data.extend(
1204 data.chunks_exact(wgt::PUSH_CONSTANT_ALIGNMENT as usize)
1205 .map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])),
1206 );
1207
1208 base.commands.push(ArcComputeCommand::SetPushConstant {
1209 offset,
1210 size_bytes: data.len() as u32,
1211 values_offset: value_offset,
1212 });
1213
1214 Ok(())
1215 }
1216
1217 pub fn compute_pass_dispatch_workgroups(
1218 &self,
1219 pass: &mut ComputePass,
1220 groups_x: u32,
1221 groups_y: u32,
1222 groups_z: u32,
1223 ) -> Result<(), PassStateError> {
1224 let scope = PassErrorScope::Dispatch { indirect: false };
1225
1226 pass_base!(pass, scope)
1227 .commands
1228 .push(ArcComputeCommand::Dispatch([groups_x, groups_y, groups_z]));
1229
1230 Ok(())
1231 }
1232
1233 pub fn compute_pass_dispatch_workgroups_indirect(
1234 &self,
1235 pass: &mut ComputePass,
1236 buffer_id: id::BufferId,
1237 offset: BufferAddress,
1238 ) -> Result<(), PassStateError> {
1239 let hub = &self.hub;
1240 let scope = PassErrorScope::Dispatch { indirect: true };
1241 let base = pass_base!(pass, scope);
1242
1243 let buffer = pass_try!(base, scope, hub.buffers.get(buffer_id).get());
1244
1245 base.commands
1246 .push(ArcComputeCommand::DispatchIndirect { buffer, offset });
1247
1248 Ok(())
1249 }
1250
1251 pub fn compute_pass_push_debug_group(
1252 &self,
1253 pass: &mut ComputePass,
1254 label: &str,
1255 color: u32,
1256 ) -> Result<(), PassStateError> {
1257 let base = pass_base!(pass, PassErrorScope::PushDebugGroup);
1258
1259 let bytes = label.as_bytes();
1260 base.string_data.extend_from_slice(bytes);
1261
1262 base.commands.push(ArcComputeCommand::PushDebugGroup {
1263 color,
1264 len: bytes.len(),
1265 });
1266
1267 Ok(())
1268 }
1269
1270 pub fn compute_pass_pop_debug_group(
1271 &self,
1272 pass: &mut ComputePass,
1273 ) -> Result<(), PassStateError> {
1274 let base = pass_base!(pass, PassErrorScope::PopDebugGroup);
1275
1276 base.commands.push(ArcComputeCommand::PopDebugGroup);
1277
1278 Ok(())
1279 }
1280
1281 pub fn compute_pass_insert_debug_marker(
1282 &self,
1283 pass: &mut ComputePass,
1284 label: &str,
1285 color: u32,
1286 ) -> Result<(), PassStateError> {
1287 let base = pass_base!(pass, PassErrorScope::InsertDebugMarker);
1288
1289 let bytes = label.as_bytes();
1290 base.string_data.extend_from_slice(bytes);
1291
1292 base.commands.push(ArcComputeCommand::InsertDebugMarker {
1293 color,
1294 len: bytes.len(),
1295 });
1296
1297 Ok(())
1298 }
1299
1300 pub fn compute_pass_write_timestamp(
1301 &self,
1302 pass: &mut ComputePass,
1303 query_set_id: id::QuerySetId,
1304 query_index: u32,
1305 ) -> Result<(), PassStateError> {
1306 let scope = PassErrorScope::WriteTimestamp;
1307 let base = pass_base!(pass, scope);
1308
1309 let hub = &self.hub;
1310 let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1311
1312 base.commands.push(ArcComputeCommand::WriteTimestamp {
1313 query_set,
1314 query_index,
1315 });
1316
1317 Ok(())
1318 }
1319
1320 pub fn compute_pass_begin_pipeline_statistics_query(
1321 &self,
1322 pass: &mut ComputePass,
1323 query_set_id: id::QuerySetId,
1324 query_index: u32,
1325 ) -> Result<(), PassStateError> {
1326 let scope = PassErrorScope::BeginPipelineStatisticsQuery;
1327 let base = pass_base!(pass, scope);
1328
1329 let hub = &self.hub;
1330 let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1331
1332 base.commands
1333 .push(ArcComputeCommand::BeginPipelineStatisticsQuery {
1334 query_set,
1335 query_index,
1336 });
1337
1338 Ok(())
1339 }
1340
1341 pub fn compute_pass_end_pipeline_statistics_query(
1342 &self,
1343 pass: &mut ComputePass,
1344 ) -> Result<(), PassStateError> {
1345 pass_base!(pass, PassErrorScope::EndPipelineStatisticsQuery)
1346 .commands
1347 .push(ArcComputeCommand::EndPipelineStatisticsQuery);
1348
1349 Ok(())
1350 }
1351}