Skip to main content

wgpu_core/command/
compute.rs

1use thiserror::Error;
2use wgt::{
3    error::{ErrorType, WebGpuError},
4    BufferAddress, DynamicOffset,
5};
6
7use alloc::{borrow::Cow, boxed::Box, sync::Arc, vec::Vec};
8use core::{convert::Infallible, fmt, str};
9
10use crate::{
11    api_log,
12    binding_model::{BindError, ImmediateUploadError, LateMinBufferBindingSizeMismatch},
13    command::{
14        bind::{Binder, BinderError},
15        compute_command::ArcComputeCommand,
16        encoder::EncodingState,
17        memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState},
18        pass::{self, flush_bindings_helper},
19        pass_base, pass_try,
20        query::{end_pipeline_statistics_query, validate_and_begin_pipeline_statistics_query},
21        ArcCommand, ArcPassTimestampWrites, BasePass, BindGroupStateChange, CommandEncoder,
22        CommandEncoderError, DebugGroupError, EncoderStateError, InnerCommandEncoder, MapPassErr,
23        PassErrorScope, PassStateError, PassTimestampWrites, QueryUseError, StateChange,
24        TimestampWritesError,
25    },
26    device::{Device, DeviceError, MissingDownlevelFlags, MissingFeatures},
27    global::Global,
28    hal_label, id,
29    init_tracker::MemoryInitKind,
30    pipeline::ComputePipeline,
31    resource::{
32        self, Buffer, DestroyedResourceError, InvalidResourceError, Labeled,
33        MissingBufferUsageError, ParentDevice, RawResourceAccess, Trackable,
34    },
35    track::{ResourceUsageCompatibilityError, Tracker},
36    Label,
37};
38
39pub type ComputeBasePass = BasePass<ArcComputeCommand, ComputePassError>;
40
41/// A pass's [encoder state](https://www.w3.org/TR/webgpu/#encoder-state) and
42/// its validity are two distinct conditions, i.e., the full matrix of
43/// (open, ended) x (valid, invalid) is possible.
44///
45/// The presence or absence of the `parent` `Option` indicates the pass's state.
46/// The presence or absence of an error in `base.error` indicates the pass's
47/// validity.
48pub struct ComputePass {
49    /// All pass data & records is stored here.
50    base: ComputeBasePass,
51
52    /// Parent command encoder that this pass records commands into.
53    ///
54    /// If this is `Some`, then the pass is in WebGPU's "open" state. If it is
55    /// `None`, then the pass is in the "ended" state.
56    /// See <https://www.w3.org/TR/webgpu/#encoder-state>
57    parent: Option<Arc<CommandEncoder>>,
58
59    timestamp_writes: Option<ArcPassTimestampWrites>,
60
61    // Resource binding dedupe state.
62    current_bind_groups: BindGroupStateChange,
63    current_pipeline: StateChange<id::ComputePipelineId>,
64}
65
66impl ComputePass {
67    /// If the parent command encoder is invalid, the returned pass will be invalid.
68    fn new(parent: Arc<CommandEncoder>, desc: ArcComputePassDescriptor) -> Self {
69        let ArcComputePassDescriptor {
70            label,
71            timestamp_writes,
72        } = desc;
73
74        Self {
75            base: BasePass::new(&label),
76            parent: Some(parent),
77            timestamp_writes,
78
79            current_bind_groups: BindGroupStateChange::new(),
80            current_pipeline: StateChange::new(),
81        }
82    }
83
84    fn new_invalid(parent: Arc<CommandEncoder>, label: &Label, err: ComputePassError) -> Self {
85        Self {
86            base: BasePass::new_invalid(label, err),
87            parent: Some(parent),
88            timestamp_writes: None,
89            current_bind_groups: BindGroupStateChange::new(),
90            current_pipeline: StateChange::new(),
91        }
92    }
93
94    #[inline]
95    pub fn label(&self) -> Option<&str> {
96        self.base.label.as_deref()
97    }
98}
99
100impl fmt::Debug for ComputePass {
101    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102        match self.parent {
103            Some(ref cmd_enc) => write!(f, "ComputePass {{ parent: {} }}", cmd_enc.error_ident()),
104            None => write!(f, "ComputePass {{ parent: None }}"),
105        }
106    }
107}
108
109#[derive(Clone, Debug, Default)]
110pub struct ComputePassDescriptor<'a, PTW = PassTimestampWrites> {
111    pub label: Label<'a>,
112    /// Defines where and when timestamp values will be written for this pass.
113    pub timestamp_writes: Option<PTW>,
114}
115
116/// cbindgen:ignore
117type ArcComputePassDescriptor<'a> = ComputePassDescriptor<'a, ArcPassTimestampWrites>;
118
119#[derive(Clone, Debug, Error)]
120#[non_exhaustive]
121pub enum DispatchError {
122    #[error("Compute pipeline must be set")]
123    MissingPipeline(pass::MissingPipeline),
124    #[error(transparent)]
125    IncompatibleBindGroup(#[from] Box<BinderError>),
126    #[error(
127        "Each current dispatch group size dimension ({current:?}) must be less or equal to {limit}"
128    )]
129    InvalidGroupSize { current: [u32; 3], limit: u32 },
130    #[error(transparent)]
131    BindingSizeTooSmall(#[from] LateMinBufferBindingSizeMismatch),
132}
133
134impl WebGpuError for DispatchError {
135    fn webgpu_error_type(&self) -> ErrorType {
136        ErrorType::Validation
137    }
138}
139
140/// Error encountered when performing a compute pass.
141#[derive(Clone, Debug, Error)]
142pub enum ComputePassErrorInner {
143    #[error(transparent)]
144    Device(#[from] DeviceError),
145    #[error(transparent)]
146    EncoderState(#[from] EncoderStateError),
147    #[error("Parent encoder is invalid")]
148    InvalidParentEncoder,
149    #[error(transparent)]
150    DebugGroupError(#[from] DebugGroupError),
151    #[error(transparent)]
152    BindGroupIndexOutOfRange(#[from] pass::BindGroupIndexOutOfRange),
153    #[error(transparent)]
154    DestroyedResource(#[from] DestroyedResourceError),
155    #[error("Indirect buffer offset {0:?} is not a multiple of 4")]
156    UnalignedIndirectBufferOffset(BufferAddress),
157    #[error("Indirect buffer uses bytes {offset}..{end_offset} which overruns indirect buffer of size {buffer_size}")]
158    IndirectBufferOverrun {
159        offset: u64,
160        end_offset: u64,
161        buffer_size: u64,
162    },
163    #[error(transparent)]
164    ResourceUsageCompatibility(#[from] ResourceUsageCompatibilityError),
165    #[error(transparent)]
166    MissingBufferUsage(#[from] MissingBufferUsageError),
167    #[error(transparent)]
168    Dispatch(#[from] DispatchError),
169    #[error(transparent)]
170    Bind(#[from] BindError),
171    #[error(transparent)]
172    ImmediateData(#[from] ImmediateUploadError),
173    #[error("Immediate data offset must be aligned to 4 bytes")]
174    ImmediateOffsetAlignment,
175    #[error("Immediate data size must be aligned to 4 bytes")]
176    ImmediateDataizeAlignment,
177    #[error("Ran out of immediate data space. Don't set 4gb of immediates per ComputePass.")]
178    ImmediateOutOfMemory,
179    #[error(transparent)]
180    QueryUse(#[from] QueryUseError),
181    #[error(transparent)]
182    MissingFeatures(#[from] MissingFeatures),
183    #[error(transparent)]
184    MissingDownlevelFlags(#[from] MissingDownlevelFlags),
185    #[error("The compute pass has already been ended and no further commands can be recorded")]
186    PassEnded,
187    #[error(transparent)]
188    InvalidResource(#[from] InvalidResourceError),
189    #[error(transparent)]
190    TimestampWrites(#[from] TimestampWritesError),
191    // This one is unreachable, but required for generic pass support
192    #[error(transparent)]
193    InvalidValuesOffset(#[from] pass::InvalidValuesOffset),
194}
195
196/// Error encountered when performing a compute pass, stored for later reporting
197/// when encoding ends.
198#[derive(Clone, Debug, Error)]
199#[error("{scope}")]
200pub struct ComputePassError {
201    pub scope: PassErrorScope,
202    #[source]
203    pub(super) inner: ComputePassErrorInner,
204}
205
206impl From<pass::MissingPipeline> for ComputePassErrorInner {
207    fn from(value: pass::MissingPipeline) -> Self {
208        Self::Dispatch(DispatchError::MissingPipeline(value))
209    }
210}
211
212impl<E> MapPassErr<ComputePassError> for E
213where
214    E: Into<ComputePassErrorInner>,
215{
216    fn map_pass_err(self, scope: PassErrorScope) -> ComputePassError {
217        ComputePassError {
218            scope,
219            inner: self.into(),
220        }
221    }
222}
223
224impl WebGpuError for ComputePassError {
225    fn webgpu_error_type(&self) -> ErrorType {
226        let Self { scope: _, inner } = self;
227        match inner {
228            ComputePassErrorInner::Device(e) => e.webgpu_error_type(),
229            ComputePassErrorInner::EncoderState(e) => e.webgpu_error_type(),
230            ComputePassErrorInner::DebugGroupError(e) => e.webgpu_error_type(),
231            ComputePassErrorInner::DestroyedResource(e) => e.webgpu_error_type(),
232            ComputePassErrorInner::ResourceUsageCompatibility(e) => e.webgpu_error_type(),
233            ComputePassErrorInner::MissingBufferUsage(e) => e.webgpu_error_type(),
234            ComputePassErrorInner::Dispatch(e) => e.webgpu_error_type(),
235            ComputePassErrorInner::Bind(e) => e.webgpu_error_type(),
236            ComputePassErrorInner::ImmediateData(e) => e.webgpu_error_type(),
237            ComputePassErrorInner::QueryUse(e) => e.webgpu_error_type(),
238            ComputePassErrorInner::MissingFeatures(e) => e.webgpu_error_type(),
239            ComputePassErrorInner::MissingDownlevelFlags(e) => e.webgpu_error_type(),
240            ComputePassErrorInner::InvalidResource(e) => e.webgpu_error_type(),
241            ComputePassErrorInner::TimestampWrites(e) => e.webgpu_error_type(),
242            ComputePassErrorInner::InvalidValuesOffset(e) => e.webgpu_error_type(),
243
244            ComputePassErrorInner::InvalidParentEncoder
245            | ComputePassErrorInner::BindGroupIndexOutOfRange { .. }
246            | ComputePassErrorInner::UnalignedIndirectBufferOffset(_)
247            | ComputePassErrorInner::IndirectBufferOverrun { .. }
248            | ComputePassErrorInner::ImmediateOffsetAlignment
249            | ComputePassErrorInner::ImmediateDataizeAlignment
250            | ComputePassErrorInner::ImmediateOutOfMemory
251            | ComputePassErrorInner::PassEnded => ErrorType::Validation,
252        }
253    }
254}
255
256struct State<'scope, 'snatch_guard, 'cmd_enc> {
257    pipeline: Option<Arc<ComputePipeline>>,
258
259    pass: pass::PassState<'scope, 'snatch_guard, 'cmd_enc>,
260
261    active_query: Option<(Arc<resource::QuerySet>, u32)>,
262
263    immediates: Vec<u32>,
264
265    intermediate_trackers: Tracker,
266}
267
268impl<'scope, 'snatch_guard, 'cmd_enc> State<'scope, 'snatch_guard, 'cmd_enc> {
269    fn is_ready(&self) -> Result<(), DispatchError> {
270        if let Some(pipeline) = self.pipeline.as_ref() {
271            self.pass.binder.check_compatibility(pipeline.as_ref())?;
272            self.pass.binder.check_late_buffer_bindings()?;
273            Ok(())
274        } else {
275            Err(DispatchError::MissingPipeline(pass::MissingPipeline))
276        }
277    }
278
279    /// Flush binding state in preparation for a dispatch.
280    ///
281    /// # Differences between render and compute passes
282    ///
283    /// There are differences between the `flush_bindings` implementations for
284    /// render and compute passes, because render passes have a single usage
285    /// scope for the entire pass, and compute passes have a separate usage
286    /// scope for each dispatch.
287    ///
288    /// For compute passes, bind groups are merged into a fresh usage scope
289    /// here, not into the pass usage scope within calls to `set_bind_group`. As
290    /// specified by WebGPU, for compute passes, we merge only the bind groups
291    /// that are actually used by the pipeline, unlike render passes, which
292    /// merge every bind group that is ever set, even if it is not ultimately
293    /// used by the pipeline.
294    ///
295    /// For compute passes, we call `drain_barriers` here, because barriers may
296    /// be needed before each dispatch if a previous dispatch had a conflicting
297    /// usage. For render passes, barriers are emitted once at the start of the
298    /// render pass.
299    ///
300    /// # Indirect buffer handling
301    ///
302    /// The `indirect_buffer` argument should be passed for any indirect
303    /// dispatch (with or without validation). It will be checked for
304    /// conflicting usages according to WebGPU rules. For the purpose of
305    /// these rules, the fact that we have actually processed the buffer in
306    /// the validation pass is an implementation detail.
307    ///
308    /// The `track_indirect_buffer` argument should be set when doing indirect
309    /// dispatch *without* validation. In this case, the indirect buffer will
310    /// be added to the tracker in order to generate any necessary transitions
311    /// for that usage.
312    ///
313    /// When doing indirect dispatch *with* validation, the indirect buffer is
314    /// processed by the validation pass and is not used by the actual dispatch.
315    /// The indirect validation code handles transitions for the validation
316    /// pass.
317    fn flush_bindings(
318        &mut self,
319        indirect_buffer: Option<&Arc<Buffer>>,
320        track_indirect_buffer: bool,
321    ) -> Result<(), ComputePassErrorInner> {
322        for bind_group in self.pass.binder.list_active() {
323            unsafe { self.pass.scope.merge_bind_group(&bind_group.used)? };
324        }
325
326        // Add the indirect buffer. Because usage scopes are per-dispatch, this
327        // is the only place where INDIRECT usage could be added, and it is safe
328        // for us to remove it below.
329        if let Some(buffer) = indirect_buffer {
330            self.pass
331                .scope
332                .buffers
333                .merge_single(buffer, wgt::BufferUses::INDIRECT)?;
334        }
335
336        // For compute, usage scopes are associated with each dispatch and not
337        // with the pass as a whole. However, because the cost of creating and
338        // dropping `UsageScope`s is significant (even with the pool), we
339        // add and then remove usage from a single usage scope.
340
341        for bind_group in self.pass.binder.list_active() {
342            self.intermediate_trackers
343                .set_and_remove_from_usage_scope_sparse(&mut self.pass.scope, &bind_group.used);
344        }
345
346        if track_indirect_buffer {
347            self.intermediate_trackers
348                .buffers
349                .set_and_remove_from_usage_scope_sparse(
350                    &mut self.pass.scope.buffers,
351                    indirect_buffer.map(|buf| buf.tracker_index()),
352                );
353        } else if let Some(buffer) = indirect_buffer {
354            self.pass
355                .scope
356                .buffers
357                .remove_usage(buffer, wgt::BufferUses::INDIRECT);
358        }
359
360        flush_bindings_helper(&mut self.pass)?;
361
362        CommandEncoder::drain_barriers(
363            self.pass.base.raw_encoder,
364            &mut self.intermediate_trackers,
365            self.pass.base.snatch_guard,
366        );
367        Ok(())
368    }
369}
370
371// Running the compute pass.
372
373impl Global {
374    /// Creates a compute pass.
375    ///
376    /// If creation fails, an invalid pass is returned. Attempting to record
377    /// commands into an invalid pass is permitted, but a validation error will
378    /// ultimately be generated when the parent encoder is finished, and it is
379    /// not possible to run any commands from the invalid pass.
380    ///
381    /// If successful, puts the encoder into the [`Locked`] state.
382    ///
383    /// [`Locked`]: crate::command::CommandEncoderStatus::Locked
384    pub fn command_encoder_begin_compute_pass(
385        &self,
386        encoder_id: id::CommandEncoderId,
387        desc: &ComputePassDescriptor<'_>,
388    ) -> (ComputePass, Option<CommandEncoderError>) {
389        use EncoderStateError as SErr;
390
391        let scope = PassErrorScope::Pass;
392        let hub = &self.hub;
393
394        let label = desc.label.as_deref().map(Cow::Borrowed);
395
396        let cmd_enc = hub.command_encoders.get(encoder_id);
397        let mut cmd_buf_data = cmd_enc.data.lock();
398
399        match cmd_buf_data.lock_encoder() {
400            Ok(()) => {
401                drop(cmd_buf_data);
402                if let Err(err) = cmd_enc.device.check_is_valid() {
403                    return (
404                        ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
405                        None,
406                    );
407                }
408
409                match desc
410                    .timestamp_writes
411                    .as_ref()
412                    .map(|tw| {
413                        Self::validate_pass_timestamp_writes::<ComputePassErrorInner>(
414                            &cmd_enc.device,
415                            &hub.query_sets.read(),
416                            tw,
417                        )
418                    })
419                    .transpose()
420                {
421                    Ok(timestamp_writes) => {
422                        let arc_desc = ArcComputePassDescriptor {
423                            label,
424                            timestamp_writes,
425                        };
426                        (ComputePass::new(cmd_enc, arc_desc), None)
427                    }
428                    Err(err) => (
429                        ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
430                        None,
431                    ),
432                }
433            }
434            Err(err @ SErr::Locked) => {
435                // Attempting to open a new pass while the encoder is locked
436                // invalidates the encoder, but does not generate a validation
437                // error.
438                cmd_buf_data.invalidate(err.clone());
439                drop(cmd_buf_data);
440                (
441                    ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
442                    None,
443                )
444            }
445            Err(err @ (SErr::Ended | SErr::Submitted)) => {
446                // Attempting to open a new pass after the encode has ended
447                // generates an immediate validation error.
448                drop(cmd_buf_data);
449                (
450                    ComputePass::new_invalid(cmd_enc, &label, err.clone().map_pass_err(scope)),
451                    Some(err.into()),
452                )
453            }
454            Err(err @ SErr::Invalid) => {
455                // Passes can be opened even on an invalid encoder. Such passes
456                // are even valid, but since there's no visible side-effect of
457                // the pass being valid and there's no point in storing recorded
458                // commands that will ultimately be discarded, we open an
459                // invalid pass to save that work.
460                drop(cmd_buf_data);
461                (
462                    ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
463                    None,
464                )
465            }
466            Err(SErr::Unlocked) => {
467                unreachable!("lock_encoder cannot fail due to the encoder being unlocked")
468            }
469        }
470    }
471
472    pub fn compute_pass_end(&self, pass: &mut ComputePass) -> Result<(), EncoderStateError> {
473        profiling::scope!(
474            "CommandEncoder::run_compute_pass {}",
475            pass.base.label.as_deref().unwrap_or("")
476        );
477
478        let cmd_enc = pass.parent.take().ok_or(EncoderStateError::Ended)?;
479        let mut cmd_buf_data = cmd_enc.data.lock();
480
481        cmd_buf_data.unlock_encoder()?;
482
483        let base = pass.base.take();
484
485        if let Err(ComputePassError {
486            inner:
487                ComputePassErrorInner::EncoderState(
488                    err @ (EncoderStateError::Locked | EncoderStateError::Ended),
489                ),
490            scope: _,
491        }) = base
492        {
493            // Most encoding errors are detected and raised within `finish()`.
494            //
495            // However, we raise a validation error here if the pass was opened
496            // within another pass, or on a finished encoder. The latter is
497            // particularly important, because in that case reporting errors via
498            // `CommandEncoder::finish` is not possible.
499            return Err(err.clone());
500        }
501
502        cmd_buf_data.push_with(|| -> Result<_, ComputePassError> {
503            Ok(ArcCommand::RunComputePass {
504                pass: base?,
505                timestamp_writes: pass.timestamp_writes.take(),
506            })
507        })
508    }
509}
510
511pub(super) fn encode_compute_pass(
512    parent_state: &mut EncodingState<InnerCommandEncoder>,
513    mut base: BasePass<ArcComputeCommand, Infallible>,
514    mut timestamp_writes: Option<ArcPassTimestampWrites>,
515) -> Result<(), ComputePassError> {
516    let pass_scope = PassErrorScope::Pass;
517
518    let device = parent_state.device;
519
520    // We automatically keep extending command buffers over time, and because
521    // we want to insert a command buffer _before_ what we're about to record,
522    // we need to make sure to close the previous one.
523    parent_state
524        .raw_encoder
525        .close_if_open()
526        .map_pass_err(pass_scope)?;
527    let raw_encoder = parent_state
528        .raw_encoder
529        .open_pass(base.label.as_deref())
530        .map_pass_err(pass_scope)?;
531
532    let mut debug_scope_depth = 0;
533
534    let mut state = State {
535        pipeline: None,
536
537        pass: pass::PassState {
538            base: EncodingState {
539                device,
540                raw_encoder,
541                tracker: parent_state.tracker,
542                buffer_memory_init_actions: parent_state.buffer_memory_init_actions,
543                texture_memory_actions: parent_state.texture_memory_actions,
544                as_actions: parent_state.as_actions,
545                temp_resources: parent_state.temp_resources,
546                indirect_draw_validation_resources: parent_state.indirect_draw_validation_resources,
547                snatch_guard: parent_state.snatch_guard,
548                debug_scope_depth: &mut debug_scope_depth,
549            },
550            binder: Binder::new(),
551            temp_offsets: Vec::new(),
552            dynamic_offset_count: 0,
553            pending_discard_init_fixups: SurfacesInDiscardState::new(),
554            scope: device.new_usage_scope(),
555            string_offset: 0,
556        },
557        active_query: None,
558
559        immediates: Vec::new(),
560
561        intermediate_trackers: Tracker::new(
562            device.ordered_buffer_usages,
563            device.ordered_texture_usages,
564        ),
565    };
566
567    let indices = &device.tracker_indices;
568    state
569        .pass
570        .base
571        .tracker
572        .buffers
573        .set_size(indices.buffers.size());
574    state
575        .pass
576        .base
577        .tracker
578        .textures
579        .set_size(indices.textures.size());
580
581    let timestamp_writes: Option<hal::PassTimestampWrites<'_, dyn hal::DynQuerySet>> =
582        if let Some(tw) = timestamp_writes.take() {
583            tw.query_set.same_device(device).map_pass_err(pass_scope)?;
584
585            let query_set = state
586                .pass
587                .base
588                .tracker
589                .query_sets
590                .insert_single(tw.query_set);
591
592            // Unlike in render passes we can't delay resetting the query sets since
593            // there is no auxiliary pass.
594            let range = if let (Some(index_a), Some(index_b)) =
595                (tw.beginning_of_pass_write_index, tw.end_of_pass_write_index)
596            {
597                Some(index_a.min(index_b)..index_a.max(index_b) + 1)
598            } else {
599                tw.beginning_of_pass_write_index
600                    .or(tw.end_of_pass_write_index)
601                    .map(|i| i..i + 1)
602            };
603            // Range should always be Some, both values being None should lead to a validation error.
604            // But no point in erroring over that nuance here!
605            if let Some(range) = range {
606                unsafe {
607                    state
608                        .pass
609                        .base
610                        .raw_encoder
611                        .reset_queries(query_set.raw(), range);
612                }
613            }
614
615            Some(hal::PassTimestampWrites {
616                query_set: query_set.raw(),
617                beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
618                end_of_pass_write_index: tw.end_of_pass_write_index,
619            })
620        } else {
621            None
622        };
623
624    let hal_desc = hal::ComputePassDescriptor {
625        label: hal_label(base.label.as_deref(), device.instance_flags),
626        timestamp_writes,
627    };
628
629    unsafe {
630        state.pass.base.raw_encoder.begin_compute_pass(&hal_desc);
631    }
632
633    for command in base.commands.drain(..) {
634        match command {
635            ArcComputeCommand::SetBindGroup {
636                index,
637                num_dynamic_offsets,
638                bind_group,
639            } => {
640                let scope = PassErrorScope::SetBindGroup;
641                pass::set_bind_group::<ComputePassErrorInner>(
642                    &mut state.pass,
643                    device,
644                    &base.dynamic_offsets,
645                    index,
646                    num_dynamic_offsets,
647                    bind_group,
648                    false,
649                )
650                .map_pass_err(scope)?;
651            }
652            ArcComputeCommand::SetPipeline(pipeline) => {
653                let scope = PassErrorScope::SetPipelineCompute;
654                set_pipeline(&mut state, device, pipeline).map_pass_err(scope)?;
655            }
656            ArcComputeCommand::SetImmediate {
657                offset,
658                size_bytes,
659                values_offset,
660            } => {
661                let scope = PassErrorScope::SetImmediate;
662                pass::set_immediates::<ComputePassErrorInner, _>(
663                    &mut state.pass,
664                    &base.immediates_data,
665                    offset,
666                    size_bytes,
667                    Some(values_offset),
668                    |data_slice| {
669                        let offset_in_elements = (offset / wgt::IMMEDIATE_DATA_ALIGNMENT) as usize;
670                        let size_in_elements =
671                            (size_bytes / wgt::IMMEDIATE_DATA_ALIGNMENT) as usize;
672                        state.immediates[offset_in_elements..][..size_in_elements]
673                            .copy_from_slice(data_slice);
674                    },
675                )
676                .map_pass_err(scope)?;
677            }
678            ArcComputeCommand::Dispatch(groups) => {
679                let scope = PassErrorScope::Dispatch { indirect: false };
680                dispatch(&mut state, groups).map_pass_err(scope)?;
681            }
682            ArcComputeCommand::DispatchIndirect { buffer, offset } => {
683                let scope = PassErrorScope::Dispatch { indirect: true };
684                dispatch_indirect(&mut state, device, buffer, offset).map_pass_err(scope)?;
685            }
686            ArcComputeCommand::PushDebugGroup { color: _, len } => {
687                pass::push_debug_group(&mut state.pass, &base.string_data, len);
688            }
689            ArcComputeCommand::PopDebugGroup => {
690                let scope = PassErrorScope::PopDebugGroup;
691                pass::pop_debug_group::<ComputePassErrorInner>(&mut state.pass)
692                    .map_pass_err(scope)?;
693            }
694            ArcComputeCommand::InsertDebugMarker { color: _, len } => {
695                pass::insert_debug_marker(&mut state.pass, &base.string_data, len);
696            }
697            ArcComputeCommand::WriteTimestamp {
698                query_set,
699                query_index,
700            } => {
701                let scope = PassErrorScope::WriteTimestamp;
702                pass::write_timestamp::<ComputePassErrorInner>(
703                    &mut state.pass,
704                    device,
705                    None, // compute passes do not attempt to coalesce query resets
706                    query_set,
707                    query_index,
708                )
709                .map_pass_err(scope)?;
710            }
711            ArcComputeCommand::BeginPipelineStatisticsQuery {
712                query_set,
713                query_index,
714            } => {
715                let scope = PassErrorScope::BeginPipelineStatisticsQuery;
716                validate_and_begin_pipeline_statistics_query(
717                    query_set,
718                    state.pass.base.raw_encoder,
719                    &mut state.pass.base.tracker.query_sets,
720                    device,
721                    query_index,
722                    None,
723                    &mut state.active_query,
724                )
725                .map_pass_err(scope)?;
726            }
727            ArcComputeCommand::EndPipelineStatisticsQuery => {
728                let scope = PassErrorScope::EndPipelineStatisticsQuery;
729                end_pipeline_statistics_query(state.pass.base.raw_encoder, &mut state.active_query)
730                    .map_pass_err(scope)?;
731            }
732        }
733    }
734
735    if *state.pass.base.debug_scope_depth > 0 {
736        Err(
737            ComputePassErrorInner::DebugGroupError(DebugGroupError::MissingPop)
738                .map_pass_err(pass_scope),
739        )?;
740    }
741
742    unsafe {
743        state.pass.base.raw_encoder.end_compute_pass();
744    }
745
746    let State {
747        pass: pass::PassState {
748            pending_discard_init_fixups,
749            ..
750        },
751        intermediate_trackers,
752        ..
753    } = state;
754
755    // Stop the current command encoder.
756    parent_state.raw_encoder.close().map_pass_err(pass_scope)?;
757
758    // Create a new command encoder, which we will insert _before_ the body of the compute pass.
759    //
760    // Use that buffer to insert barriers and clear discarded images.
761    let transit = parent_state
762        .raw_encoder
763        .open_pass(hal_label(
764            Some("(wgpu internal) Pre Pass"),
765            device.instance_flags,
766        ))
767        .map_pass_err(pass_scope)?;
768    fixup_discarded_surfaces(
769        pending_discard_init_fixups.into_iter(),
770        transit,
771        &mut parent_state.tracker.textures,
772        device,
773        parent_state.snatch_guard,
774    );
775    CommandEncoder::insert_barriers_from_tracker(
776        transit,
777        parent_state.tracker,
778        &intermediate_trackers,
779        parent_state.snatch_guard,
780    );
781    // Close the command encoder, and swap it with the previous.
782    parent_state
783        .raw_encoder
784        .close_and_swap()
785        .map_pass_err(pass_scope)?;
786
787    Ok(())
788}
789
790fn set_pipeline(
791    state: &mut State,
792    device: &Arc<Device>,
793    pipeline: Arc<ComputePipeline>,
794) -> Result<(), ComputePassErrorInner> {
795    pipeline.same_device(device)?;
796
797    state.pipeline = Some(pipeline.clone());
798
799    let pipeline = state
800        .pass
801        .base
802        .tracker
803        .compute_pipelines
804        .insert_single(pipeline)
805        .clone();
806
807    unsafe {
808        state
809            .pass
810            .base
811            .raw_encoder
812            .set_compute_pipeline(pipeline.raw());
813    }
814
815    // Rebind resources
816    pass::change_pipeline_layout::<ComputePassErrorInner, _>(
817        &mut state.pass,
818        &pipeline.layout,
819        &pipeline.late_sized_buffer_groups,
820        || {
821            // This only needs to be here for compute pipelines because they use immediates for
822            // validating indirect draws.
823            state.immediates.clear();
824            // Note that can only be one range for each stage. See the `MoreThanOneImmediateRangePerStage` error.
825            if pipeline.layout.immediate_size != 0 {
826                // Note that non-0 range start doesn't work anyway https://github.com/gfx-rs/wgpu/issues/4502
827                let len = pipeline.layout.immediate_size as usize
828                    / wgt::IMMEDIATE_DATA_ALIGNMENT as usize;
829                state.immediates.extend(core::iter::repeat_n(0, len));
830            }
831        },
832    )
833}
834
835fn dispatch(state: &mut State, groups: [u32; 3]) -> Result<(), ComputePassErrorInner> {
836    api_log!("ComputePass::dispatch {groups:?}");
837
838    state.is_ready()?;
839
840    state.flush_bindings(None, false)?;
841
842    let groups_size_limit = state
843        .pass
844        .base
845        .device
846        .limits
847        .max_compute_workgroups_per_dimension;
848
849    if groups.iter().copied().any(|g| g > groups_size_limit) {
850        return Err(ComputePassErrorInner::Dispatch(
851            DispatchError::InvalidGroupSize {
852                current: groups,
853                limit: groups_size_limit,
854            },
855        ));
856    }
857
858    unsafe {
859        state.pass.base.raw_encoder.dispatch(groups);
860    }
861    Ok(())
862}
863
864fn dispatch_indirect(
865    state: &mut State,
866    device: &Arc<Device>,
867    buffer: Arc<Buffer>,
868    offset: u64,
869) -> Result<(), ComputePassErrorInner> {
870    api_log!("ComputePass::dispatch_indirect");
871
872    buffer.same_device(device)?;
873
874    state.is_ready()?;
875
876    state
877        .pass
878        .base
879        .device
880        .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?;
881
882    buffer.check_usage(wgt::BufferUsages::INDIRECT)?;
883
884    if !offset.is_multiple_of(4) {
885        return Err(ComputePassErrorInner::UnalignedIndirectBufferOffset(offset));
886    }
887
888    let end_offset = offset + size_of::<wgt::DispatchIndirectArgs>() as u64;
889    if end_offset > buffer.size {
890        return Err(ComputePassErrorInner::IndirectBufferOverrun {
891            offset,
892            end_offset,
893            buffer_size: buffer.size,
894        });
895    }
896
897    buffer.check_destroyed(state.pass.base.snatch_guard)?;
898
899    let stride = 3 * 4; // 3 integers, x/y/z group size
900    state.pass.base.buffer_memory_init_actions.extend(
901        buffer.initialization_status.read().create_action(
902            &buffer,
903            offset..(offset + stride),
904            MemoryInitKind::NeedsInitializedMemory,
905        ),
906    );
907
908    if let Some(ref indirect_validation) = state.pass.base.device.indirect_validation {
909        let params = indirect_validation.dispatch.params(
910            &state.pass.base.device.limits,
911            offset,
912            buffer.size,
913        );
914
915        unsafe {
916            state
917                .pass
918                .base
919                .raw_encoder
920                .set_compute_pipeline(params.pipeline);
921        }
922
923        unsafe {
924            state.pass.base.raw_encoder.set_immediates(
925                params.pipeline_layout,
926                0,
927                &[params.offset_remainder as u32 / 4],
928            );
929        }
930
931        unsafe {
932            state.pass.base.raw_encoder.set_bind_group(
933                params.pipeline_layout,
934                0,
935                params.dst_bind_group,
936                &[],
937            );
938        }
939        unsafe {
940            state.pass.base.raw_encoder.set_bind_group(
941                params.pipeline_layout,
942                1,
943                buffer
944                    .indirect_validation_bind_groups
945                    .get(state.pass.base.snatch_guard)
946                    .unwrap()
947                    .dispatch
948                    .as_ref(),
949                &[params.aligned_offset as u32],
950            );
951        }
952
953        let src_transition = state
954            .intermediate_trackers
955            .buffers
956            .set_single(&buffer, wgt::BufferUses::STORAGE_READ_ONLY);
957        let src_barrier = src_transition
958            .map(|transition| transition.into_hal(&buffer, state.pass.base.snatch_guard));
959        unsafe {
960            state
961                .pass
962                .base
963                .raw_encoder
964                .transition_buffers(src_barrier.as_slice());
965        }
966
967        unsafe {
968            state
969                .pass
970                .base
971                .raw_encoder
972                .transition_buffers(&[hal::BufferBarrier {
973                    buffer: params.dst_buffer,
974                    usage: hal::StateTransition {
975                        from: wgt::BufferUses::INDIRECT,
976                        to: wgt::BufferUses::STORAGE_READ_WRITE,
977                    },
978                }]);
979        }
980
981        unsafe {
982            state.pass.base.raw_encoder.dispatch([1, 1, 1]);
983        }
984
985        // reset state
986        {
987            let pipeline = state.pipeline.as_ref().unwrap();
988
989            unsafe {
990                state
991                    .pass
992                    .base
993                    .raw_encoder
994                    .set_compute_pipeline(pipeline.raw());
995            }
996
997            if !state.immediates.is_empty() {
998                unsafe {
999                    state.pass.base.raw_encoder.set_immediates(
1000                        pipeline.layout.raw(),
1001                        0,
1002                        &state.immediates,
1003                    );
1004                }
1005            }
1006
1007            for (i, group, dynamic_offsets) in state.pass.binder.list_valid() {
1008                let raw_bg = group.try_raw(state.pass.base.snatch_guard)?;
1009                unsafe {
1010                    state.pass.base.raw_encoder.set_bind_group(
1011                        pipeline.layout.raw(),
1012                        i as u32,
1013                        raw_bg,
1014                        dynamic_offsets,
1015                    );
1016                }
1017            }
1018        }
1019
1020        unsafe {
1021            state
1022                .pass
1023                .base
1024                .raw_encoder
1025                .transition_buffers(&[hal::BufferBarrier {
1026                    buffer: params.dst_buffer,
1027                    usage: hal::StateTransition {
1028                        from: wgt::BufferUses::STORAGE_READ_WRITE,
1029                        to: wgt::BufferUses::INDIRECT,
1030                    },
1031                }]);
1032        }
1033
1034        state.flush_bindings(Some(&buffer), false)?;
1035        unsafe {
1036            state
1037                .pass
1038                .base
1039                .raw_encoder
1040                .dispatch_indirect(params.dst_buffer, 0);
1041        }
1042    } else {
1043        state.flush_bindings(Some(&buffer), true)?;
1044
1045        let buf_raw = buffer.try_raw(state.pass.base.snatch_guard)?;
1046        unsafe {
1047            state
1048                .pass
1049                .base
1050                .raw_encoder
1051                .dispatch_indirect(buf_raw, offset);
1052        }
1053    }
1054
1055    Ok(())
1056}
1057
1058// Recording a compute pass.
1059//
1060// The only error that should be returned from these methods is
1061// `EncoderStateError::Ended`, when the pass has already ended and an immediate
1062// validation error is raised.
1063//
1064// All other errors should be stored in the pass for later reporting when
1065// `CommandEncoder.finish()` is called.
1066//
1067// The `pass_try!` macro should be used to handle errors appropriately. Note
1068// that the `pass_try!` and `pass_base!` macros may return early from the
1069// function that invokes them, like the `?` operator.
1070impl Global {
1071    pub fn compute_pass_set_bind_group(
1072        &self,
1073        pass: &mut ComputePass,
1074        index: u32,
1075        bind_group_id: Option<id::BindGroupId>,
1076        offsets: &[DynamicOffset],
1077    ) -> Result<(), PassStateError> {
1078        let scope = PassErrorScope::SetBindGroup;
1079
1080        // This statement will return an error if the pass is ended. It's
1081        // important the error check comes before the early-out for
1082        // `set_and_check_redundant`.
1083        let base = pass_base!(pass, scope);
1084
1085        if pass.current_bind_groups.set_and_check_redundant(
1086            bind_group_id,
1087            index,
1088            &mut base.dynamic_offsets,
1089            offsets,
1090        ) {
1091            return Ok(());
1092        }
1093
1094        let mut bind_group = None;
1095        if let Some(bind_group_id) = bind_group_id {
1096            let hub = &self.hub;
1097            bind_group = Some(pass_try!(
1098                base,
1099                scope,
1100                hub.bind_groups.get(bind_group_id).get(),
1101            ));
1102        }
1103
1104        base.commands.push(ArcComputeCommand::SetBindGroup {
1105            index,
1106            num_dynamic_offsets: offsets.len(),
1107            bind_group,
1108        });
1109
1110        Ok(())
1111    }
1112
1113    pub fn compute_pass_set_pipeline(
1114        &self,
1115        pass: &mut ComputePass,
1116        pipeline_id: id::ComputePipelineId,
1117    ) -> Result<(), PassStateError> {
1118        let redundant = pass.current_pipeline.set_and_check_redundant(pipeline_id);
1119
1120        let scope = PassErrorScope::SetPipelineCompute;
1121
1122        // This statement will return an error if the pass is ended.
1123        // Its important the error check comes before the early-out for `redundant`.
1124        let base = pass_base!(pass, scope);
1125
1126        if redundant {
1127            return Ok(());
1128        }
1129
1130        let hub = &self.hub;
1131        let pipeline = pass_try!(base, scope, hub.compute_pipelines.get(pipeline_id).get());
1132
1133        base.commands.push(ArcComputeCommand::SetPipeline(pipeline));
1134
1135        Ok(())
1136    }
1137
1138    pub fn compute_pass_set_immediates(
1139        &self,
1140        pass: &mut ComputePass,
1141        offset: u32,
1142        data: &[u8],
1143    ) -> Result<(), PassStateError> {
1144        let scope = PassErrorScope::SetImmediate;
1145        let base = pass_base!(pass, scope);
1146
1147        if offset & (wgt::IMMEDIATE_DATA_ALIGNMENT - 1) != 0 {
1148            pass_try!(
1149                base,
1150                scope,
1151                Err(ComputePassErrorInner::ImmediateOffsetAlignment),
1152            );
1153        }
1154
1155        if data.len() as u32 & (wgt::IMMEDIATE_DATA_ALIGNMENT - 1) != 0 {
1156            pass_try!(
1157                base,
1158                scope,
1159                Err(ComputePassErrorInner::ImmediateDataizeAlignment),
1160            )
1161        }
1162        let value_offset = pass_try!(
1163            base,
1164            scope,
1165            base.immediates_data
1166                .len()
1167                .try_into()
1168                .map_err(|_| ComputePassErrorInner::ImmediateOutOfMemory)
1169        );
1170
1171        base.immediates_data.extend(
1172            data.chunks_exact(wgt::IMMEDIATE_DATA_ALIGNMENT as usize)
1173                .map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])),
1174        );
1175
1176        base.commands.push(ArcComputeCommand::SetImmediate {
1177            offset,
1178            size_bytes: data.len() as u32,
1179            values_offset: value_offset,
1180        });
1181
1182        Ok(())
1183    }
1184
1185    pub fn compute_pass_dispatch_workgroups(
1186        &self,
1187        pass: &mut ComputePass,
1188        groups_x: u32,
1189        groups_y: u32,
1190        groups_z: u32,
1191    ) -> Result<(), PassStateError> {
1192        let scope = PassErrorScope::Dispatch { indirect: false };
1193
1194        pass_base!(pass, scope)
1195            .commands
1196            .push(ArcComputeCommand::Dispatch([groups_x, groups_y, groups_z]));
1197
1198        Ok(())
1199    }
1200
1201    pub fn compute_pass_dispatch_workgroups_indirect(
1202        &self,
1203        pass: &mut ComputePass,
1204        buffer_id: id::BufferId,
1205        offset: BufferAddress,
1206    ) -> Result<(), PassStateError> {
1207        let hub = &self.hub;
1208        let scope = PassErrorScope::Dispatch { indirect: true };
1209        let base = pass_base!(pass, scope);
1210
1211        let buffer = pass_try!(base, scope, hub.buffers.get(buffer_id).get());
1212
1213        base.commands
1214            .push(ArcComputeCommand::DispatchIndirect { buffer, offset });
1215
1216        Ok(())
1217    }
1218
1219    pub fn compute_pass_push_debug_group(
1220        &self,
1221        pass: &mut ComputePass,
1222        label: &str,
1223        color: u32,
1224    ) -> Result<(), PassStateError> {
1225        let base = pass_base!(pass, PassErrorScope::PushDebugGroup);
1226
1227        let bytes = label.as_bytes();
1228        base.string_data.extend_from_slice(bytes);
1229
1230        base.commands.push(ArcComputeCommand::PushDebugGroup {
1231            color,
1232            len: bytes.len(),
1233        });
1234
1235        Ok(())
1236    }
1237
1238    pub fn compute_pass_pop_debug_group(
1239        &self,
1240        pass: &mut ComputePass,
1241    ) -> Result<(), PassStateError> {
1242        let base = pass_base!(pass, PassErrorScope::PopDebugGroup);
1243
1244        base.commands.push(ArcComputeCommand::PopDebugGroup);
1245
1246        Ok(())
1247    }
1248
1249    pub fn compute_pass_insert_debug_marker(
1250        &self,
1251        pass: &mut ComputePass,
1252        label: &str,
1253        color: u32,
1254    ) -> Result<(), PassStateError> {
1255        let base = pass_base!(pass, PassErrorScope::InsertDebugMarker);
1256
1257        let bytes = label.as_bytes();
1258        base.string_data.extend_from_slice(bytes);
1259
1260        base.commands.push(ArcComputeCommand::InsertDebugMarker {
1261            color,
1262            len: bytes.len(),
1263        });
1264
1265        Ok(())
1266    }
1267
1268    pub fn compute_pass_write_timestamp(
1269        &self,
1270        pass: &mut ComputePass,
1271        query_set_id: id::QuerySetId,
1272        query_index: u32,
1273    ) -> Result<(), PassStateError> {
1274        let scope = PassErrorScope::WriteTimestamp;
1275        let base = pass_base!(pass, scope);
1276
1277        let hub = &self.hub;
1278        let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1279
1280        base.commands.push(ArcComputeCommand::WriteTimestamp {
1281            query_set,
1282            query_index,
1283        });
1284
1285        Ok(())
1286    }
1287
1288    pub fn compute_pass_begin_pipeline_statistics_query(
1289        &self,
1290        pass: &mut ComputePass,
1291        query_set_id: id::QuerySetId,
1292        query_index: u32,
1293    ) -> Result<(), PassStateError> {
1294        let scope = PassErrorScope::BeginPipelineStatisticsQuery;
1295        let base = pass_base!(pass, scope);
1296
1297        let hub = &self.hub;
1298        let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1299
1300        base.commands
1301            .push(ArcComputeCommand::BeginPipelineStatisticsQuery {
1302                query_set,
1303                query_index,
1304            });
1305
1306        Ok(())
1307    }
1308
1309    pub fn compute_pass_end_pipeline_statistics_query(
1310        &self,
1311        pass: &mut ComputePass,
1312    ) -> Result<(), PassStateError> {
1313        pass_base!(pass, PassErrorScope::EndPipelineStatisticsQuery)
1314            .commands
1315            .push(ArcComputeCommand::EndPipelineStatisticsQuery);
1316
1317        Ok(())
1318    }
1319}