1use std::sync::Arc;
2
3use crate::{
4 binding_model::{BindGroup, LateMinBufferBindingSizeMismatch, PipelineLayout},
5 device::SHADER_STAGE_COUNT,
6 pipeline::LateSizedBufferGroup,
7 resource::{Labeled, ResourceErrorIdent},
8};
9
10use arrayvec::ArrayVec;
11use thiserror::Error;
12
13mod compat {
14 use arrayvec::ArrayVec;
15 use thiserror::Error;
16 use wgt::{BindingType, ShaderStages};
17
18 use crate::{
19 binding_model::BindGroupLayout,
20 error::MultiError,
21 resource::{Labeled, ParentDevice, ResourceErrorIdent},
22 };
23 use std::{
24 num::NonZeroU32,
25 ops::Range,
26 sync::{Arc, Weak},
27 };
28
29 pub(crate) enum Error {
30 Incompatible {
31 expected_bgl: ResourceErrorIdent,
32 assigned_bgl: ResourceErrorIdent,
33 inner: MultiError,
34 },
35 Missing,
36 }
37
38 #[derive(Debug, Clone)]
39 struct Entry {
40 assigned: Option<Arc<BindGroupLayout>>,
41 expected: Option<Arc<BindGroupLayout>>,
42 }
43
44 impl Entry {
45 fn empty() -> Self {
46 Self {
47 assigned: None,
48 expected: None,
49 }
50 }
51 fn is_active(&self) -> bool {
52 self.assigned.is_some() && self.expected.is_some()
53 }
54
55 fn is_valid(&self) -> bool {
56 if let Some(expected_bgl) = self.expected.as_ref() {
57 if let Some(assigned_bgl) = self.assigned.as_ref() {
58 expected_bgl.is_equal(assigned_bgl)
59 } else {
60 false
61 }
62 } else {
63 true
64 }
65 }
66
67 fn is_incompatible(&self) -> bool {
68 self.expected.is_none() || !self.is_valid()
69 }
70
71 fn check(&self) -> Result<(), Error> {
72 if let Some(expected_bgl) = self.expected.as_ref() {
73 if let Some(assigned_bgl) = self.assigned.as_ref() {
74 if expected_bgl.is_equal(assigned_bgl) {
75 Ok(())
76 } else {
77 #[derive(Clone, Debug, Error)]
78 #[error(
79 "Exclusive pipelines don't match: expected {expected}, got {assigned}"
80 )]
81 struct IncompatibleExclusivePipelines {
82 expected: String,
83 assigned: String,
84 }
85
86 use crate::binding_model::ExclusivePipeline;
87 match (
88 expected_bgl.exclusive_pipeline.get().unwrap(),
89 assigned_bgl.exclusive_pipeline.get().unwrap(),
90 ) {
91 (ExclusivePipeline::None, ExclusivePipeline::None) => {}
92 (
93 ExclusivePipeline::Render(e_pipeline),
94 ExclusivePipeline::Render(a_pipeline),
95 ) if Weak::ptr_eq(e_pipeline, a_pipeline) => {}
96 (
97 ExclusivePipeline::Compute(e_pipeline),
98 ExclusivePipeline::Compute(a_pipeline),
99 ) if Weak::ptr_eq(e_pipeline, a_pipeline) => {}
100 (expected, assigned) => {
101 return Err(Error::Incompatible {
102 expected_bgl: expected_bgl.error_ident(),
103 assigned_bgl: assigned_bgl.error_ident(),
104 inner: MultiError::new(core::iter::once(
105 IncompatibleExclusivePipelines {
106 expected: expected.to_string(),
107 assigned: assigned.to_string(),
108 },
109 ))
110 .unwrap(),
111 });
112 }
113 }
114
115 #[derive(Clone, Debug, Error)]
116 enum EntryError {
117 #[error("Entries with binding {binding} differ in visibility: expected {expected:?}, got {assigned:?}")]
118 Visibility {
119 binding: u32,
120 expected: ShaderStages,
121 assigned: ShaderStages,
122 },
123 #[error("Entries with binding {binding} differ in type: expected {expected:?}, got {assigned:?}")]
124 Type {
125 binding: u32,
126 expected: BindingType,
127 assigned: BindingType,
128 },
129 #[error("Entries with binding {binding} differ in count: expected {expected:?}, got {assigned:?}")]
130 Count {
131 binding: u32,
132 expected: Option<NonZeroU32>,
133 assigned: Option<NonZeroU32>,
134 },
135 #[error("Expected entry with binding {binding} not found in assigned bind group layout")]
136 ExtraExpected { binding: u32 },
137 #[error("Assigned entry with binding {binding} not found in expected bind group layout")]
138 ExtraAssigned { binding: u32 },
139 }
140
141 let mut errors = Vec::new();
142
143 for (&binding, expected_entry) in expected_bgl.entries.iter() {
144 if let Some(assigned_entry) = assigned_bgl.entries.get(binding) {
145 if assigned_entry.visibility != expected_entry.visibility {
146 errors.push(EntryError::Visibility {
147 binding,
148 expected: expected_entry.visibility,
149 assigned: assigned_entry.visibility,
150 });
151 }
152 if assigned_entry.ty != expected_entry.ty {
153 errors.push(EntryError::Type {
154 binding,
155 expected: expected_entry.ty,
156 assigned: assigned_entry.ty,
157 });
158 }
159 if assigned_entry.count != expected_entry.count {
160 errors.push(EntryError::Count {
161 binding,
162 expected: expected_entry.count,
163 assigned: assigned_entry.count,
164 });
165 }
166 } else {
167 errors.push(EntryError::ExtraExpected { binding });
168 }
169 }
170
171 for (&binding, _) in assigned_bgl.entries.iter() {
172 if !expected_bgl.entries.contains_key(binding) {
173 errors.push(EntryError::ExtraAssigned { binding });
174 }
175 }
176
177 Err(Error::Incompatible {
178 expected_bgl: expected_bgl.error_ident(),
179 assigned_bgl: assigned_bgl.error_ident(),
180 inner: MultiError::new(errors.drain(..)).unwrap(),
181 })
182 }
183 } else {
184 Err(Error::Missing)
185 }
186 } else {
187 Ok(())
188 }
189 }
190 }
191
192 #[derive(Debug, Default)]
193 pub(crate) struct BoundBindGroupLayouts {
194 entries: ArrayVec<Entry, { hal::MAX_BIND_GROUPS }>,
195 }
196
197 impl BoundBindGroupLayouts {
198 pub fn new() -> Self {
199 Self {
200 entries: (0..hal::MAX_BIND_GROUPS).map(|_| Entry::empty()).collect(),
201 }
202 }
203
204 pub fn num_valid_entries(&self) -> usize {
205 self.entries
207 .iter()
208 .position(|e| e.is_incompatible())
209 .unwrap_or(self.entries.len())
210 }
211
212 fn make_range(&self, start_index: usize) -> Range<usize> {
213 let end = self.num_valid_entries();
214 start_index..end.max(start_index)
215 }
216
217 pub fn update_expectations(
218 &mut self,
219 expectations: &[Arc<BindGroupLayout>],
220 ) -> Range<usize> {
221 let start_index = self
222 .entries
223 .iter()
224 .zip(expectations)
225 .position(|(e, expect)| {
226 e.expected.is_none() || !e.expected.as_ref().unwrap().is_equal(expect)
227 })
228 .unwrap_or(expectations.len());
229 for (e, expect) in self.entries[start_index..]
230 .iter_mut()
231 .zip(expectations[start_index..].iter())
232 {
233 e.expected = Some(expect.clone());
234 }
235 for e in self.entries[expectations.len()..].iter_mut() {
236 e.expected = None;
237 }
238 self.make_range(start_index)
239 }
240
241 pub fn assign(&mut self, index: usize, value: Arc<BindGroupLayout>) -> Range<usize> {
242 self.entries[index].assigned = Some(value);
243 self.make_range(index)
244 }
245
246 pub fn list_active(&self) -> impl Iterator<Item = usize> + '_ {
247 self.entries
248 .iter()
249 .enumerate()
250 .filter_map(|(i, e)| if e.is_active() { Some(i) } else { None })
251 }
252
253 #[allow(clippy::result_large_err)]
254 pub fn get_invalid(&self) -> Result<(), (usize, Error)> {
255 for (index, entry) in self.entries.iter().enumerate() {
256 entry.check().map_err(|e| (index, e))?;
257 }
258 Ok(())
259 }
260 }
261}
262
263#[derive(Clone, Debug, Error)]
264pub enum BinderError {
265 #[error("The current set {pipeline} expects a BindGroup to be set at index {index}")]
266 MissingBindGroup {
267 index: usize,
268 pipeline: ResourceErrorIdent,
269 },
270 #[error("The {assigned_bgl} of current set {assigned_bg} at index {index} is not compatible with the corresponding {expected_bgl} of {pipeline}")]
271 IncompatibleBindGroup {
272 expected_bgl: ResourceErrorIdent,
273 assigned_bgl: ResourceErrorIdent,
274 assigned_bg: ResourceErrorIdent,
275 index: usize,
276 pipeline: ResourceErrorIdent,
277 #[source]
278 inner: crate::error::MultiError,
279 },
280}
281
282#[derive(Debug)]
283struct LateBufferBinding {
284 shader_expect_size: wgt::BufferAddress,
285 bound_size: wgt::BufferAddress,
286}
287
288#[derive(Debug, Default)]
289pub(super) struct EntryPayload {
290 pub(super) group: Option<Arc<BindGroup>>,
291 pub(super) dynamic_offsets: Vec<wgt::DynamicOffset>,
292 late_buffer_bindings: Vec<LateBufferBinding>,
293 pub(super) late_bindings_effective_count: usize,
296}
297
298impl EntryPayload {
299 fn reset(&mut self) {
300 self.group = None;
301 self.dynamic_offsets.clear();
302 self.late_buffer_bindings.clear();
303 self.late_bindings_effective_count = 0;
304 }
305}
306
307#[derive(Debug, Default)]
308pub(super) struct Binder {
309 pub(super) pipeline_layout: Option<Arc<PipelineLayout>>,
310 manager: compat::BoundBindGroupLayouts,
311 payloads: [EntryPayload; hal::MAX_BIND_GROUPS],
312}
313
314impl Binder {
315 pub(super) fn new() -> Self {
316 Self {
317 pipeline_layout: None,
318 manager: compat::BoundBindGroupLayouts::new(),
319 payloads: Default::default(),
320 }
321 }
322 pub(super) fn reset(&mut self) {
323 self.pipeline_layout = None;
324 self.manager = compat::BoundBindGroupLayouts::new();
325 for payload in self.payloads.iter_mut() {
326 payload.reset();
327 }
328 }
329
330 pub(super) fn change_pipeline_layout<'a>(
331 &'a mut self,
332 new: &Arc<PipelineLayout>,
333 late_sized_buffer_groups: &[LateSizedBufferGroup],
334 ) -> (usize, &'a [EntryPayload]) {
335 let old_id_opt = self.pipeline_layout.replace(new.clone());
336
337 let mut bind_range = self.manager.update_expectations(&new.bind_group_layouts);
338
339 for (payload, late_group) in self.payloads.iter_mut().zip(late_sized_buffer_groups) {
341 payload.late_bindings_effective_count = late_group.shader_sizes.len();
342 for (late_binding, &shader_expect_size) in payload
343 .late_buffer_bindings
344 .iter_mut()
345 .zip(late_group.shader_sizes.iter())
346 {
347 late_binding.shader_expect_size = shader_expect_size;
348 }
349 if late_group.shader_sizes.len() > payload.late_buffer_bindings.len() {
350 for &shader_expect_size in
351 late_group.shader_sizes[payload.late_buffer_bindings.len()..].iter()
352 {
353 payload.late_buffer_bindings.push(LateBufferBinding {
354 shader_expect_size,
355 bound_size: 0,
356 });
357 }
358 }
359 }
360
361 if let Some(old) = old_id_opt {
362 if old.push_constant_ranges != new.push_constant_ranges {
364 bind_range.start = 0;
365 }
366 }
367
368 (bind_range.start, &self.payloads[bind_range])
369 }
370
371 pub(super) fn assign_group<'a>(
372 &'a mut self,
373 index: usize,
374 bind_group: &Arc<BindGroup>,
375 offsets: &[wgt::DynamicOffset],
376 ) -> &'a [EntryPayload] {
377 let payload = &mut self.payloads[index];
378 payload.group = Some(bind_group.clone());
379 payload.dynamic_offsets.clear();
380 payload.dynamic_offsets.extend_from_slice(offsets);
381
382 for (late_binding, late_size) in payload
385 .late_buffer_bindings
386 .iter_mut()
387 .zip(bind_group.late_buffer_binding_sizes.iter())
388 {
389 late_binding.bound_size = late_size.get();
390 }
391 if bind_group.late_buffer_binding_sizes.len() > payload.late_buffer_bindings.len() {
392 for late_size in
393 bind_group.late_buffer_binding_sizes[payload.late_buffer_bindings.len()..].iter()
394 {
395 payload.late_buffer_bindings.push(LateBufferBinding {
396 shader_expect_size: 0,
397 bound_size: late_size.get(),
398 });
399 }
400 }
401
402 let bind_range = self.manager.assign(index, bind_group.layout.clone());
403 &self.payloads[bind_range]
404 }
405
406 pub(super) fn list_active<'a>(&'a self) -> impl Iterator<Item = &'a Arc<BindGroup>> + 'a {
407 let payloads = &self.payloads;
408 self.manager
409 .list_active()
410 .map(move |index| payloads[index].group.as_ref().unwrap())
411 }
412
413 #[cfg(feature = "indirect-validation")]
414 pub(super) fn list_valid<'a>(&'a self) -> impl Iterator<Item = (usize, &'a EntryPayload)> + 'a {
415 self.payloads
416 .iter()
417 .take(self.manager.num_valid_entries())
418 .enumerate()
419 }
420
421 pub(super) fn check_compatibility<T: Labeled>(
422 &self,
423 pipeline: &T,
424 ) -> Result<(), Box<BinderError>> {
425 self.manager.get_invalid().map_err(|(index, error)| {
426 Box::new(match error {
427 compat::Error::Incompatible {
428 expected_bgl,
429 assigned_bgl,
430 inner,
431 } => BinderError::IncompatibleBindGroup {
432 expected_bgl,
433 assigned_bgl,
434 assigned_bg: self.payloads[index].group.as_ref().unwrap().error_ident(),
435 index,
436 pipeline: pipeline.error_ident(),
437 inner,
438 },
439 compat::Error::Missing => BinderError::MissingBindGroup {
440 index,
441 pipeline: pipeline.error_ident(),
442 },
443 })
444 })
445 }
446
447 pub(super) fn check_late_buffer_bindings(
449 &self,
450 ) -> Result<(), LateMinBufferBindingSizeMismatch> {
451 for group_index in self.manager.list_active() {
452 let payload = &self.payloads[group_index];
453 for (compact_index, late_binding) in payload.late_buffer_bindings
454 [..payload.late_bindings_effective_count]
455 .iter()
456 .enumerate()
457 {
458 if late_binding.bound_size < late_binding.shader_expect_size {
459 return Err(LateMinBufferBindingSizeMismatch {
460 group_index: group_index as u32,
461 compact_index,
462 shader_size: late_binding.shader_expect_size,
463 bound_size: late_binding.bound_size,
464 });
465 }
466 }
467 }
468 Ok(())
469 }
470}
471
472struct PushConstantChange {
473 stages: wgt::ShaderStages,
474 offset: u32,
475 enable: bool,
476}
477
478pub fn compute_nonoverlapping_ranges(
483 ranges: &[wgt::PushConstantRange],
484) -> ArrayVec<wgt::PushConstantRange, { SHADER_STAGE_COUNT * 2 }> {
485 if ranges.is_empty() {
486 return ArrayVec::new();
487 }
488 debug_assert!(ranges.len() <= SHADER_STAGE_COUNT);
489
490 let mut breaks: ArrayVec<PushConstantChange, { SHADER_STAGE_COUNT * 2 }> = ArrayVec::new();
491 for range in ranges {
492 breaks.push(PushConstantChange {
493 stages: range.stages,
494 offset: range.range.start,
495 enable: true,
496 });
497 breaks.push(PushConstantChange {
498 stages: range.stages,
499 offset: range.range.end,
500 enable: false,
501 });
502 }
503 breaks.sort_unstable_by_key(|change| change.offset);
504
505 let mut output_ranges = ArrayVec::new();
506 let mut position = 0_u32;
507 let mut stages = wgt::ShaderStages::NONE;
508
509 for bk in breaks {
510 if bk.offset - position > 0 && !stages.is_empty() {
511 output_ranges.push(wgt::PushConstantRange {
512 stages,
513 range: position..bk.offset,
514 })
515 }
516 position = bk.offset;
517 stages.set(bk.stages, bk.enable);
518 }
519
520 output_ranges
521}