wgpu_core/command/
compute.rs

1use thiserror::Error;
2use wgt::{
3    error::{ErrorType, WebGpuError},
4    BufferAddress, DynamicOffset,
5};
6
7use alloc::{borrow::Cow, boxed::Box, sync::Arc, vec::Vec};
8use core::{convert::Infallible, fmt, str};
9
10use crate::{binding_model::BindError, resource::RawResourceAccess};
11use crate::{
12    binding_model::{LateMinBufferBindingSizeMismatch, PushConstantUploadError},
13    command::{
14        bind::{Binder, BinderError},
15        compute_command::ArcComputeCommand,
16        end_pipeline_statistics_query,
17        memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState},
18        pass_base, pass_try, validate_and_begin_pipeline_statistics_query, ArcPassTimestampWrites,
19        BasePass, BindGroupStateChange, CommandEncoderError, MapPassErr, PassErrorScope,
20        PassTimestampWrites, QueryUseError, StateChange,
21    },
22    device::{DeviceError, MissingDownlevelFlags, MissingFeatures},
23    global::Global,
24    hal_label, id,
25    init_tracker::MemoryInitKind,
26    pipeline::ComputePipeline,
27    resource::{
28        self, Buffer, InvalidResourceError, Labeled, MissingBufferUsageError, ParentDevice,
29    },
30    track::{ResourceUsageCompatibilityError, Tracker, TrackerIndex},
31    Label,
32};
33use crate::{command::InnerCommandEncoder, resource::DestroyedResourceError};
34use crate::{
35    command::{
36        encoder::EncodingState, pass, ArcCommand, CommandEncoder, DebugGroupError,
37        EncoderStateError, PassStateError, TimestampWritesError,
38    },
39    device::Device,
40};
41
42pub type ComputeBasePass = BasePass<ArcComputeCommand, ComputePassError>;
43
44/// A pass's [encoder state](https://www.w3.org/TR/webgpu/#encoder-state) and
45/// its validity are two distinct conditions, i.e., the full matrix of
46/// (open, ended) x (valid, invalid) is possible.
47///
48/// The presence or absence of the `parent` `Option` indicates the pass's state.
49/// The presence or absence of an error in `base.error` indicates the pass's
50/// validity.
51pub struct ComputePass {
52    /// All pass data & records is stored here.
53    base: ComputeBasePass,
54
55    /// Parent command encoder that this pass records commands into.
56    ///
57    /// If this is `Some`, then the pass is in WebGPU's "open" state. If it is
58    /// `None`, then the pass is in the "ended" state.
59    /// See <https://www.w3.org/TR/webgpu/#encoder-state>
60    parent: Option<Arc<CommandEncoder>>,
61
62    timestamp_writes: Option<ArcPassTimestampWrites>,
63
64    // Resource binding dedupe state.
65    current_bind_groups: BindGroupStateChange,
66    current_pipeline: StateChange<id::ComputePipelineId>,
67}
68
69impl ComputePass {
70    /// If the parent command encoder is invalid, the returned pass will be invalid.
71    fn new(parent: Arc<CommandEncoder>, desc: ArcComputePassDescriptor) -> Self {
72        let ArcComputePassDescriptor {
73            label,
74            timestamp_writes,
75        } = desc;
76
77        Self {
78            base: BasePass::new(&label),
79            parent: Some(parent),
80            timestamp_writes,
81
82            current_bind_groups: BindGroupStateChange::new(),
83            current_pipeline: StateChange::new(),
84        }
85    }
86
87    fn new_invalid(parent: Arc<CommandEncoder>, label: &Label, err: ComputePassError) -> Self {
88        Self {
89            base: BasePass::new_invalid(label, err),
90            parent: Some(parent),
91            timestamp_writes: None,
92            current_bind_groups: BindGroupStateChange::new(),
93            current_pipeline: StateChange::new(),
94        }
95    }
96
97    #[inline]
98    pub fn label(&self) -> Option<&str> {
99        self.base.label.as_deref()
100    }
101}
102
103impl fmt::Debug for ComputePass {
104    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
105        match self.parent {
106            Some(ref cmd_enc) => write!(f, "ComputePass {{ parent: {} }}", cmd_enc.error_ident()),
107            None => write!(f, "ComputePass {{ parent: None }}"),
108        }
109    }
110}
111
112#[derive(Clone, Debug, Default)]
113pub struct ComputePassDescriptor<'a, PTW = PassTimestampWrites> {
114    pub label: Label<'a>,
115    /// Defines where and when timestamp values will be written for this pass.
116    pub timestamp_writes: Option<PTW>,
117}
118
119/// cbindgen:ignore
120type ArcComputePassDescriptor<'a> = ComputePassDescriptor<'a, ArcPassTimestampWrites>;
121
122#[derive(Clone, Debug, Error)]
123#[non_exhaustive]
124pub enum DispatchError {
125    #[error("Compute pipeline must be set")]
126    MissingPipeline(pass::MissingPipeline),
127    #[error(transparent)]
128    IncompatibleBindGroup(#[from] Box<BinderError>),
129    #[error(
130        "Each current dispatch group size dimension ({current:?}) must be less or equal to {limit}"
131    )]
132    InvalidGroupSize { current: [u32; 3], limit: u32 },
133    #[error(transparent)]
134    BindingSizeTooSmall(#[from] LateMinBufferBindingSizeMismatch),
135}
136
137impl WebGpuError for DispatchError {
138    fn webgpu_error_type(&self) -> ErrorType {
139        ErrorType::Validation
140    }
141}
142
143/// Error encountered when performing a compute pass.
144#[derive(Clone, Debug, Error)]
145pub enum ComputePassErrorInner {
146    #[error(transparent)]
147    Device(#[from] DeviceError),
148    #[error(transparent)]
149    EncoderState(#[from] EncoderStateError),
150    #[error("Parent encoder is invalid")]
151    InvalidParentEncoder,
152    #[error(transparent)]
153    DebugGroupError(#[from] DebugGroupError),
154    #[error(transparent)]
155    BindGroupIndexOutOfRange(#[from] pass::BindGroupIndexOutOfRange),
156    #[error(transparent)]
157    DestroyedResource(#[from] DestroyedResourceError),
158    #[error("Indirect buffer offset {0:?} is not a multiple of 4")]
159    UnalignedIndirectBufferOffset(BufferAddress),
160    #[error("Indirect buffer uses bytes {offset}..{end_offset} which overruns indirect buffer of size {buffer_size}")]
161    IndirectBufferOverrun {
162        offset: u64,
163        end_offset: u64,
164        buffer_size: u64,
165    },
166    #[error(transparent)]
167    ResourceUsageCompatibility(#[from] ResourceUsageCompatibilityError),
168    #[error(transparent)]
169    MissingBufferUsage(#[from] MissingBufferUsageError),
170    #[error(transparent)]
171    Dispatch(#[from] DispatchError),
172    #[error(transparent)]
173    Bind(#[from] BindError),
174    #[error(transparent)]
175    PushConstants(#[from] PushConstantUploadError),
176    #[error("Push constant offset must be aligned to 4 bytes")]
177    PushConstantOffsetAlignment,
178    #[error("Push constant size must be aligned to 4 bytes")]
179    PushConstantSizeAlignment,
180    #[error("Ran out of push constant space. Don't set 4gb of push constants per ComputePass.")]
181    PushConstantOutOfMemory,
182    #[error(transparent)]
183    QueryUse(#[from] QueryUseError),
184    #[error(transparent)]
185    MissingFeatures(#[from] MissingFeatures),
186    #[error(transparent)]
187    MissingDownlevelFlags(#[from] MissingDownlevelFlags),
188    #[error("The compute pass has already been ended and no further commands can be recorded")]
189    PassEnded,
190    #[error(transparent)]
191    InvalidResource(#[from] InvalidResourceError),
192    #[error(transparent)]
193    TimestampWrites(#[from] TimestampWritesError),
194    // This one is unreachable, but required for generic pass support
195    #[error(transparent)]
196    InvalidValuesOffset(#[from] pass::InvalidValuesOffset),
197}
198
199/// Error encountered when performing a compute pass, stored for later reporting
200/// when encoding ends.
201#[derive(Clone, Debug, Error)]
202#[error("{scope}")]
203pub struct ComputePassError {
204    pub scope: PassErrorScope,
205    #[source]
206    pub(super) inner: ComputePassErrorInner,
207}
208
209impl From<pass::MissingPipeline> for ComputePassErrorInner {
210    fn from(value: pass::MissingPipeline) -> Self {
211        Self::Dispatch(DispatchError::MissingPipeline(value))
212    }
213}
214
215impl<E> MapPassErr<ComputePassError> for E
216where
217    E: Into<ComputePassErrorInner>,
218{
219    fn map_pass_err(self, scope: PassErrorScope) -> ComputePassError {
220        ComputePassError {
221            scope,
222            inner: self.into(),
223        }
224    }
225}
226
227impl WebGpuError for ComputePassError {
228    fn webgpu_error_type(&self) -> ErrorType {
229        let Self { scope: _, inner } = self;
230        let e: &dyn WebGpuError = match inner {
231            ComputePassErrorInner::Device(e) => e,
232            ComputePassErrorInner::EncoderState(e) => e,
233            ComputePassErrorInner::DebugGroupError(e) => e,
234            ComputePassErrorInner::DestroyedResource(e) => e,
235            ComputePassErrorInner::ResourceUsageCompatibility(e) => e,
236            ComputePassErrorInner::MissingBufferUsage(e) => e,
237            ComputePassErrorInner::Dispatch(e) => e,
238            ComputePassErrorInner::Bind(e) => e,
239            ComputePassErrorInner::PushConstants(e) => e,
240            ComputePassErrorInner::QueryUse(e) => e,
241            ComputePassErrorInner::MissingFeatures(e) => e,
242            ComputePassErrorInner::MissingDownlevelFlags(e) => e,
243            ComputePassErrorInner::InvalidResource(e) => e,
244            ComputePassErrorInner::TimestampWrites(e) => e,
245            ComputePassErrorInner::InvalidValuesOffset(e) => e,
246
247            ComputePassErrorInner::InvalidParentEncoder
248            | ComputePassErrorInner::BindGroupIndexOutOfRange { .. }
249            | ComputePassErrorInner::UnalignedIndirectBufferOffset(_)
250            | ComputePassErrorInner::IndirectBufferOverrun { .. }
251            | ComputePassErrorInner::PushConstantOffsetAlignment
252            | ComputePassErrorInner::PushConstantSizeAlignment
253            | ComputePassErrorInner::PushConstantOutOfMemory
254            | ComputePassErrorInner::PassEnded => return ErrorType::Validation,
255        };
256        e.webgpu_error_type()
257    }
258}
259
260struct State<'scope, 'snatch_guard, 'cmd_enc> {
261    pipeline: Option<Arc<ComputePipeline>>,
262
263    pass: pass::PassState<'scope, 'snatch_guard, 'cmd_enc>,
264
265    active_query: Option<(Arc<resource::QuerySet>, u32)>,
266
267    push_constants: Vec<u32>,
268
269    intermediate_trackers: Tracker,
270}
271
272impl<'scope, 'snatch_guard, 'cmd_enc> State<'scope, 'snatch_guard, 'cmd_enc> {
273    fn is_ready(&self) -> Result<(), DispatchError> {
274        if let Some(pipeline) = self.pipeline.as_ref() {
275            self.pass.binder.check_compatibility(pipeline.as_ref())?;
276            self.pass.binder.check_late_buffer_bindings()?;
277            Ok(())
278        } else {
279            Err(DispatchError::MissingPipeline(pass::MissingPipeline))
280        }
281    }
282
283    // `extra_buffer` is there to represent the indirect buffer that is also
284    // part of the usage scope.
285    fn flush_states(
286        &mut self,
287        indirect_buffer: Option<TrackerIndex>,
288    ) -> Result<(), ResourceUsageCompatibilityError> {
289        for bind_group in self.pass.binder.list_active() {
290            unsafe { self.pass.scope.merge_bind_group(&bind_group.used)? };
291            // Note: stateless trackers are not merged: the lifetime reference
292            // is held to the bind group itself.
293        }
294
295        for bind_group in self.pass.binder.list_active() {
296            unsafe {
297                self.intermediate_trackers
298                    .set_and_remove_from_usage_scope_sparse(&mut self.pass.scope, &bind_group.used)
299            }
300        }
301
302        // Add the state of the indirect buffer if it hasn't been hit before.
303        unsafe {
304            self.intermediate_trackers
305                .buffers
306                .set_and_remove_from_usage_scope_sparse(
307                    &mut self.pass.scope.buffers,
308                    indirect_buffer,
309                );
310        }
311
312        CommandEncoder::drain_barriers(
313            self.pass.base.raw_encoder,
314            &mut self.intermediate_trackers,
315            self.pass.base.snatch_guard,
316        );
317        Ok(())
318    }
319}
320
321// Running the compute pass.
322
323impl Global {
324    /// Creates a compute pass.
325    ///
326    /// If creation fails, an invalid pass is returned. Attempting to record
327    /// commands into an invalid pass is permitted, but a validation error will
328    /// ultimately be generated when the parent encoder is finished, and it is
329    /// not possible to run any commands from the invalid pass.
330    ///
331    /// If successful, puts the encoder into the [`Locked`] state.
332    ///
333    /// [`Locked`]: crate::command::CommandEncoderStatus::Locked
334    pub fn command_encoder_begin_compute_pass(
335        &self,
336        encoder_id: id::CommandEncoderId,
337        desc: &ComputePassDescriptor<'_>,
338    ) -> (ComputePass, Option<CommandEncoderError>) {
339        use EncoderStateError as SErr;
340
341        let scope = PassErrorScope::Pass;
342        let hub = &self.hub;
343
344        let label = desc.label.as_deref().map(Cow::Borrowed);
345
346        let cmd_enc = hub.command_encoders.get(encoder_id);
347        let mut cmd_buf_data = cmd_enc.data.lock();
348
349        match cmd_buf_data.lock_encoder() {
350            Ok(()) => {
351                drop(cmd_buf_data);
352                if let Err(err) = cmd_enc.device.check_is_valid() {
353                    return (
354                        ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
355                        None,
356                    );
357                }
358
359                match desc
360                    .timestamp_writes
361                    .as_ref()
362                    .map(|tw| {
363                        Self::validate_pass_timestamp_writes::<ComputePassErrorInner>(
364                            &cmd_enc.device,
365                            &hub.query_sets.read(),
366                            tw,
367                        )
368                    })
369                    .transpose()
370                {
371                    Ok(timestamp_writes) => {
372                        let arc_desc = ArcComputePassDescriptor {
373                            label,
374                            timestamp_writes,
375                        };
376                        (ComputePass::new(cmd_enc, arc_desc), None)
377                    }
378                    Err(err) => (
379                        ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
380                        None,
381                    ),
382                }
383            }
384            Err(err @ SErr::Locked) => {
385                // Attempting to open a new pass while the encoder is locked
386                // invalidates the encoder, but does not generate a validation
387                // error.
388                cmd_buf_data.invalidate(err.clone());
389                drop(cmd_buf_data);
390                (
391                    ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
392                    None,
393                )
394            }
395            Err(err @ (SErr::Ended | SErr::Submitted)) => {
396                // Attempting to open a new pass after the encode has ended
397                // generates an immediate validation error.
398                drop(cmd_buf_data);
399                (
400                    ComputePass::new_invalid(cmd_enc, &label, err.clone().map_pass_err(scope)),
401                    Some(err.into()),
402                )
403            }
404            Err(err @ SErr::Invalid) => {
405                // Passes can be opened even on an invalid encoder. Such passes
406                // are even valid, but since there's no visible side-effect of
407                // the pass being valid and there's no point in storing recorded
408                // commands that will ultimately be discarded, we open an
409                // invalid pass to save that work.
410                drop(cmd_buf_data);
411                (
412                    ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
413                    None,
414                )
415            }
416            Err(SErr::Unlocked) => {
417                unreachable!("lock_encoder cannot fail due to the encoder being unlocked")
418            }
419        }
420    }
421
422    /// Note that this differs from [`Self::compute_pass_end`], it will
423    /// create a new pass, replay the commands and end the pass.
424    ///
425    /// # Panics
426    /// On any error.
427    #[doc(hidden)]
428    #[cfg(any(feature = "serde", feature = "replay"))]
429    pub fn compute_pass_end_with_unresolved_commands(
430        &self,
431        encoder_id: id::CommandEncoderId,
432        base: BasePass<super::ComputeCommand, Infallible>,
433        timestamp_writes: Option<&PassTimestampWrites>,
434    ) {
435        #[cfg(feature = "trace")]
436        {
437            let cmd_enc = self.hub.command_encoders.get(encoder_id);
438            let mut cmd_buf_data = cmd_enc.data.lock();
439            let cmd_buf_data = cmd_buf_data.get_inner();
440
441            if let Some(ref mut list) = cmd_buf_data.trace_commands {
442                list.push(crate::command::Command::RunComputePass {
443                    base: BasePass {
444                        label: base.label.clone(),
445                        error: None,
446                        commands: base.commands.clone(),
447                        dynamic_offsets: base.dynamic_offsets.clone(),
448                        string_data: base.string_data.clone(),
449                        push_constant_data: base.push_constant_data.clone(),
450                    },
451                    timestamp_writes: timestamp_writes.cloned(),
452                });
453            }
454        }
455
456        let BasePass {
457            label,
458            error: _,
459            commands,
460            dynamic_offsets,
461            string_data,
462            push_constant_data,
463        } = base;
464
465        let (mut compute_pass, encoder_error) = self.command_encoder_begin_compute_pass(
466            encoder_id,
467            &ComputePassDescriptor {
468                label: label.as_deref().map(Cow::Borrowed),
469                timestamp_writes: timestamp_writes.cloned(),
470            },
471        );
472        if let Some(err) = encoder_error {
473            panic!("{:?}", err);
474        };
475
476        compute_pass.base = BasePass {
477            label,
478            error: None,
479            commands: super::ComputeCommand::resolve_compute_command_ids(&self.hub, &commands)
480                .unwrap(),
481            dynamic_offsets,
482            string_data,
483            push_constant_data,
484        };
485
486        self.compute_pass_end(&mut compute_pass).unwrap();
487    }
488
489    pub fn compute_pass_end(&self, pass: &mut ComputePass) -> Result<(), EncoderStateError> {
490        profiling::scope!(
491            "CommandEncoder::run_compute_pass {}",
492            pass.base.label.as_deref().unwrap_or("")
493        );
494
495        let cmd_enc = pass.parent.take().ok_or(EncoderStateError::Ended)?;
496        let mut cmd_buf_data = cmd_enc.data.lock();
497
498        cmd_buf_data.unlock_encoder()?;
499
500        let base = pass.base.take();
501
502        if matches!(
503            base,
504            Err(ComputePassError {
505                inner: ComputePassErrorInner::EncoderState(EncoderStateError::Ended),
506                scope: _,
507            })
508        ) {
509            // If the encoder was already finished at time of pass creation,
510            // then it was not put in the locked state, so we need to
511            // generate a validation error here and now due to the encoder not
512            // being locked. The encoder already holds an error from when the
513            // pass was opened, or earlier.
514            //
515            // All other errors are propagated to the encoder within `push_with`,
516            // and will be reported later.
517            return Err(EncoderStateError::Ended);
518        }
519
520        cmd_buf_data.push_with(|| -> Result<_, ComputePassError> {
521            Ok(ArcCommand::RunComputePass {
522                pass: base?,
523                timestamp_writes: pass.timestamp_writes.take(),
524            })
525        })
526    }
527}
528
529pub(super) fn encode_compute_pass(
530    parent_state: &mut EncodingState<InnerCommandEncoder>,
531    mut base: BasePass<ArcComputeCommand, Infallible>,
532    mut timestamp_writes: Option<ArcPassTimestampWrites>,
533) -> Result<(), ComputePassError> {
534    let pass_scope = PassErrorScope::Pass;
535
536    let device = parent_state.device;
537
538    // We automatically keep extending command buffers over time, and because
539    // we want to insert a command buffer _before_ what we're about to record,
540    // we need to make sure to close the previous one.
541    parent_state
542        .raw_encoder
543        .close_if_open()
544        .map_pass_err(pass_scope)?;
545    let raw_encoder = parent_state
546        .raw_encoder
547        .open_pass(base.label.as_deref())
548        .map_pass_err(pass_scope)?;
549
550    let mut debug_scope_depth = 0;
551
552    let mut state = State {
553        pipeline: None,
554
555        pass: pass::PassState {
556            base: EncodingState {
557                device,
558                raw_encoder,
559                tracker: parent_state.tracker,
560                buffer_memory_init_actions: parent_state.buffer_memory_init_actions,
561                texture_memory_actions: parent_state.texture_memory_actions,
562                as_actions: parent_state.as_actions,
563                temp_resources: parent_state.temp_resources,
564                indirect_draw_validation_resources: parent_state.indirect_draw_validation_resources,
565                snatch_guard: parent_state.snatch_guard,
566                debug_scope_depth: &mut debug_scope_depth,
567            },
568            binder: Binder::new(),
569            temp_offsets: Vec::new(),
570            dynamic_offset_count: 0,
571            pending_discard_init_fixups: SurfacesInDiscardState::new(),
572            scope: device.new_usage_scope(),
573            string_offset: 0,
574        },
575        active_query: None,
576
577        push_constants: Vec::new(),
578
579        intermediate_trackers: Tracker::new(),
580    };
581
582    let indices = &device.tracker_indices;
583    state
584        .pass
585        .base
586        .tracker
587        .buffers
588        .set_size(indices.buffers.size());
589    state
590        .pass
591        .base
592        .tracker
593        .textures
594        .set_size(indices.textures.size());
595
596    let timestamp_writes: Option<hal::PassTimestampWrites<'_, dyn hal::DynQuerySet>> =
597        if let Some(tw) = timestamp_writes.take() {
598            tw.query_set.same_device(device).map_pass_err(pass_scope)?;
599
600            let query_set = state
601                .pass
602                .base
603                .tracker
604                .query_sets
605                .insert_single(tw.query_set);
606
607            // Unlike in render passes we can't delay resetting the query sets since
608            // there is no auxiliary pass.
609            let range = if let (Some(index_a), Some(index_b)) =
610                (tw.beginning_of_pass_write_index, tw.end_of_pass_write_index)
611            {
612                Some(index_a.min(index_b)..index_a.max(index_b) + 1)
613            } else {
614                tw.beginning_of_pass_write_index
615                    .or(tw.end_of_pass_write_index)
616                    .map(|i| i..i + 1)
617            };
618            // Range should always be Some, both values being None should lead to a validation error.
619            // But no point in erroring over that nuance here!
620            if let Some(range) = range {
621                unsafe {
622                    state
623                        .pass
624                        .base
625                        .raw_encoder
626                        .reset_queries(query_set.raw(), range);
627                }
628            }
629
630            Some(hal::PassTimestampWrites {
631                query_set: query_set.raw(),
632                beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
633                end_of_pass_write_index: tw.end_of_pass_write_index,
634            })
635        } else {
636            None
637        };
638
639    let hal_desc = hal::ComputePassDescriptor {
640        label: hal_label(base.label.as_deref(), device.instance_flags),
641        timestamp_writes,
642    };
643
644    unsafe {
645        state.pass.base.raw_encoder.begin_compute_pass(&hal_desc);
646    }
647
648    for command in base.commands.drain(..) {
649        match command {
650            ArcComputeCommand::SetBindGroup {
651                index,
652                num_dynamic_offsets,
653                bind_group,
654            } => {
655                let scope = PassErrorScope::SetBindGroup;
656                pass::set_bind_group::<ComputePassErrorInner>(
657                    &mut state.pass,
658                    device,
659                    &base.dynamic_offsets,
660                    index,
661                    num_dynamic_offsets,
662                    bind_group,
663                    false,
664                )
665                .map_pass_err(scope)?;
666            }
667            ArcComputeCommand::SetPipeline(pipeline) => {
668                let scope = PassErrorScope::SetPipelineCompute;
669                set_pipeline(&mut state, device, pipeline).map_pass_err(scope)?;
670            }
671            ArcComputeCommand::SetPushConstant {
672                offset,
673                size_bytes,
674                values_offset,
675            } => {
676                let scope = PassErrorScope::SetPushConstant;
677                pass::set_push_constant::<ComputePassErrorInner, _>(
678                    &mut state.pass,
679                    &base.push_constant_data,
680                    wgt::ShaderStages::COMPUTE,
681                    offset,
682                    size_bytes,
683                    Some(values_offset),
684                    |data_slice| {
685                        let offset_in_elements = (offset / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
686                        let size_in_elements = (size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
687                        state.push_constants[offset_in_elements..][..size_in_elements]
688                            .copy_from_slice(data_slice);
689                    },
690                )
691                .map_pass_err(scope)?;
692            }
693            ArcComputeCommand::Dispatch(groups) => {
694                let scope = PassErrorScope::Dispatch { indirect: false };
695                dispatch(&mut state, groups).map_pass_err(scope)?;
696            }
697            ArcComputeCommand::DispatchIndirect { buffer, offset } => {
698                let scope = PassErrorScope::Dispatch { indirect: true };
699                dispatch_indirect(&mut state, device, buffer, offset).map_pass_err(scope)?;
700            }
701            ArcComputeCommand::PushDebugGroup { color: _, len } => {
702                pass::push_debug_group(&mut state.pass, &base.string_data, len);
703            }
704            ArcComputeCommand::PopDebugGroup => {
705                let scope = PassErrorScope::PopDebugGroup;
706                pass::pop_debug_group::<ComputePassErrorInner>(&mut state.pass)
707                    .map_pass_err(scope)?;
708            }
709            ArcComputeCommand::InsertDebugMarker { color: _, len } => {
710                pass::insert_debug_marker(&mut state.pass, &base.string_data, len);
711            }
712            ArcComputeCommand::WriteTimestamp {
713                query_set,
714                query_index,
715            } => {
716                let scope = PassErrorScope::WriteTimestamp;
717                pass::write_timestamp::<ComputePassErrorInner>(
718                    &mut state.pass,
719                    device,
720                    None, // compute passes do not attempt to coalesce query resets
721                    query_set,
722                    query_index,
723                )
724                .map_pass_err(scope)?;
725            }
726            ArcComputeCommand::BeginPipelineStatisticsQuery {
727                query_set,
728                query_index,
729            } => {
730                let scope = PassErrorScope::BeginPipelineStatisticsQuery;
731                validate_and_begin_pipeline_statistics_query(
732                    query_set,
733                    state.pass.base.raw_encoder,
734                    &mut state.pass.base.tracker.query_sets,
735                    device,
736                    query_index,
737                    None,
738                    &mut state.active_query,
739                )
740                .map_pass_err(scope)?;
741            }
742            ArcComputeCommand::EndPipelineStatisticsQuery => {
743                let scope = PassErrorScope::EndPipelineStatisticsQuery;
744                end_pipeline_statistics_query(state.pass.base.raw_encoder, &mut state.active_query)
745                    .map_pass_err(scope)?;
746            }
747        }
748    }
749
750    if *state.pass.base.debug_scope_depth > 0 {
751        Err(
752            ComputePassErrorInner::DebugGroupError(DebugGroupError::MissingPop)
753                .map_pass_err(pass_scope),
754        )?;
755    }
756
757    unsafe {
758        state.pass.base.raw_encoder.end_compute_pass();
759    }
760
761    let State {
762        pass: pass::PassState {
763            pending_discard_init_fixups,
764            ..
765        },
766        intermediate_trackers,
767        ..
768    } = state;
769
770    // Stop the current command encoder.
771    parent_state.raw_encoder.close().map_pass_err(pass_scope)?;
772
773    // Create a new command encoder, which we will insert _before_ the body of the compute pass.
774    //
775    // Use that buffer to insert barriers and clear discarded images.
776    let transit = parent_state
777        .raw_encoder
778        .open_pass(hal_label(
779            Some("(wgpu internal) Pre Pass"),
780            device.instance_flags,
781        ))
782        .map_pass_err(pass_scope)?;
783    fixup_discarded_surfaces(
784        pending_discard_init_fixups.into_iter(),
785        transit,
786        &mut parent_state.tracker.textures,
787        device,
788        parent_state.snatch_guard,
789    );
790    CommandEncoder::insert_barriers_from_tracker(
791        transit,
792        parent_state.tracker,
793        &intermediate_trackers,
794        parent_state.snatch_guard,
795    );
796    // Close the command encoder, and swap it with the previous.
797    parent_state
798        .raw_encoder
799        .close_and_swap()
800        .map_pass_err(pass_scope)?;
801
802    Ok(())
803}
804
805fn set_pipeline(
806    state: &mut State,
807    device: &Arc<Device>,
808    pipeline: Arc<ComputePipeline>,
809) -> Result<(), ComputePassErrorInner> {
810    pipeline.same_device(device)?;
811
812    state.pipeline = Some(pipeline.clone());
813
814    let pipeline = state
815        .pass
816        .base
817        .tracker
818        .compute_pipelines
819        .insert_single(pipeline)
820        .clone();
821
822    unsafe {
823        state
824            .pass
825            .base
826            .raw_encoder
827            .set_compute_pipeline(pipeline.raw());
828    }
829
830    // Rebind resources
831    pass::rebind_resources::<ComputePassErrorInner, _>(
832        &mut state.pass,
833        &pipeline.layout,
834        &pipeline.late_sized_buffer_groups,
835        || {
836            // This only needs to be here for compute pipelines because they use push constants for
837            // validating indirect draws.
838            state.push_constants.clear();
839            // Note that can only be one range for each stage. See the `MoreThanOnePushConstantRangePerStage` error.
840            if let Some(push_constant_range) =
841                pipeline.layout.push_constant_ranges.iter().find_map(|pcr| {
842                    pcr.stages
843                        .contains(wgt::ShaderStages::COMPUTE)
844                        .then_some(pcr.range.clone())
845                })
846            {
847                // Note that non-0 range start doesn't work anyway https://github.com/gfx-rs/wgpu/issues/4502
848                let len = push_constant_range.len() / wgt::PUSH_CONSTANT_ALIGNMENT as usize;
849                state.push_constants.extend(core::iter::repeat_n(0, len));
850            }
851        },
852    )
853}
854
855fn dispatch(state: &mut State, groups: [u32; 3]) -> Result<(), ComputePassErrorInner> {
856    state.is_ready()?;
857
858    state.flush_states(None)?;
859
860    let groups_size_limit = state
861        .pass
862        .base
863        .device
864        .limits
865        .max_compute_workgroups_per_dimension;
866
867    if groups[0] > groups_size_limit
868        || groups[1] > groups_size_limit
869        || groups[2] > groups_size_limit
870    {
871        return Err(ComputePassErrorInner::Dispatch(
872            DispatchError::InvalidGroupSize {
873                current: groups,
874                limit: groups_size_limit,
875            },
876        ));
877    }
878
879    unsafe {
880        state.pass.base.raw_encoder.dispatch(groups);
881    }
882    Ok(())
883}
884
885fn dispatch_indirect(
886    state: &mut State,
887    device: &Arc<Device>,
888    buffer: Arc<Buffer>,
889    offset: u64,
890) -> Result<(), ComputePassErrorInner> {
891    buffer.same_device(device)?;
892
893    state.is_ready()?;
894
895    state
896        .pass
897        .base
898        .device
899        .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?;
900
901    buffer.check_usage(wgt::BufferUsages::INDIRECT)?;
902    buffer.check_destroyed(state.pass.base.snatch_guard)?;
903
904    if offset % 4 != 0 {
905        return Err(ComputePassErrorInner::UnalignedIndirectBufferOffset(offset));
906    }
907
908    let end_offset = offset + size_of::<wgt::DispatchIndirectArgs>() as u64;
909    if end_offset > buffer.size {
910        return Err(ComputePassErrorInner::IndirectBufferOverrun {
911            offset,
912            end_offset,
913            buffer_size: buffer.size,
914        });
915    }
916
917    let stride = 3 * 4; // 3 integers, x/y/z group size
918    state.pass.base.buffer_memory_init_actions.extend(
919        buffer.initialization_status.read().create_action(
920            &buffer,
921            offset..(offset + stride),
922            MemoryInitKind::NeedsInitializedMemory,
923        ),
924    );
925
926    if let Some(ref indirect_validation) = state.pass.base.device.indirect_validation {
927        let params = indirect_validation.dispatch.params(
928            &state.pass.base.device.limits,
929            offset,
930            buffer.size,
931        );
932
933        unsafe {
934            state
935                .pass
936                .base
937                .raw_encoder
938                .set_compute_pipeline(params.pipeline);
939        }
940
941        unsafe {
942            state.pass.base.raw_encoder.set_push_constants(
943                params.pipeline_layout,
944                wgt::ShaderStages::COMPUTE,
945                0,
946                &[params.offset_remainder as u32 / 4],
947            );
948        }
949
950        unsafe {
951            state.pass.base.raw_encoder.set_bind_group(
952                params.pipeline_layout,
953                0,
954                Some(params.dst_bind_group),
955                &[],
956            );
957        }
958        unsafe {
959            state.pass.base.raw_encoder.set_bind_group(
960                params.pipeline_layout,
961                1,
962                Some(
963                    buffer
964                        .indirect_validation_bind_groups
965                        .get(state.pass.base.snatch_guard)
966                        .unwrap()
967                        .dispatch
968                        .as_ref(),
969                ),
970                &[params.aligned_offset as u32],
971            );
972        }
973
974        let src_transition = state
975            .intermediate_trackers
976            .buffers
977            .set_single(&buffer, wgt::BufferUses::STORAGE_READ_ONLY);
978        let src_barrier = src_transition
979            .map(|transition| transition.into_hal(&buffer, state.pass.base.snatch_guard));
980        unsafe {
981            state
982                .pass
983                .base
984                .raw_encoder
985                .transition_buffers(src_barrier.as_slice());
986        }
987
988        unsafe {
989            state
990                .pass
991                .base
992                .raw_encoder
993                .transition_buffers(&[hal::BufferBarrier {
994                    buffer: params.dst_buffer,
995                    usage: hal::StateTransition {
996                        from: wgt::BufferUses::INDIRECT,
997                        to: wgt::BufferUses::STORAGE_READ_WRITE,
998                    },
999                }]);
1000        }
1001
1002        unsafe {
1003            state.pass.base.raw_encoder.dispatch([1, 1, 1]);
1004        }
1005
1006        // reset state
1007        {
1008            let pipeline = state.pipeline.as_ref().unwrap();
1009
1010            unsafe {
1011                state
1012                    .pass
1013                    .base
1014                    .raw_encoder
1015                    .set_compute_pipeline(pipeline.raw());
1016            }
1017
1018            if !state.push_constants.is_empty() {
1019                unsafe {
1020                    state.pass.base.raw_encoder.set_push_constants(
1021                        pipeline.layout.raw(),
1022                        wgt::ShaderStages::COMPUTE,
1023                        0,
1024                        &state.push_constants,
1025                    );
1026                }
1027            }
1028
1029            for (i, e) in state.pass.binder.list_valid() {
1030                let group = e.group.as_ref().unwrap();
1031                let raw_bg = group.try_raw(state.pass.base.snatch_guard)?;
1032                unsafe {
1033                    state.pass.base.raw_encoder.set_bind_group(
1034                        pipeline.layout.raw(),
1035                        i as u32,
1036                        Some(raw_bg),
1037                        &e.dynamic_offsets,
1038                    );
1039                }
1040            }
1041        }
1042
1043        unsafe {
1044            state
1045                .pass
1046                .base
1047                .raw_encoder
1048                .transition_buffers(&[hal::BufferBarrier {
1049                    buffer: params.dst_buffer,
1050                    usage: hal::StateTransition {
1051                        from: wgt::BufferUses::STORAGE_READ_WRITE,
1052                        to: wgt::BufferUses::INDIRECT,
1053                    },
1054                }]);
1055        }
1056
1057        state.flush_states(None)?;
1058        unsafe {
1059            state
1060                .pass
1061                .base
1062                .raw_encoder
1063                .dispatch_indirect(params.dst_buffer, 0);
1064        }
1065    } else {
1066        state
1067            .pass
1068            .scope
1069            .buffers
1070            .merge_single(&buffer, wgt::BufferUses::INDIRECT)?;
1071
1072        use crate::resource::Trackable;
1073        state.flush_states(Some(buffer.tracker_index()))?;
1074
1075        let buf_raw = buffer.try_raw(state.pass.base.snatch_guard)?;
1076        unsafe {
1077            state
1078                .pass
1079                .base
1080                .raw_encoder
1081                .dispatch_indirect(buf_raw, offset);
1082        }
1083    }
1084
1085    Ok(())
1086}
1087
1088// Recording a compute pass.
1089//
1090// The only error that should be returned from these methods is
1091// `EncoderStateError::Ended`, when the pass has already ended and an immediate
1092// validation error is raised.
1093//
1094// All other errors should be stored in the pass for later reporting when
1095// `CommandEncoder.finish()` is called.
1096//
1097// The `pass_try!` macro should be used to handle errors appropriately. Note
1098// that the `pass_try!` and `pass_base!` macros may return early from the
1099// function that invokes them, like the `?` operator.
1100impl Global {
1101    pub fn compute_pass_set_bind_group(
1102        &self,
1103        pass: &mut ComputePass,
1104        index: u32,
1105        bind_group_id: Option<id::BindGroupId>,
1106        offsets: &[DynamicOffset],
1107    ) -> Result<(), PassStateError> {
1108        let scope = PassErrorScope::SetBindGroup;
1109
1110        // This statement will return an error if the pass is ended. It's
1111        // important the error check comes before the early-out for
1112        // `set_and_check_redundant`.
1113        let base = pass_base!(pass, scope);
1114
1115        if pass.current_bind_groups.set_and_check_redundant(
1116            bind_group_id,
1117            index,
1118            &mut base.dynamic_offsets,
1119            offsets,
1120        ) {
1121            return Ok(());
1122        }
1123
1124        let mut bind_group = None;
1125        if bind_group_id.is_some() {
1126            let bind_group_id = bind_group_id.unwrap();
1127
1128            let hub = &self.hub;
1129            bind_group = Some(pass_try!(
1130                base,
1131                scope,
1132                hub.bind_groups.get(bind_group_id).get(),
1133            ));
1134        }
1135
1136        base.commands.push(ArcComputeCommand::SetBindGroup {
1137            index,
1138            num_dynamic_offsets: offsets.len(),
1139            bind_group,
1140        });
1141
1142        Ok(())
1143    }
1144
1145    pub fn compute_pass_set_pipeline(
1146        &self,
1147        pass: &mut ComputePass,
1148        pipeline_id: id::ComputePipelineId,
1149    ) -> Result<(), PassStateError> {
1150        let redundant = pass.current_pipeline.set_and_check_redundant(pipeline_id);
1151
1152        let scope = PassErrorScope::SetPipelineCompute;
1153
1154        // This statement will return an error if the pass is ended.
1155        // Its important the error check comes before the early-out for `redundant`.
1156        let base = pass_base!(pass, scope);
1157
1158        if redundant {
1159            return Ok(());
1160        }
1161
1162        let hub = &self.hub;
1163        let pipeline = pass_try!(base, scope, hub.compute_pipelines.get(pipeline_id).get());
1164
1165        base.commands.push(ArcComputeCommand::SetPipeline(pipeline));
1166
1167        Ok(())
1168    }
1169
1170    pub fn compute_pass_set_push_constants(
1171        &self,
1172        pass: &mut ComputePass,
1173        offset: u32,
1174        data: &[u8],
1175    ) -> Result<(), PassStateError> {
1176        let scope = PassErrorScope::SetPushConstant;
1177        let base = pass_base!(pass, scope);
1178
1179        if offset & (wgt::PUSH_CONSTANT_ALIGNMENT - 1) != 0 {
1180            pass_try!(
1181                base,
1182                scope,
1183                Err(ComputePassErrorInner::PushConstantOffsetAlignment),
1184            );
1185        }
1186
1187        if data.len() as u32 & (wgt::PUSH_CONSTANT_ALIGNMENT - 1) != 0 {
1188            pass_try!(
1189                base,
1190                scope,
1191                Err(ComputePassErrorInner::PushConstantSizeAlignment),
1192            )
1193        }
1194        let value_offset = pass_try!(
1195            base,
1196            scope,
1197            base.push_constant_data
1198                .len()
1199                .try_into()
1200                .map_err(|_| ComputePassErrorInner::PushConstantOutOfMemory)
1201        );
1202
1203        base.push_constant_data.extend(
1204            data.chunks_exact(wgt::PUSH_CONSTANT_ALIGNMENT as usize)
1205                .map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])),
1206        );
1207
1208        base.commands.push(ArcComputeCommand::SetPushConstant {
1209            offset,
1210            size_bytes: data.len() as u32,
1211            values_offset: value_offset,
1212        });
1213
1214        Ok(())
1215    }
1216
1217    pub fn compute_pass_dispatch_workgroups(
1218        &self,
1219        pass: &mut ComputePass,
1220        groups_x: u32,
1221        groups_y: u32,
1222        groups_z: u32,
1223    ) -> Result<(), PassStateError> {
1224        let scope = PassErrorScope::Dispatch { indirect: false };
1225
1226        pass_base!(pass, scope)
1227            .commands
1228            .push(ArcComputeCommand::Dispatch([groups_x, groups_y, groups_z]));
1229
1230        Ok(())
1231    }
1232
1233    pub fn compute_pass_dispatch_workgroups_indirect(
1234        &self,
1235        pass: &mut ComputePass,
1236        buffer_id: id::BufferId,
1237        offset: BufferAddress,
1238    ) -> Result<(), PassStateError> {
1239        let hub = &self.hub;
1240        let scope = PassErrorScope::Dispatch { indirect: true };
1241        let base = pass_base!(pass, scope);
1242
1243        let buffer = pass_try!(base, scope, hub.buffers.get(buffer_id).get());
1244
1245        base.commands
1246            .push(ArcComputeCommand::DispatchIndirect { buffer, offset });
1247
1248        Ok(())
1249    }
1250
1251    pub fn compute_pass_push_debug_group(
1252        &self,
1253        pass: &mut ComputePass,
1254        label: &str,
1255        color: u32,
1256    ) -> Result<(), PassStateError> {
1257        let base = pass_base!(pass, PassErrorScope::PushDebugGroup);
1258
1259        let bytes = label.as_bytes();
1260        base.string_data.extend_from_slice(bytes);
1261
1262        base.commands.push(ArcComputeCommand::PushDebugGroup {
1263            color,
1264            len: bytes.len(),
1265        });
1266
1267        Ok(())
1268    }
1269
1270    pub fn compute_pass_pop_debug_group(
1271        &self,
1272        pass: &mut ComputePass,
1273    ) -> Result<(), PassStateError> {
1274        let base = pass_base!(pass, PassErrorScope::PopDebugGroup);
1275
1276        base.commands.push(ArcComputeCommand::PopDebugGroup);
1277
1278        Ok(())
1279    }
1280
1281    pub fn compute_pass_insert_debug_marker(
1282        &self,
1283        pass: &mut ComputePass,
1284        label: &str,
1285        color: u32,
1286    ) -> Result<(), PassStateError> {
1287        let base = pass_base!(pass, PassErrorScope::InsertDebugMarker);
1288
1289        let bytes = label.as_bytes();
1290        base.string_data.extend_from_slice(bytes);
1291
1292        base.commands.push(ArcComputeCommand::InsertDebugMarker {
1293            color,
1294            len: bytes.len(),
1295        });
1296
1297        Ok(())
1298    }
1299
1300    pub fn compute_pass_write_timestamp(
1301        &self,
1302        pass: &mut ComputePass,
1303        query_set_id: id::QuerySetId,
1304        query_index: u32,
1305    ) -> Result<(), PassStateError> {
1306        let scope = PassErrorScope::WriteTimestamp;
1307        let base = pass_base!(pass, scope);
1308
1309        let hub = &self.hub;
1310        let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1311
1312        base.commands.push(ArcComputeCommand::WriteTimestamp {
1313            query_set,
1314            query_index,
1315        });
1316
1317        Ok(())
1318    }
1319
1320    pub fn compute_pass_begin_pipeline_statistics_query(
1321        &self,
1322        pass: &mut ComputePass,
1323        query_set_id: id::QuerySetId,
1324        query_index: u32,
1325    ) -> Result<(), PassStateError> {
1326        let scope = PassErrorScope::BeginPipelineStatisticsQuery;
1327        let base = pass_base!(pass, scope);
1328
1329        let hub = &self.hub;
1330        let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1331
1332        base.commands
1333            .push(ArcComputeCommand::BeginPipelineStatisticsQuery {
1334                query_set,
1335                query_index,
1336            });
1337
1338        Ok(())
1339    }
1340
1341    pub fn compute_pass_end_pipeline_statistics_query(
1342        &self,
1343        pass: &mut ComputePass,
1344    ) -> Result<(), PassStateError> {
1345        pass_base!(pass, PassErrorScope::EndPipelineStatisticsQuery)
1346            .commands
1347            .push(ArcComputeCommand::EndPipelineStatisticsQuery);
1348
1349        Ok(())
1350    }
1351}