wgpu_core/command/
bind.rs

1use std::sync::Arc;
2
3use crate::{
4    binding_model::{BindGroup, LateMinBufferBindingSizeMismatch, PipelineLayout},
5    device::SHADER_STAGE_COUNT,
6    pipeline::LateSizedBufferGroup,
7    resource::{Labeled, ResourceErrorIdent},
8};
9
10use arrayvec::ArrayVec;
11use thiserror::Error;
12
13mod compat {
14    use arrayvec::ArrayVec;
15    use thiserror::Error;
16    use wgt::{BindingType, ShaderStages};
17
18    use crate::{
19        binding_model::BindGroupLayout,
20        error::MultiError,
21        resource::{Labeled, ParentDevice, ResourceErrorIdent},
22    };
23    use std::{
24        num::NonZeroU32,
25        ops::Range,
26        sync::{Arc, Weak},
27    };
28
29    pub(crate) enum Error {
30        Incompatible {
31            expected_bgl: ResourceErrorIdent,
32            assigned_bgl: ResourceErrorIdent,
33            inner: MultiError,
34        },
35        Missing,
36    }
37
38    #[derive(Debug, Clone)]
39    struct Entry {
40        assigned: Option<Arc<BindGroupLayout>>,
41        expected: Option<Arc<BindGroupLayout>>,
42    }
43
44    impl Entry {
45        fn empty() -> Self {
46            Self {
47                assigned: None,
48                expected: None,
49            }
50        }
51        fn is_active(&self) -> bool {
52            self.assigned.is_some() && self.expected.is_some()
53        }
54
55        fn is_valid(&self) -> bool {
56            if let Some(expected_bgl) = self.expected.as_ref() {
57                if let Some(assigned_bgl) = self.assigned.as_ref() {
58                    expected_bgl.is_equal(assigned_bgl)
59                } else {
60                    false
61                }
62            } else {
63                true
64            }
65        }
66
67        fn is_incompatible(&self) -> bool {
68            self.expected.is_none() || !self.is_valid()
69        }
70
71        fn check(&self) -> Result<(), Error> {
72            if let Some(expected_bgl) = self.expected.as_ref() {
73                if let Some(assigned_bgl) = self.assigned.as_ref() {
74                    if expected_bgl.is_equal(assigned_bgl) {
75                        Ok(())
76                    } else {
77                        #[derive(Clone, Debug, Error)]
78                        #[error(
79                            "Exclusive pipelines don't match: expected {expected}, got {assigned}"
80                        )]
81                        struct IncompatibleExclusivePipelines {
82                            expected: String,
83                            assigned: String,
84                        }
85
86                        use crate::binding_model::ExclusivePipeline;
87                        match (
88                            expected_bgl.exclusive_pipeline.get().unwrap(),
89                            assigned_bgl.exclusive_pipeline.get().unwrap(),
90                        ) {
91                            (ExclusivePipeline::None, ExclusivePipeline::None) => {}
92                            (
93                                ExclusivePipeline::Render(e_pipeline),
94                                ExclusivePipeline::Render(a_pipeline),
95                            ) if Weak::ptr_eq(e_pipeline, a_pipeline) => {}
96                            (
97                                ExclusivePipeline::Compute(e_pipeline),
98                                ExclusivePipeline::Compute(a_pipeline),
99                            ) if Weak::ptr_eq(e_pipeline, a_pipeline) => {}
100                            (expected, assigned) => {
101                                return Err(Error::Incompatible {
102                                    expected_bgl: expected_bgl.error_ident(),
103                                    assigned_bgl: assigned_bgl.error_ident(),
104                                    inner: MultiError::new(core::iter::once(
105                                        IncompatibleExclusivePipelines {
106                                            expected: expected.to_string(),
107                                            assigned: assigned.to_string(),
108                                        },
109                                    ))
110                                    .unwrap(),
111                                });
112                            }
113                        }
114
115                        #[derive(Clone, Debug, Error)]
116                        enum EntryError {
117                            #[error("Entries with binding {binding} differ in visibility: expected {expected:?}, got {assigned:?}")]
118                            Visibility {
119                                binding: u32,
120                                expected: ShaderStages,
121                                assigned: ShaderStages,
122                            },
123                            #[error("Entries with binding {binding} differ in type: expected {expected:?}, got {assigned:?}")]
124                            Type {
125                                binding: u32,
126                                expected: BindingType,
127                                assigned: BindingType,
128                            },
129                            #[error("Entries with binding {binding} differ in count: expected {expected:?}, got {assigned:?}")]
130                            Count {
131                                binding: u32,
132                                expected: Option<NonZeroU32>,
133                                assigned: Option<NonZeroU32>,
134                            },
135                            #[error("Expected entry with binding {binding} not found in assigned bind group layout")]
136                            ExtraExpected { binding: u32 },
137                            #[error("Assigned entry with binding {binding} not found in expected bind group layout")]
138                            ExtraAssigned { binding: u32 },
139                        }
140
141                        let mut errors = Vec::new();
142
143                        for (&binding, expected_entry) in expected_bgl.entries.iter() {
144                            if let Some(assigned_entry) = assigned_bgl.entries.get(binding) {
145                                if assigned_entry.visibility != expected_entry.visibility {
146                                    errors.push(EntryError::Visibility {
147                                        binding,
148                                        expected: expected_entry.visibility,
149                                        assigned: assigned_entry.visibility,
150                                    });
151                                }
152                                if assigned_entry.ty != expected_entry.ty {
153                                    errors.push(EntryError::Type {
154                                        binding,
155                                        expected: expected_entry.ty,
156                                        assigned: assigned_entry.ty,
157                                    });
158                                }
159                                if assigned_entry.count != expected_entry.count {
160                                    errors.push(EntryError::Count {
161                                        binding,
162                                        expected: expected_entry.count,
163                                        assigned: assigned_entry.count,
164                                    });
165                                }
166                            } else {
167                                errors.push(EntryError::ExtraExpected { binding });
168                            }
169                        }
170
171                        for (&binding, _) in assigned_bgl.entries.iter() {
172                            if !expected_bgl.entries.contains_key(binding) {
173                                errors.push(EntryError::ExtraAssigned { binding });
174                            }
175                        }
176
177                        Err(Error::Incompatible {
178                            expected_bgl: expected_bgl.error_ident(),
179                            assigned_bgl: assigned_bgl.error_ident(),
180                            inner: MultiError::new(errors.drain(..)).unwrap(),
181                        })
182                    }
183                } else {
184                    Err(Error::Missing)
185                }
186            } else {
187                Ok(())
188            }
189        }
190    }
191
192    #[derive(Debug, Default)]
193    pub(crate) struct BoundBindGroupLayouts {
194        entries: ArrayVec<Entry, { hal::MAX_BIND_GROUPS }>,
195    }
196
197    impl BoundBindGroupLayouts {
198        pub fn new() -> Self {
199            Self {
200                entries: (0..hal::MAX_BIND_GROUPS).map(|_| Entry::empty()).collect(),
201            }
202        }
203
204        pub fn num_valid_entries(&self) -> usize {
205            // find first incompatible entry
206            self.entries
207                .iter()
208                .position(|e| e.is_incompatible())
209                .unwrap_or(self.entries.len())
210        }
211
212        fn make_range(&self, start_index: usize) -> Range<usize> {
213            let end = self.num_valid_entries();
214            start_index..end.max(start_index)
215        }
216
217        pub fn update_expectations(
218            &mut self,
219            expectations: &[Arc<BindGroupLayout>],
220        ) -> Range<usize> {
221            let start_index = self
222                .entries
223                .iter()
224                .zip(expectations)
225                .position(|(e, expect)| {
226                    e.expected.is_none() || !e.expected.as_ref().unwrap().is_equal(expect)
227                })
228                .unwrap_or(expectations.len());
229            for (e, expect) in self.entries[start_index..]
230                .iter_mut()
231                .zip(expectations[start_index..].iter())
232            {
233                e.expected = Some(expect.clone());
234            }
235            for e in self.entries[expectations.len()..].iter_mut() {
236                e.expected = None;
237            }
238            self.make_range(start_index)
239        }
240
241        pub fn assign(&mut self, index: usize, value: Arc<BindGroupLayout>) -> Range<usize> {
242            self.entries[index].assigned = Some(value);
243            self.make_range(index)
244        }
245
246        pub fn list_active(&self) -> impl Iterator<Item = usize> + '_ {
247            self.entries
248                .iter()
249                .enumerate()
250                .filter_map(|(i, e)| if e.is_active() { Some(i) } else { None })
251        }
252
253        #[allow(clippy::result_large_err)]
254        pub fn get_invalid(&self) -> Result<(), (usize, Error)> {
255            for (index, entry) in self.entries.iter().enumerate() {
256                entry.check().map_err(|e| (index, e))?;
257            }
258            Ok(())
259        }
260    }
261}
262
263#[derive(Clone, Debug, Error)]
264pub enum BinderError {
265    #[error("The current set {pipeline} expects a BindGroup to be set at index {index}")]
266    MissingBindGroup {
267        index: usize,
268        pipeline: ResourceErrorIdent,
269    },
270    #[error("The {assigned_bgl} of current set {assigned_bg} at index {index} is not compatible with the corresponding {expected_bgl} of {pipeline}")]
271    IncompatibleBindGroup {
272        expected_bgl: ResourceErrorIdent,
273        assigned_bgl: ResourceErrorIdent,
274        assigned_bg: ResourceErrorIdent,
275        index: usize,
276        pipeline: ResourceErrorIdent,
277        #[source]
278        inner: crate::error::MultiError,
279    },
280}
281
282#[derive(Debug)]
283struct LateBufferBinding {
284    shader_expect_size: wgt::BufferAddress,
285    bound_size: wgt::BufferAddress,
286}
287
288#[derive(Debug, Default)]
289pub(super) struct EntryPayload {
290    pub(super) group: Option<Arc<BindGroup>>,
291    pub(super) dynamic_offsets: Vec<wgt::DynamicOffset>,
292    late_buffer_bindings: Vec<LateBufferBinding>,
293    /// Since `LateBufferBinding` may contain information about the bindings
294    /// not used by the pipeline, we need to know when to stop validating.
295    pub(super) late_bindings_effective_count: usize,
296}
297
298impl EntryPayload {
299    fn reset(&mut self) {
300        self.group = None;
301        self.dynamic_offsets.clear();
302        self.late_buffer_bindings.clear();
303        self.late_bindings_effective_count = 0;
304    }
305}
306
307#[derive(Debug, Default)]
308pub(super) struct Binder {
309    pub(super) pipeline_layout: Option<Arc<PipelineLayout>>,
310    manager: compat::BoundBindGroupLayouts,
311    payloads: [EntryPayload; hal::MAX_BIND_GROUPS],
312}
313
314impl Binder {
315    pub(super) fn new() -> Self {
316        Self {
317            pipeline_layout: None,
318            manager: compat::BoundBindGroupLayouts::new(),
319            payloads: Default::default(),
320        }
321    }
322    pub(super) fn reset(&mut self) {
323        self.pipeline_layout = None;
324        self.manager = compat::BoundBindGroupLayouts::new();
325        for payload in self.payloads.iter_mut() {
326            payload.reset();
327        }
328    }
329
330    pub(super) fn change_pipeline_layout<'a>(
331        &'a mut self,
332        new: &Arc<PipelineLayout>,
333        late_sized_buffer_groups: &[LateSizedBufferGroup],
334    ) -> (usize, &'a [EntryPayload]) {
335        let old_id_opt = self.pipeline_layout.replace(new.clone());
336
337        let mut bind_range = self.manager.update_expectations(&new.bind_group_layouts);
338
339        // Update the buffer binding sizes that are required by shaders.
340        for (payload, late_group) in self.payloads.iter_mut().zip(late_sized_buffer_groups) {
341            payload.late_bindings_effective_count = late_group.shader_sizes.len();
342            for (late_binding, &shader_expect_size) in payload
343                .late_buffer_bindings
344                .iter_mut()
345                .zip(late_group.shader_sizes.iter())
346            {
347                late_binding.shader_expect_size = shader_expect_size;
348            }
349            if late_group.shader_sizes.len() > payload.late_buffer_bindings.len() {
350                for &shader_expect_size in
351                    late_group.shader_sizes[payload.late_buffer_bindings.len()..].iter()
352                {
353                    payload.late_buffer_bindings.push(LateBufferBinding {
354                        shader_expect_size,
355                        bound_size: 0,
356                    });
357                }
358            }
359        }
360
361        if let Some(old) = old_id_opt {
362            // root constants are the base compatibility property
363            if old.push_constant_ranges != new.push_constant_ranges {
364                bind_range.start = 0;
365            }
366        }
367
368        (bind_range.start, &self.payloads[bind_range])
369    }
370
371    pub(super) fn assign_group<'a>(
372        &'a mut self,
373        index: usize,
374        bind_group: &Arc<BindGroup>,
375        offsets: &[wgt::DynamicOffset],
376    ) -> &'a [EntryPayload] {
377        let payload = &mut self.payloads[index];
378        payload.group = Some(bind_group.clone());
379        payload.dynamic_offsets.clear();
380        payload.dynamic_offsets.extend_from_slice(offsets);
381
382        // Fill out the actual binding sizes for buffers,
383        // whose layout doesn't specify `min_binding_size`.
384        for (late_binding, late_size) in payload
385            .late_buffer_bindings
386            .iter_mut()
387            .zip(bind_group.late_buffer_binding_sizes.iter())
388        {
389            late_binding.bound_size = late_size.get();
390        }
391        if bind_group.late_buffer_binding_sizes.len() > payload.late_buffer_bindings.len() {
392            for late_size in
393                bind_group.late_buffer_binding_sizes[payload.late_buffer_bindings.len()..].iter()
394            {
395                payload.late_buffer_bindings.push(LateBufferBinding {
396                    shader_expect_size: 0,
397                    bound_size: late_size.get(),
398                });
399            }
400        }
401
402        let bind_range = self.manager.assign(index, bind_group.layout.clone());
403        &self.payloads[bind_range]
404    }
405
406    pub(super) fn list_active<'a>(&'a self) -> impl Iterator<Item = &'a Arc<BindGroup>> + 'a {
407        let payloads = &self.payloads;
408        self.manager
409            .list_active()
410            .map(move |index| payloads[index].group.as_ref().unwrap())
411    }
412
413    #[cfg(feature = "indirect-validation")]
414    pub(super) fn list_valid<'a>(&'a self) -> impl Iterator<Item = (usize, &'a EntryPayload)> + 'a {
415        self.payloads
416            .iter()
417            .take(self.manager.num_valid_entries())
418            .enumerate()
419    }
420
421    pub(super) fn check_compatibility<T: Labeled>(
422        &self,
423        pipeline: &T,
424    ) -> Result<(), Box<BinderError>> {
425        self.manager.get_invalid().map_err(|(index, error)| {
426            Box::new(match error {
427                compat::Error::Incompatible {
428                    expected_bgl,
429                    assigned_bgl,
430                    inner,
431                } => BinderError::IncompatibleBindGroup {
432                    expected_bgl,
433                    assigned_bgl,
434                    assigned_bg: self.payloads[index].group.as_ref().unwrap().error_ident(),
435                    index,
436                    pipeline: pipeline.error_ident(),
437                    inner,
438                },
439                compat::Error::Missing => BinderError::MissingBindGroup {
440                    index,
441                    pipeline: pipeline.error_ident(),
442                },
443            })
444        })
445    }
446
447    /// Scan active buffer bindings corresponding to layouts without `min_binding_size` specified.
448    pub(super) fn check_late_buffer_bindings(
449        &self,
450    ) -> Result<(), LateMinBufferBindingSizeMismatch> {
451        for group_index in self.manager.list_active() {
452            let payload = &self.payloads[group_index];
453            for (compact_index, late_binding) in payload.late_buffer_bindings
454                [..payload.late_bindings_effective_count]
455                .iter()
456                .enumerate()
457            {
458                if late_binding.bound_size < late_binding.shader_expect_size {
459                    return Err(LateMinBufferBindingSizeMismatch {
460                        group_index: group_index as u32,
461                        compact_index,
462                        shader_size: late_binding.shader_expect_size,
463                        bound_size: late_binding.bound_size,
464                    });
465                }
466            }
467        }
468        Ok(())
469    }
470}
471
472struct PushConstantChange {
473    stages: wgt::ShaderStages,
474    offset: u32,
475    enable: bool,
476}
477
478/// Break up possibly overlapping push constant ranges into a set of
479/// non-overlapping ranges which contain all the stage flags of the
480/// original ranges. This allows us to zero out (or write any value)
481/// to every possible value.
482pub fn compute_nonoverlapping_ranges(
483    ranges: &[wgt::PushConstantRange],
484) -> ArrayVec<wgt::PushConstantRange, { SHADER_STAGE_COUNT * 2 }> {
485    if ranges.is_empty() {
486        return ArrayVec::new();
487    }
488    debug_assert!(ranges.len() <= SHADER_STAGE_COUNT);
489
490    let mut breaks: ArrayVec<PushConstantChange, { SHADER_STAGE_COUNT * 2 }> = ArrayVec::new();
491    for range in ranges {
492        breaks.push(PushConstantChange {
493            stages: range.stages,
494            offset: range.range.start,
495            enable: true,
496        });
497        breaks.push(PushConstantChange {
498            stages: range.stages,
499            offset: range.range.end,
500            enable: false,
501        });
502    }
503    breaks.sort_unstable_by_key(|change| change.offset);
504
505    let mut output_ranges = ArrayVec::new();
506    let mut position = 0_u32;
507    let mut stages = wgt::ShaderStages::NONE;
508
509    for bk in breaks {
510        if bk.offset - position > 0 && !stages.is_empty() {
511            output_ranges.push(wgt::PushConstantRange {
512                stages,
513                range: position..bk.offset,
514            })
515        }
516        position = bk.offset;
517        stages.set(bk.stages, bk.enable);
518    }
519
520    output_ranges
521}