wgpu_core/command/
compute.rs

1use thiserror::Error;
2use wgt::{
3    error::{ErrorType, WebGpuError},
4    BufferAddress, DynamicOffset,
5};
6
7use alloc::{borrow::Cow, boxed::Box, sync::Arc, vec::Vec};
8use core::{fmt, str};
9
10use crate::command::{pass, EncoderStateError, PassStateError, TimestampWritesError};
11use crate::resource::DestroyedResourceError;
12use crate::{binding_model::BindError, resource::RawResourceAccess};
13use crate::{
14    binding_model::{LateMinBufferBindingSizeMismatch, PushConstantUploadError},
15    command::{
16        bind::{Binder, BinderError},
17        compute_command::ArcComputeCommand,
18        end_pipeline_statistics_query,
19        memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState},
20        pass_base, pass_try, validate_and_begin_pipeline_statistics_query, ArcPassTimestampWrites,
21        BasePass, BindGroupStateChange, CommandBuffer, CommandEncoderError, MapPassErr,
22        PassErrorScope, PassTimestampWrites, QueryUseError, StateChange,
23    },
24    device::{DeviceError, MissingDownlevelFlags, MissingFeatures},
25    global::Global,
26    hal_label, id,
27    init_tracker::MemoryInitKind,
28    pipeline::ComputePipeline,
29    resource::{
30        self, Buffer, InvalidResourceError, Labeled, MissingBufferUsageError, ParentDevice,
31    },
32    track::{ResourceUsageCompatibilityError, Tracker, TrackerIndex},
33    Label,
34};
35
36pub type ComputeBasePass = BasePass<ArcComputeCommand, ComputePassError>;
37
38/// A pass's [encoder state](https://www.w3.org/TR/webgpu/#encoder-state) and
39/// its validity are two distinct conditions, i.e., the full matrix of
40/// (open, ended) x (valid, invalid) is possible.
41///
42/// The presence or absence of the `parent` `Option` indicates the pass's state.
43/// The presence or absence of an error in `base.error` indicates the pass's
44/// validity.
45pub struct ComputePass {
46    /// All pass data & records is stored here.
47    base: ComputeBasePass,
48
49    /// Parent command buffer that this pass records commands into.
50    ///
51    /// If this is `Some`, then the pass is in WebGPU's "open" state. If it is
52    /// `None`, then the pass is in the "ended" state.
53    /// See <https://www.w3.org/TR/webgpu/#encoder-state>
54    parent: Option<Arc<CommandBuffer>>,
55
56    timestamp_writes: Option<ArcPassTimestampWrites>,
57
58    // Resource binding dedupe state.
59    current_bind_groups: BindGroupStateChange,
60    current_pipeline: StateChange<id::ComputePipelineId>,
61}
62
63impl ComputePass {
64    /// If the parent command buffer is invalid, the returned pass will be invalid.
65    fn new(parent: Arc<CommandBuffer>, desc: ArcComputePassDescriptor) -> Self {
66        let ArcComputePassDescriptor {
67            label,
68            timestamp_writes,
69        } = desc;
70
71        Self {
72            base: BasePass::new(&label),
73            parent: Some(parent),
74            timestamp_writes,
75
76            current_bind_groups: BindGroupStateChange::new(),
77            current_pipeline: StateChange::new(),
78        }
79    }
80
81    fn new_invalid(parent: Arc<CommandBuffer>, label: &Label, err: ComputePassError) -> Self {
82        Self {
83            base: BasePass::new_invalid(label, err),
84            parent: Some(parent),
85            timestamp_writes: None,
86            current_bind_groups: BindGroupStateChange::new(),
87            current_pipeline: StateChange::new(),
88        }
89    }
90
91    #[inline]
92    pub fn label(&self) -> Option<&str> {
93        self.base.label.as_deref()
94    }
95}
96
97impl fmt::Debug for ComputePass {
98    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99        match self.parent {
100            Some(ref cmd_buf) => write!(f, "ComputePass {{ parent: {} }}", cmd_buf.error_ident()),
101            None => write!(f, "ComputePass {{ parent: None }}"),
102        }
103    }
104}
105
106#[derive(Clone, Debug, Default)]
107pub struct ComputePassDescriptor<'a, PTW = PassTimestampWrites> {
108    pub label: Label<'a>,
109    /// Defines where and when timestamp values will be written for this pass.
110    pub timestamp_writes: Option<PTW>,
111}
112
113/// cbindgen:ignore
114type ArcComputePassDescriptor<'a> = ComputePassDescriptor<'a, ArcPassTimestampWrites>;
115
116#[derive(Clone, Debug, Error)]
117#[non_exhaustive]
118pub enum DispatchError {
119    #[error("Compute pipeline must be set")]
120    MissingPipeline(pass::MissingPipeline),
121    #[error(transparent)]
122    IncompatibleBindGroup(#[from] Box<BinderError>),
123    #[error(
124        "Each current dispatch group size dimension ({current:?}) must be less or equal to {limit}"
125    )]
126    InvalidGroupSize { current: [u32; 3], limit: u32 },
127    #[error(transparent)]
128    BindingSizeTooSmall(#[from] LateMinBufferBindingSizeMismatch),
129}
130
131impl WebGpuError for DispatchError {
132    fn webgpu_error_type(&self) -> ErrorType {
133        ErrorType::Validation
134    }
135}
136
137/// Error encountered when performing a compute pass.
138#[derive(Clone, Debug, Error)]
139pub enum ComputePassErrorInner {
140    #[error(transparent)]
141    Device(#[from] DeviceError),
142    #[error(transparent)]
143    EncoderState(#[from] EncoderStateError),
144    #[error("Parent encoder is invalid")]
145    InvalidParentEncoder,
146    #[error(transparent)]
147    BindGroupIndexOutOfRange(#[from] pass::BindGroupIndexOutOfRange),
148    #[error(transparent)]
149    DestroyedResource(#[from] DestroyedResourceError),
150    #[error("Indirect buffer offset {0:?} is not a multiple of 4")]
151    UnalignedIndirectBufferOffset(BufferAddress),
152    #[error("Indirect buffer uses bytes {offset}..{end_offset} which overruns indirect buffer of size {buffer_size}")]
153    IndirectBufferOverrun {
154        offset: u64,
155        end_offset: u64,
156        buffer_size: u64,
157    },
158    #[error(transparent)]
159    ResourceUsageCompatibility(#[from] ResourceUsageCompatibilityError),
160    #[error(transparent)]
161    MissingBufferUsage(#[from] MissingBufferUsageError),
162    #[error(transparent)]
163    InvalidPopDebugGroup(#[from] pass::InvalidPopDebugGroup),
164    #[error(transparent)]
165    Dispatch(#[from] DispatchError),
166    #[error(transparent)]
167    Bind(#[from] BindError),
168    #[error(transparent)]
169    PushConstants(#[from] PushConstantUploadError),
170    #[error("Push constant offset must be aligned to 4 bytes")]
171    PushConstantOffsetAlignment,
172    #[error("Push constant size must be aligned to 4 bytes")]
173    PushConstantSizeAlignment,
174    #[error("Ran out of push constant space. Don't set 4gb of push constants per ComputePass.")]
175    PushConstantOutOfMemory,
176    #[error(transparent)]
177    QueryUse(#[from] QueryUseError),
178    #[error(transparent)]
179    MissingFeatures(#[from] MissingFeatures),
180    #[error(transparent)]
181    MissingDownlevelFlags(#[from] MissingDownlevelFlags),
182    #[error("The compute pass has already been ended and no further commands can be recorded")]
183    PassEnded,
184    #[error(transparent)]
185    InvalidResource(#[from] InvalidResourceError),
186    #[error(transparent)]
187    TimestampWrites(#[from] TimestampWritesError),
188    // This one is unreachable, but required for generic pass support
189    #[error(transparent)]
190    InvalidValuesOffset(#[from] pass::InvalidValuesOffset),
191}
192
193/// Error encountered when performing a compute pass, stored for later reporting
194/// when encoding ends.
195#[derive(Clone, Debug, Error)]
196#[error("{scope}")]
197pub struct ComputePassError {
198    pub scope: PassErrorScope,
199    #[source]
200    pub(super) inner: ComputePassErrorInner,
201}
202
203impl From<pass::MissingPipeline> for ComputePassErrorInner {
204    fn from(value: pass::MissingPipeline) -> Self {
205        Self::Dispatch(DispatchError::MissingPipeline(value))
206    }
207}
208
209impl<E> MapPassErr<ComputePassError> for E
210where
211    E: Into<ComputePassErrorInner>,
212{
213    fn map_pass_err(self, scope: PassErrorScope) -> ComputePassError {
214        ComputePassError {
215            scope,
216            inner: self.into(),
217        }
218    }
219}
220
221impl WebGpuError for ComputePassError {
222    fn webgpu_error_type(&self) -> ErrorType {
223        let Self { scope: _, inner } = self;
224        let e: &dyn WebGpuError = match inner {
225            ComputePassErrorInner::Device(e) => e,
226            ComputePassErrorInner::EncoderState(e) => e,
227            ComputePassErrorInner::DestroyedResource(e) => e,
228            ComputePassErrorInner::ResourceUsageCompatibility(e) => e,
229            ComputePassErrorInner::MissingBufferUsage(e) => e,
230            ComputePassErrorInner::Dispatch(e) => e,
231            ComputePassErrorInner::Bind(e) => e,
232            ComputePassErrorInner::PushConstants(e) => e,
233            ComputePassErrorInner::QueryUse(e) => e,
234            ComputePassErrorInner::MissingFeatures(e) => e,
235            ComputePassErrorInner::MissingDownlevelFlags(e) => e,
236            ComputePassErrorInner::InvalidResource(e) => e,
237            ComputePassErrorInner::TimestampWrites(e) => e,
238            ComputePassErrorInner::InvalidValuesOffset(e) => e,
239            ComputePassErrorInner::InvalidPopDebugGroup(e) => e,
240
241            ComputePassErrorInner::InvalidParentEncoder
242            | ComputePassErrorInner::BindGroupIndexOutOfRange { .. }
243            | ComputePassErrorInner::UnalignedIndirectBufferOffset(_)
244            | ComputePassErrorInner::IndirectBufferOverrun { .. }
245            | ComputePassErrorInner::PushConstantOffsetAlignment
246            | ComputePassErrorInner::PushConstantSizeAlignment
247            | ComputePassErrorInner::PushConstantOutOfMemory
248            | ComputePassErrorInner::PassEnded => return ErrorType::Validation,
249        };
250        e.webgpu_error_type()
251    }
252}
253
254struct State<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder> {
255    pipeline: Option<Arc<ComputePipeline>>,
256
257    general: pass::BaseState<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder>,
258
259    active_query: Option<(Arc<resource::QuerySet>, u32)>,
260
261    push_constants: Vec<u32>,
262
263    intermediate_trackers: Tracker,
264}
265
266impl<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder>
267    State<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder>
268{
269    fn is_ready(&self) -> Result<(), DispatchError> {
270        if let Some(pipeline) = self.pipeline.as_ref() {
271            self.general.binder.check_compatibility(pipeline.as_ref())?;
272            self.general.binder.check_late_buffer_bindings()?;
273            Ok(())
274        } else {
275            Err(DispatchError::MissingPipeline(pass::MissingPipeline))
276        }
277    }
278
279    // `extra_buffer` is there to represent the indirect buffer that is also
280    // part of the usage scope.
281    fn flush_states(
282        &mut self,
283        indirect_buffer: Option<TrackerIndex>,
284    ) -> Result<(), ResourceUsageCompatibilityError> {
285        for bind_group in self.general.binder.list_active() {
286            unsafe { self.general.scope.merge_bind_group(&bind_group.used)? };
287            // Note: stateless trackers are not merged: the lifetime reference
288            // is held to the bind group itself.
289        }
290
291        for bind_group in self.general.binder.list_active() {
292            unsafe {
293                self.intermediate_trackers
294                    .set_and_remove_from_usage_scope_sparse(
295                        &mut self.general.scope,
296                        &bind_group.used,
297                    )
298            }
299        }
300
301        // Add the state of the indirect buffer if it hasn't been hit before.
302        unsafe {
303            self.intermediate_trackers
304                .buffers
305                .set_and_remove_from_usage_scope_sparse(
306                    &mut self.general.scope.buffers,
307                    indirect_buffer,
308                );
309        }
310
311        CommandBuffer::drain_barriers(
312            self.general.raw_encoder,
313            &mut self.intermediate_trackers,
314            self.general.snatch_guard,
315        );
316        Ok(())
317    }
318}
319
320// Running the compute pass.
321
322impl Global {
323    /// Creates a compute pass.
324    ///
325    /// If creation fails, an invalid pass is returned. Attempting to record
326    /// commands into an invalid pass is permitted, but a validation error will
327    /// ultimately be generated when the parent encoder is finished, and it is
328    /// not possible to run any commands from the invalid pass.
329    ///
330    /// If successful, puts the encoder into the [`Locked`] state.
331    ///
332    /// [`Locked`]: crate::command::CommandEncoderStatus::Locked
333    pub fn command_encoder_begin_compute_pass(
334        &self,
335        encoder_id: id::CommandEncoderId,
336        desc: &ComputePassDescriptor<'_>,
337    ) -> (ComputePass, Option<CommandEncoderError>) {
338        use EncoderStateError as SErr;
339
340        let scope = PassErrorScope::Pass;
341        let hub = &self.hub;
342
343        let label = desc.label.as_deref().map(Cow::Borrowed);
344
345        let cmd_buf = hub.command_buffers.get(encoder_id.into_command_buffer_id());
346        let mut cmd_buf_data = cmd_buf.data.lock();
347
348        match cmd_buf_data.lock_encoder() {
349            Ok(()) => {
350                drop(cmd_buf_data);
351                if let Err(err) = cmd_buf.device.check_is_valid() {
352                    return (
353                        ComputePass::new_invalid(cmd_buf, &label, err.map_pass_err(scope)),
354                        None,
355                    );
356                }
357
358                match desc
359                    .timestamp_writes
360                    .as_ref()
361                    .map(|tw| {
362                        Self::validate_pass_timestamp_writes::<ComputePassErrorInner>(
363                            &cmd_buf.device,
364                            &hub.query_sets.read(),
365                            tw,
366                        )
367                    })
368                    .transpose()
369                {
370                    Ok(timestamp_writes) => {
371                        let arc_desc = ArcComputePassDescriptor {
372                            label,
373                            timestamp_writes,
374                        };
375                        (ComputePass::new(cmd_buf, arc_desc), None)
376                    }
377                    Err(err) => (
378                        ComputePass::new_invalid(cmd_buf, &label, err.map_pass_err(scope)),
379                        None,
380                    ),
381                }
382            }
383            Err(err @ SErr::Locked) => {
384                // Attempting to open a new pass while the encoder is locked
385                // invalidates the encoder, but does not generate a validation
386                // error.
387                cmd_buf_data.invalidate(err.clone());
388                drop(cmd_buf_data);
389                (
390                    ComputePass::new_invalid(cmd_buf, &label, err.map_pass_err(scope)),
391                    None,
392                )
393            }
394            Err(err @ (SErr::Ended | SErr::Submitted)) => {
395                // Attempting to open a new pass after the encode has ended
396                // generates an immediate validation error.
397                drop(cmd_buf_data);
398                (
399                    ComputePass::new_invalid(cmd_buf, &label, err.clone().map_pass_err(scope)),
400                    Some(err.into()),
401                )
402            }
403            Err(err @ SErr::Invalid) => {
404                // Passes can be opened even on an invalid encoder. Such passes
405                // are even valid, but since there's no visible side-effect of
406                // the pass being valid and there's no point in storing recorded
407                // commands that will ultimately be discarded, we open an
408                // invalid pass to save that work.
409                drop(cmd_buf_data);
410                (
411                    ComputePass::new_invalid(cmd_buf, &label, err.map_pass_err(scope)),
412                    None,
413                )
414            }
415            Err(SErr::Unlocked) => {
416                unreachable!("lock_encoder cannot fail due to the encoder being unlocked")
417            }
418        }
419    }
420
421    /// Note that this differs from [`Self::compute_pass_end`], it will
422    /// create a new pass, replay the commands and end the pass.
423    ///
424    /// # Panics
425    /// On any error.
426    #[doc(hidden)]
427    #[cfg(any(feature = "serde", feature = "replay"))]
428    pub fn compute_pass_end_with_unresolved_commands(
429        &self,
430        encoder_id: id::CommandEncoderId,
431        base: BasePass<super::ComputeCommand, core::convert::Infallible>,
432        timestamp_writes: Option<&PassTimestampWrites>,
433    ) {
434        #[cfg(feature = "trace")]
435        {
436            let cmd_buf = self
437                .hub
438                .command_buffers
439                .get(encoder_id.into_command_buffer_id());
440            let mut cmd_buf_data = cmd_buf.data.lock();
441            let cmd_buf_data = cmd_buf_data.get_inner();
442
443            if let Some(ref mut list) = cmd_buf_data.commands {
444                list.push(crate::device::trace::Command::RunComputePass {
445                    base: BasePass {
446                        label: base.label.clone(),
447                        error: None,
448                        commands: base.commands.clone(),
449                        dynamic_offsets: base.dynamic_offsets.clone(),
450                        string_data: base.string_data.clone(),
451                        push_constant_data: base.push_constant_data.clone(),
452                    },
453                    timestamp_writes: timestamp_writes.cloned(),
454                });
455            }
456        }
457
458        let BasePass {
459            label,
460            error: _,
461            commands,
462            dynamic_offsets,
463            string_data,
464            push_constant_data,
465        } = base;
466
467        let (mut compute_pass, encoder_error) = self.command_encoder_begin_compute_pass(
468            encoder_id,
469            &ComputePassDescriptor {
470                label: label.as_deref().map(Cow::Borrowed),
471                timestamp_writes: timestamp_writes.cloned(),
472            },
473        );
474        if let Some(err) = encoder_error {
475            panic!("{:?}", err);
476        };
477
478        compute_pass.base = BasePass {
479            label,
480            error: None,
481            commands: super::ComputeCommand::resolve_compute_command_ids(&self.hub, &commands)
482                .unwrap(),
483            dynamic_offsets,
484            string_data,
485            push_constant_data,
486        };
487
488        self.compute_pass_end(&mut compute_pass).unwrap();
489    }
490
491    pub fn compute_pass_end(&self, pass: &mut ComputePass) -> Result<(), EncoderStateError> {
492        let pass_scope = PassErrorScope::Pass;
493        profiling::scope!(
494            "CommandEncoder::run_compute_pass {}",
495            pass.base.label.as_deref().unwrap_or("")
496        );
497
498        let cmd_buf = pass.parent.take().ok_or(EncoderStateError::Ended)?;
499        let mut cmd_buf_data = cmd_buf.data.lock();
500
501        if let Some(err) = pass.base.error.take() {
502            if matches!(
503                err,
504                ComputePassError {
505                    inner: ComputePassErrorInner::EncoderState(EncoderStateError::Ended),
506                    scope: _,
507                }
508            ) {
509                // If the encoder was already finished at time of pass creation,
510                // then it was not put in the locked state, so we need to
511                // generate a validation error here due to the encoder not being
512                // locked. The encoder already has a copy of the error.
513                return Err(EncoderStateError::Ended);
514            } else {
515                // If the pass is invalid, invalidate the parent encoder and return.
516                // Since we do not track the state of an invalid encoder, it is not
517                // necessary to unlock it.
518                cmd_buf_data.invalidate(err);
519                return Ok(());
520            }
521        }
522
523        cmd_buf_data.unlock_and_record(|cmd_buf_data| -> Result<(), ComputePassError> {
524            let device = &cmd_buf.device;
525            device.check_is_valid().map_pass_err(pass_scope)?;
526
527            let base = &mut pass.base;
528
529            let encoder = &mut cmd_buf_data.encoder;
530
531            // We automatically keep extending command buffers over time, and because
532            // we want to insert a command buffer _before_ what we're about to record,
533            // we need to make sure to close the previous one.
534            encoder.close_if_open().map_pass_err(pass_scope)?;
535            let raw_encoder = encoder
536                .open_pass(base.label.as_deref())
537                .map_pass_err(pass_scope)?;
538
539            let snatch_guard = device.snatchable_lock.read();
540
541            let mut state = State {
542                pipeline: None,
543
544                general: pass::BaseState {
545                    device,
546                    raw_encoder,
547                    tracker: &mut cmd_buf_data.trackers,
548                    buffer_memory_init_actions: &mut cmd_buf_data.buffer_memory_init_actions,
549                    texture_memory_actions: &mut cmd_buf_data.texture_memory_actions,
550                    as_actions: &mut cmd_buf_data.as_actions,
551                    binder: Binder::new(),
552                    temp_offsets: Vec::new(),
553                    dynamic_offset_count: 0,
554
555                    pending_discard_init_fixups: SurfacesInDiscardState::new(),
556
557                    snatch_guard: &snatch_guard,
558                    scope: device.new_usage_scope(),
559
560                    debug_scope_depth: 0,
561                    string_offset: 0,
562                },
563                active_query: None,
564
565                push_constants: Vec::new(),
566
567                intermediate_trackers: Tracker::new(),
568            };
569
570            let indices = &state.general.device.tracker_indices;
571            state
572                .general
573                .tracker
574                .buffers
575                .set_size(indices.buffers.size());
576            state
577                .general
578                .tracker
579                .textures
580                .set_size(indices.textures.size());
581
582            let timestamp_writes: Option<hal::PassTimestampWrites<'_, dyn hal::DynQuerySet>> =
583                if let Some(tw) = pass.timestamp_writes.take() {
584                    tw.query_set
585                        .same_device_as(cmd_buf.as_ref())
586                        .map_pass_err(pass_scope)?;
587
588                    let query_set = state.general.tracker.query_sets.insert_single(tw.query_set);
589
590                    // Unlike in render passes we can't delay resetting the query sets since
591                    // there is no auxiliary pass.
592                    let range = if let (Some(index_a), Some(index_b)) =
593                        (tw.beginning_of_pass_write_index, tw.end_of_pass_write_index)
594                    {
595                        Some(index_a.min(index_b)..index_a.max(index_b) + 1)
596                    } else {
597                        tw.beginning_of_pass_write_index
598                            .or(tw.end_of_pass_write_index)
599                            .map(|i| i..i + 1)
600                    };
601                    // Range should always be Some, both values being None should lead to a validation error.
602                    // But no point in erroring over that nuance here!
603                    if let Some(range) = range {
604                        unsafe {
605                            state
606                                .general
607                                .raw_encoder
608                                .reset_queries(query_set.raw(), range);
609                        }
610                    }
611
612                    Some(hal::PassTimestampWrites {
613                        query_set: query_set.raw(),
614                        beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
615                        end_of_pass_write_index: tw.end_of_pass_write_index,
616                    })
617                } else {
618                    None
619                };
620
621            let hal_desc = hal::ComputePassDescriptor {
622                label: hal_label(base.label.as_deref(), device.instance_flags),
623                timestamp_writes,
624            };
625
626            unsafe {
627                state.general.raw_encoder.begin_compute_pass(&hal_desc);
628            }
629
630            for command in base.commands.drain(..) {
631                match command {
632                    ArcComputeCommand::SetBindGroup {
633                        index,
634                        num_dynamic_offsets,
635                        bind_group,
636                    } => {
637                        let scope = PassErrorScope::SetBindGroup;
638                        pass::set_bind_group::<ComputePassErrorInner>(
639                            &mut state.general,
640                            cmd_buf.as_ref(),
641                            &base.dynamic_offsets,
642                            index,
643                            num_dynamic_offsets,
644                            bind_group,
645                            false,
646                        )
647                        .map_pass_err(scope)?;
648                    }
649                    ArcComputeCommand::SetPipeline(pipeline) => {
650                        let scope = PassErrorScope::SetPipelineCompute;
651                        set_pipeline(&mut state, cmd_buf.as_ref(), pipeline).map_pass_err(scope)?;
652                    }
653                    ArcComputeCommand::SetPushConstant {
654                        offset,
655                        size_bytes,
656                        values_offset,
657                    } => {
658                        let scope = PassErrorScope::SetPushConstant;
659                        pass::set_push_constant::<ComputePassErrorInner, _>(
660                            &mut state.general,
661                            &base.push_constant_data,
662                            wgt::ShaderStages::COMPUTE,
663                            offset,
664                            size_bytes,
665                            Some(values_offset),
666                            |data_slice| {
667                                let offset_in_elements =
668                                    (offset / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
669                                let size_in_elements =
670                                    (size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
671                                state.push_constants[offset_in_elements..][..size_in_elements]
672                                    .copy_from_slice(data_slice);
673                            },
674                        )
675                        .map_pass_err(scope)?;
676                    }
677                    ArcComputeCommand::Dispatch(groups) => {
678                        let scope = PassErrorScope::Dispatch { indirect: false };
679                        dispatch(&mut state, groups).map_pass_err(scope)?;
680                    }
681                    ArcComputeCommand::DispatchIndirect { buffer, offset } => {
682                        let scope = PassErrorScope::Dispatch { indirect: true };
683                        dispatch_indirect(&mut state, cmd_buf.as_ref(), buffer, offset)
684                            .map_pass_err(scope)?;
685                    }
686                    ArcComputeCommand::PushDebugGroup { color: _, len } => {
687                        pass::push_debug_group(&mut state.general, &base.string_data, len);
688                    }
689                    ArcComputeCommand::PopDebugGroup => {
690                        let scope = PassErrorScope::PopDebugGroup;
691                        pass::pop_debug_group::<ComputePassErrorInner>(&mut state.general)
692                            .map_pass_err(scope)?;
693                    }
694                    ArcComputeCommand::InsertDebugMarker { color: _, len } => {
695                        pass::insert_debug_marker(&mut state.general, &base.string_data, len);
696                    }
697                    ArcComputeCommand::WriteTimestamp {
698                        query_set,
699                        query_index,
700                    } => {
701                        let scope = PassErrorScope::WriteTimestamp;
702                        pass::write_timestamp::<ComputePassErrorInner>(
703                            &mut state.general,
704                            cmd_buf.as_ref(),
705                            None,
706                            query_set,
707                            query_index,
708                        )
709                        .map_pass_err(scope)?;
710                    }
711                    ArcComputeCommand::BeginPipelineStatisticsQuery {
712                        query_set,
713                        query_index,
714                    } => {
715                        let scope = PassErrorScope::BeginPipelineStatisticsQuery;
716                        validate_and_begin_pipeline_statistics_query(
717                            query_set,
718                            state.general.raw_encoder,
719                            &mut state.general.tracker.query_sets,
720                            cmd_buf.as_ref(),
721                            query_index,
722                            None,
723                            &mut state.active_query,
724                        )
725                        .map_pass_err(scope)?;
726                    }
727                    ArcComputeCommand::EndPipelineStatisticsQuery => {
728                        let scope = PassErrorScope::EndPipelineStatisticsQuery;
729                        end_pipeline_statistics_query(
730                            state.general.raw_encoder,
731                            &mut state.active_query,
732                        )
733                        .map_pass_err(scope)?;
734                    }
735                }
736            }
737
738            unsafe {
739                state.general.raw_encoder.end_compute_pass();
740            }
741
742            let State {
743                general:
744                    pass::BaseState {
745                        tracker,
746                        pending_discard_init_fixups,
747                        ..
748                    },
749                intermediate_trackers,
750                ..
751            } = state;
752
753            // Stop the current command buffer.
754            encoder.close().map_pass_err(pass_scope)?;
755
756            // Create a new command buffer, which we will insert _before_ the body of the compute pass.
757            //
758            // Use that buffer to insert barriers and clear discarded images.
759            let transit = encoder
760                .open_pass(Some("(wgpu internal) Pre Pass"))
761                .map_pass_err(pass_scope)?;
762            fixup_discarded_surfaces(
763                pending_discard_init_fixups.into_iter(),
764                transit,
765                &mut tracker.textures,
766                device,
767                &snatch_guard,
768            );
769            CommandBuffer::insert_barriers_from_tracker(
770                transit,
771                tracker,
772                &intermediate_trackers,
773                &snatch_guard,
774            );
775            // Close the command buffer, and swap it with the previous.
776            encoder.close_and_swap().map_pass_err(pass_scope)?;
777
778            Ok(())
779        })
780    }
781}
782
783fn set_pipeline(
784    state: &mut State,
785    cmd_buf: &CommandBuffer,
786    pipeline: Arc<ComputePipeline>,
787) -> Result<(), ComputePassErrorInner> {
788    pipeline.same_device_as(cmd_buf)?;
789
790    state.pipeline = Some(pipeline.clone());
791
792    let pipeline = state
793        .general
794        .tracker
795        .compute_pipelines
796        .insert_single(pipeline)
797        .clone();
798
799    unsafe {
800        state
801            .general
802            .raw_encoder
803            .set_compute_pipeline(pipeline.raw());
804    }
805
806    // Rebind resources
807    pass::rebind_resources::<ComputePassErrorInner, _>(
808        &mut state.general,
809        &pipeline.layout,
810        &pipeline.late_sized_buffer_groups,
811        || {
812            // This only needs to be here for compute pipelines because they use push constants for
813            // validating indirect draws.
814            state.push_constants.clear();
815            // Note that can only be one range for each stage. See the `MoreThanOnePushConstantRangePerStage` error.
816            if let Some(push_constant_range) =
817                pipeline.layout.push_constant_ranges.iter().find_map(|pcr| {
818                    pcr.stages
819                        .contains(wgt::ShaderStages::COMPUTE)
820                        .then_some(pcr.range.clone())
821                })
822            {
823                // Note that non-0 range start doesn't work anyway https://github.com/gfx-rs/wgpu/issues/4502
824                let len = push_constant_range.len() / wgt::PUSH_CONSTANT_ALIGNMENT as usize;
825                state.push_constants.extend(core::iter::repeat_n(0, len));
826            }
827        },
828    )
829}
830
831fn dispatch(state: &mut State, groups: [u32; 3]) -> Result<(), ComputePassErrorInner> {
832    state.is_ready()?;
833
834    state.flush_states(None)?;
835
836    let groups_size_limit = state
837        .general
838        .device
839        .limits
840        .max_compute_workgroups_per_dimension;
841
842    if groups[0] > groups_size_limit
843        || groups[1] > groups_size_limit
844        || groups[2] > groups_size_limit
845    {
846        return Err(ComputePassErrorInner::Dispatch(
847            DispatchError::InvalidGroupSize {
848                current: groups,
849                limit: groups_size_limit,
850            },
851        ));
852    }
853
854    unsafe {
855        state.general.raw_encoder.dispatch(groups);
856    }
857    Ok(())
858}
859
860fn dispatch_indirect(
861    state: &mut State,
862    cmd_buf: &CommandBuffer,
863    buffer: Arc<Buffer>,
864    offset: u64,
865) -> Result<(), ComputePassErrorInner> {
866    buffer.same_device_as(cmd_buf)?;
867
868    state.is_ready()?;
869
870    state
871        .general
872        .device
873        .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?;
874
875    buffer.check_usage(wgt::BufferUsages::INDIRECT)?;
876    buffer.check_destroyed(state.general.snatch_guard)?;
877
878    if offset % 4 != 0 {
879        return Err(ComputePassErrorInner::UnalignedIndirectBufferOffset(offset));
880    }
881
882    let end_offset = offset + size_of::<wgt::DispatchIndirectArgs>() as u64;
883    if end_offset > buffer.size {
884        return Err(ComputePassErrorInner::IndirectBufferOverrun {
885            offset,
886            end_offset,
887            buffer_size: buffer.size,
888        });
889    }
890
891    let stride = 3 * 4; // 3 integers, x/y/z group size
892    state.general.buffer_memory_init_actions.extend(
893        buffer.initialization_status.read().create_action(
894            &buffer,
895            offset..(offset + stride),
896            MemoryInitKind::NeedsInitializedMemory,
897        ),
898    );
899
900    if let Some(ref indirect_validation) = state.general.device.indirect_validation {
901        let params =
902            indirect_validation
903                .dispatch
904                .params(&state.general.device.limits, offset, buffer.size);
905
906        unsafe {
907            state
908                .general
909                .raw_encoder
910                .set_compute_pipeline(params.pipeline);
911        }
912
913        unsafe {
914            state.general.raw_encoder.set_push_constants(
915                params.pipeline_layout,
916                wgt::ShaderStages::COMPUTE,
917                0,
918                &[params.offset_remainder as u32 / 4],
919            );
920        }
921
922        unsafe {
923            state.general.raw_encoder.set_bind_group(
924                params.pipeline_layout,
925                0,
926                Some(params.dst_bind_group),
927                &[],
928            );
929        }
930        unsafe {
931            state.general.raw_encoder.set_bind_group(
932                params.pipeline_layout,
933                1,
934                Some(
935                    buffer
936                        .indirect_validation_bind_groups
937                        .get(state.general.snatch_guard)
938                        .unwrap()
939                        .dispatch
940                        .as_ref(),
941                ),
942                &[params.aligned_offset as u32],
943            );
944        }
945
946        let src_transition = state
947            .intermediate_trackers
948            .buffers
949            .set_single(&buffer, wgt::BufferUses::STORAGE_READ_ONLY);
950        let src_barrier = src_transition
951            .map(|transition| transition.into_hal(&buffer, state.general.snatch_guard));
952        unsafe {
953            state
954                .general
955                .raw_encoder
956                .transition_buffers(src_barrier.as_slice());
957        }
958
959        unsafe {
960            state
961                .general
962                .raw_encoder
963                .transition_buffers(&[hal::BufferBarrier {
964                    buffer: params.dst_buffer,
965                    usage: hal::StateTransition {
966                        from: wgt::BufferUses::INDIRECT,
967                        to: wgt::BufferUses::STORAGE_READ_WRITE,
968                    },
969                }]);
970        }
971
972        unsafe {
973            state.general.raw_encoder.dispatch([1, 1, 1]);
974        }
975
976        // reset state
977        {
978            let pipeline = state.pipeline.as_ref().unwrap();
979
980            unsafe {
981                state
982                    .general
983                    .raw_encoder
984                    .set_compute_pipeline(pipeline.raw());
985            }
986
987            if !state.push_constants.is_empty() {
988                unsafe {
989                    state.general.raw_encoder.set_push_constants(
990                        pipeline.layout.raw(),
991                        wgt::ShaderStages::COMPUTE,
992                        0,
993                        &state.push_constants,
994                    );
995                }
996            }
997
998            for (i, e) in state.general.binder.list_valid() {
999                let group = e.group.as_ref().unwrap();
1000                let raw_bg = group.try_raw(state.general.snatch_guard)?;
1001                unsafe {
1002                    state.general.raw_encoder.set_bind_group(
1003                        pipeline.layout.raw(),
1004                        i as u32,
1005                        Some(raw_bg),
1006                        &e.dynamic_offsets,
1007                    );
1008                }
1009            }
1010        }
1011
1012        unsafe {
1013            state
1014                .general
1015                .raw_encoder
1016                .transition_buffers(&[hal::BufferBarrier {
1017                    buffer: params.dst_buffer,
1018                    usage: hal::StateTransition {
1019                        from: wgt::BufferUses::STORAGE_READ_WRITE,
1020                        to: wgt::BufferUses::INDIRECT,
1021                    },
1022                }]);
1023        }
1024
1025        state.flush_states(None)?;
1026        unsafe {
1027            state
1028                .general
1029                .raw_encoder
1030                .dispatch_indirect(params.dst_buffer, 0);
1031        }
1032    } else {
1033        state
1034            .general
1035            .scope
1036            .buffers
1037            .merge_single(&buffer, wgt::BufferUses::INDIRECT)?;
1038
1039        use crate::resource::Trackable;
1040        state.flush_states(Some(buffer.tracker_index()))?;
1041
1042        let buf_raw = buffer.try_raw(state.general.snatch_guard)?;
1043        unsafe {
1044            state.general.raw_encoder.dispatch_indirect(buf_raw, offset);
1045        }
1046    }
1047
1048    Ok(())
1049}
1050
1051// Recording a compute pass.
1052//
1053// The only error that should be returned from these methods is
1054// `EncoderStateError::Ended`, when the pass has already ended and an immediate
1055// validation error is raised.
1056//
1057// All other errors should be stored in the pass for later reporting when
1058// `CommandEncoder.finish()` is called.
1059//
1060// The `pass_try!` macro should be used to handle errors appropriately. Note
1061// that the `pass_try!` and `pass_base!` macros may return early from the
1062// function that invokes them, like the `?` operator.
1063impl Global {
1064    pub fn compute_pass_set_bind_group(
1065        &self,
1066        pass: &mut ComputePass,
1067        index: u32,
1068        bind_group_id: Option<id::BindGroupId>,
1069        offsets: &[DynamicOffset],
1070    ) -> Result<(), PassStateError> {
1071        let scope = PassErrorScope::SetBindGroup;
1072
1073        // This statement will return an error if the pass is ended. It's
1074        // important the error check comes before the early-out for
1075        // `set_and_check_redundant`.
1076        let base = pass_base!(pass, scope);
1077
1078        if pass.current_bind_groups.set_and_check_redundant(
1079            bind_group_id,
1080            index,
1081            &mut base.dynamic_offsets,
1082            offsets,
1083        ) {
1084            return Ok(());
1085        }
1086
1087        let mut bind_group = None;
1088        if bind_group_id.is_some() {
1089            let bind_group_id = bind_group_id.unwrap();
1090
1091            let hub = &self.hub;
1092            bind_group = Some(pass_try!(
1093                base,
1094                scope,
1095                hub.bind_groups.get(bind_group_id).get(),
1096            ));
1097        }
1098
1099        base.commands.push(ArcComputeCommand::SetBindGroup {
1100            index,
1101            num_dynamic_offsets: offsets.len(),
1102            bind_group,
1103        });
1104
1105        Ok(())
1106    }
1107
1108    pub fn compute_pass_set_pipeline(
1109        &self,
1110        pass: &mut ComputePass,
1111        pipeline_id: id::ComputePipelineId,
1112    ) -> Result<(), PassStateError> {
1113        let redundant = pass.current_pipeline.set_and_check_redundant(pipeline_id);
1114
1115        let scope = PassErrorScope::SetPipelineCompute;
1116
1117        // This statement will return an error if the pass is ended.
1118        // Its important the error check comes before the early-out for `redundant`.
1119        let base = pass_base!(pass, scope);
1120
1121        if redundant {
1122            return Ok(());
1123        }
1124
1125        let hub = &self.hub;
1126        let pipeline = pass_try!(base, scope, hub.compute_pipelines.get(pipeline_id).get());
1127
1128        base.commands.push(ArcComputeCommand::SetPipeline(pipeline));
1129
1130        Ok(())
1131    }
1132
1133    pub fn compute_pass_set_push_constants(
1134        &self,
1135        pass: &mut ComputePass,
1136        offset: u32,
1137        data: &[u8],
1138    ) -> Result<(), PassStateError> {
1139        let scope = PassErrorScope::SetPushConstant;
1140        let base = pass_base!(pass, scope);
1141
1142        if offset & (wgt::PUSH_CONSTANT_ALIGNMENT - 1) != 0 {
1143            pass_try!(
1144                base,
1145                scope,
1146                Err(ComputePassErrorInner::PushConstantOffsetAlignment),
1147            );
1148        }
1149
1150        if data.len() as u32 & (wgt::PUSH_CONSTANT_ALIGNMENT - 1) != 0 {
1151            pass_try!(
1152                base,
1153                scope,
1154                Err(ComputePassErrorInner::PushConstantSizeAlignment),
1155            )
1156        }
1157        let value_offset = pass_try!(
1158            base,
1159            scope,
1160            base.push_constant_data
1161                .len()
1162                .try_into()
1163                .map_err(|_| ComputePassErrorInner::PushConstantOutOfMemory)
1164        );
1165
1166        base.push_constant_data.extend(
1167            data.chunks_exact(wgt::PUSH_CONSTANT_ALIGNMENT as usize)
1168                .map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])),
1169        );
1170
1171        base.commands.push(ArcComputeCommand::SetPushConstant {
1172            offset,
1173            size_bytes: data.len() as u32,
1174            values_offset: value_offset,
1175        });
1176
1177        Ok(())
1178    }
1179
1180    pub fn compute_pass_dispatch_workgroups(
1181        &self,
1182        pass: &mut ComputePass,
1183        groups_x: u32,
1184        groups_y: u32,
1185        groups_z: u32,
1186    ) -> Result<(), PassStateError> {
1187        let scope = PassErrorScope::Dispatch { indirect: false };
1188
1189        pass_base!(pass, scope)
1190            .commands
1191            .push(ArcComputeCommand::Dispatch([groups_x, groups_y, groups_z]));
1192
1193        Ok(())
1194    }
1195
1196    pub fn compute_pass_dispatch_workgroups_indirect(
1197        &self,
1198        pass: &mut ComputePass,
1199        buffer_id: id::BufferId,
1200        offset: BufferAddress,
1201    ) -> Result<(), PassStateError> {
1202        let hub = &self.hub;
1203        let scope = PassErrorScope::Dispatch { indirect: true };
1204        let base = pass_base!(pass, scope);
1205
1206        let buffer = pass_try!(base, scope, hub.buffers.get(buffer_id).get());
1207
1208        base.commands
1209            .push(ArcComputeCommand::DispatchIndirect { buffer, offset });
1210
1211        Ok(())
1212    }
1213
1214    pub fn compute_pass_push_debug_group(
1215        &self,
1216        pass: &mut ComputePass,
1217        label: &str,
1218        color: u32,
1219    ) -> Result<(), PassStateError> {
1220        let base = pass_base!(pass, PassErrorScope::PushDebugGroup);
1221
1222        let bytes = label.as_bytes();
1223        base.string_data.extend_from_slice(bytes);
1224
1225        base.commands.push(ArcComputeCommand::PushDebugGroup {
1226            color,
1227            len: bytes.len(),
1228        });
1229
1230        Ok(())
1231    }
1232
1233    pub fn compute_pass_pop_debug_group(
1234        &self,
1235        pass: &mut ComputePass,
1236    ) -> Result<(), PassStateError> {
1237        let base = pass_base!(pass, PassErrorScope::PopDebugGroup);
1238
1239        base.commands.push(ArcComputeCommand::PopDebugGroup);
1240
1241        Ok(())
1242    }
1243
1244    pub fn compute_pass_insert_debug_marker(
1245        &self,
1246        pass: &mut ComputePass,
1247        label: &str,
1248        color: u32,
1249    ) -> Result<(), PassStateError> {
1250        let base = pass_base!(pass, PassErrorScope::InsertDebugMarker);
1251
1252        let bytes = label.as_bytes();
1253        base.string_data.extend_from_slice(bytes);
1254
1255        base.commands.push(ArcComputeCommand::InsertDebugMarker {
1256            color,
1257            len: bytes.len(),
1258        });
1259
1260        Ok(())
1261    }
1262
1263    pub fn compute_pass_write_timestamp(
1264        &self,
1265        pass: &mut ComputePass,
1266        query_set_id: id::QuerySetId,
1267        query_index: u32,
1268    ) -> Result<(), PassStateError> {
1269        let scope = PassErrorScope::WriteTimestamp;
1270        let base = pass_base!(pass, scope);
1271
1272        let hub = &self.hub;
1273        let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1274
1275        base.commands.push(ArcComputeCommand::WriteTimestamp {
1276            query_set,
1277            query_index,
1278        });
1279
1280        Ok(())
1281    }
1282
1283    pub fn compute_pass_begin_pipeline_statistics_query(
1284        &self,
1285        pass: &mut ComputePass,
1286        query_set_id: id::QuerySetId,
1287        query_index: u32,
1288    ) -> Result<(), PassStateError> {
1289        let scope = PassErrorScope::BeginPipelineStatisticsQuery;
1290        let base = pass_base!(pass, scope);
1291
1292        let hub = &self.hub;
1293        let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1294
1295        base.commands
1296            .push(ArcComputeCommand::BeginPipelineStatisticsQuery {
1297                query_set,
1298                query_index,
1299            });
1300
1301        Ok(())
1302    }
1303
1304    pub fn compute_pass_end_pipeline_statistics_query(
1305        &self,
1306        pass: &mut ComputePass,
1307    ) -> Result<(), PassStateError> {
1308        pass_base!(pass, PassErrorScope::EndPipelineStatisticsQuery)
1309            .commands
1310            .push(ArcComputeCommand::EndPipelineStatisticsQuery);
1311
1312        Ok(())
1313    }
1314}